?

基于多鑒別器生成對抗網絡的時間序列生成模型

2023-01-09 12:33陸彥輝柳寒李航朱光旭
通信學報 2022年10期
關鍵詞:鑒別器頻域誤差

陸彥輝,柳寒,李航,朱光旭

(1.鄭州大學電氣與信息工程學院,河南 鄭州 450001;2.深圳市大數據研究院,廣東 深圳 518115)

0 引言

近年來,隨著計算能力的提升和5G 網絡的普及,數據生成規模逐步擴大,在生產生活中的作用也日益顯著。越來越多的商業公司和組織機構依賴于大數據分析得到有效的決策[1]。大數據分析中一個重要類別是分析與時間相關的數據,涉及金融、氣象、石油和醫學等多個領域。例如,通過分析金融時間序列來預測股票價格[2];通過分析氣候時間序列來分析植被的變化[3];通過分析石油產量時間序列來預測石油的產量[4];通過分析COVID-19 隨時間變化的確診人數來預測未來的確診人數[5]。

時間序列是按照一定的時間間隔持續記錄一段時間的數據,它們通常包含著豐富且復雜的信息,具備較強的研究和商業價值。然而,這些數據在收集過程中存在著各種各樣的問題,例如,數據往往包含隱私信息,無法進行公開傳播與實驗[6];傳感器數據在收集過程中存在數據缺失[7];數據收集困難導致可用數據集過小,難以滿足模型訓練需求[8]。一種可行的解決方案是通過機器學習方法生成大量與真實數據相似度較高的數據,從而滿足模型訓練、驗證等應用。

現有基于機器學習的生成模型主要包括變分自動編碼器(VAE,variational auto-encoder)[9]和生成對抗網絡(GAN,generative adversarial network)[10]。其中,GAN 的研究得到了廣泛的關注,已有工作提出了多種GAN 模型,可用于生成逼真的圖像和視頻。鑒于GAN 在圖像生成方面的優異性能,開發高質量、多樣化和特殊性的時間序列數據的工作得以進一步展開。

本文采取多鑒別器對時間序列的多種特征進行鑒別,提出了多鑒別器生成對抗網絡(MDGAN,multi-discriminator generative adversarial network)模型。本文主要研究工作如下。

1) 本文提出了一種新型的MDGAN 模型,包含時域鑒別器、頻域鑒別器、時頻域鑒別器和自相關鑒別器,能夠對生成數據進行多角度評估,進而提高生成器的合成數據質量,使合成數據更加符合真實時間序列的分布和特征。

2) 在對所提模型進行訓練時,本文引入了二分類交叉熵模型,優化了原始的GAN 損失函數,使其適配多鑒別器網絡,從而提升了模型訓練效果。

3) 本文采用了不同類型的數據集對模型進行橫向和縱向的對照實驗,驗證了本文所提模型能夠有效提升合成時間序列的質量。

1 相關工作

生成對抗網絡最早由Goodfellow 提出,其核心主要體現了零和博弈思想。在生成對抗網絡中,同時訓練生成器網絡和鑒別器網絡這2 個網絡。整個網絡的損失函數定義為

其中,pdata表示真實數據x的分布,符合隨機分布pz的噪聲z表示生成器的輸入,G(z)表示生成器生成的合成數據,D(·) 表示鑒別器對數據的評價結果,E 表示數學期望。生成器致力于學習真實數據的特征,以此生成符合真實數據分布的合成數據;鑒別器致力于分辨輸入是來源于真實數據還是合成數據。在訓練鑒別器的過程中,希望真實數據x通過鑒別器的結果D(x)更接近真實的評價,合成數據G(z)通過鑒別器的結果D(G(z))更接近虛假的評價。而在訓練生成器的過程中,希望合成數據G(z)通過鑒別器的結果D(G(z))更接近真實的評價。當訓練達到納什平衡時,認為生成器的合成數據的主要特征已經符合真實數據的主要特征。

現有工作以GAN 為基礎進行了不同方面的改進。Radford 等[11]提出的深度卷積生成對抗網絡(DCGAN,deep convolutional generative adversarial network)將卷積神經網絡應用到GAN 中,在網絡架構上改進了原始GAN。Arjovsky 等[12]提出的WGAN(Wasserstein generative adversarial network)采用Wasserstein 距離指導整個模型的訓練,在鑒別器中使用權重剪枝技術。Isola 等[13]提出的基于GAN 的Pix2Pix 算法用于圖像像素間的轉換,利用條件生成對抗網絡(CGAN,conditional generative adversarial network)生成圖像。Zhu 等[14]提出了循環一致性生成對抗網絡(CycleAN,cycle-consistent adversarial network),以Pix2Pix 為基礎,主要應用于非配對的圖片生成和轉換,可以實現圖片的風格轉換。Karras 等[15]提出了可以控制樣式的StyleGAN(style-based generator architecture for generative adversarial network),通過修改樣式的特定尺度來控制圖像的生成?,F有工作已經將GAN 成功應用于圖像、視頻以及自然語言等方向。

循環神經網絡(RNN,recurrent neural network)具有獨特的環狀結構,很適用于處理連續時間序列[16]。然而它缺乏學習長期依賴關系的能力,而這種關系對于根據過去預測未來是至關重要的。RNN 的變體長短期記憶(LSTM,long short term memory)網絡具有長時間記憶信息的能力,進而可以學習序列信息的長期依賴關系[17]。Mogren[18]提出了具有GAN 的連續循環神經網絡(C-RNN-GAN,continuous recurrent neural network with adversarial training)模型,是最早利用RNN 的GAN 生成連續序列數據的例子。該模型的生成器是一個LSTM 網絡,鑒別器是一個雙向的LSTM 網絡,通過時間反向傳播和正則化的小批量隨機梯度下降,訓練生成器和鑒別器的網絡參數。

Esteban 等[19]提出了循環條件生成對抗網絡(RCGAN,recurrent conditional generative adversarial network)模型。它的生成器和鑒別器都采用RNN,和C-RNN-GAN 不同的是,RCGAN 的生成器和鑒別器的輸入需要加入附加條件來控制結果。此模型的損失函數采用二分類交叉熵(BCE,binary cross entropy),能夠描述真實數據與合成數據之間的關系。RCGAN 模型是很多后續工作的模型參照。

Yoon 等[20]提出了一種時間序列生成對抗網絡(TimeGAN,time-series generative adversarial network),并利用了傳統的無監督GAN 訓練方法和更可控的監督學習方法。具體而言,該網絡能夠生成具有時間動態特性的時間序列。TimeGAN 由嵌入網絡、恢復網絡、生成器和鑒別器4 個網絡組件組成。自動編碼網絡(前2 個網絡)與生成對抗網絡(后2 個網絡)聯合訓練,嵌入網絡和恢復網絡負責數據到隱式特征的轉換,生成對抗網絡在此空間內學習數據的潛在有效特征。

TimeGAN 主要用于生成短時間序列,因為長時間序列會大大增加生成建模的維數要求,導致復雜度過高。為了解決這個問題,Ni 等[21]提出一個名為Signature Wasserstein-1的度量并將其作為鑒別器的評價結果,同時提出了一種新的生成器,稱為條件自回歸前饋神經網絡,它抓住了時間序列的自回歸性質,加快了訓練的速度,整個模型被稱為SigWGAN(signature Wasserstein generative adversarial network)。

盡管已有工作能夠實現多種類型時間序列的生成,但是上述模型也存在不足。一是原始GAN面臨梯度消失的問題。在訓練初期,生成器的合成數據與真實數據相差很大,鑒別器可以利用高置信度區分二者,但損失函數無法為生成器提供足夠大的梯度,最終導致梯度消失。二是時間序列的特征提取和利用的問題。時間序列數據的特征有多方面,涉及周期性、相關性和頻域的特征等。單一鑒別器能夠完成對時間序列特征的鑒別,但是不具有針對性。

對于上述2 個代表性問題,本文設計了多鑒別器的模型。多鑒別器針對時間序列的不同特征進行針對性的鑒別,在初期訓練中合成數據不會因為某一項特征不明顯而直接導致梯度消失,同時也有助于提高生成器合成數據的質量。

2 多鑒別器生成對抗網絡模型

本文以GAN 和RNN 為基礎提出了MDGAN的模型。此模型主要由3 個部分組成,分別是數據處理、生成器和多鑒別器。多鑒別器GAN 結構如圖1 所示。在整個模型中,生成器輸出的合成數據為G(ZN),其中ZN為輸入的隨機噪聲。合成數據經過數據處理得到T(G(ZN)),真實時間序列XN經過數據處理得到T(XN)。處理后的數據通過多鑒別器進行真/假判定。最后,通過計算鑒別器的損失函數D loss 和生成器的損失函數G loss 分別更新鑒別器和生成器的網絡參數。

圖1 多鑒別器GAN 結構

下面,分別介紹模型的組成部分、模型訓練中的損失函數和訓練方法。

2.1 數據處理

數據處理的目的是得到數據的不同特征。本文以真實時間序列的處理過程為例,介紹數據處理的流程。數據處理流程如圖2 所示。

圖2 數據處理流程

真實時間序列XN是一段長度為N的序列。序列可以描述為

在數據處理的過程中,時間序列XN通過傅里葉變換得到頻域數據F(XN);通過對時域和頻域數據的處理和拼接得到時頻域數據TF(XN);通過自相關處理得到自相關函數ACF(XN)。處理后的數據按順序組合為T(XN),排序方式為

T(XN)是將3 種數據組合在一起。接下來,對式(3)中的3 個部分分別進行介紹。

2.1.1 傅里葉變換

離散傅里葉變換(DFT,discrete Fourier transform)是信號分析最基本的方法[22]。該方法將時間序列從時間域變換到頻率域,分析時間序列的頻域結構與變化規律。本文對長度為N的時間序列XN做M點的離散傅里葉變換。M的取值是2的整數冪,且大于或等于時間序列的長度N。XN的表達式為

其中,x(n)是時間序列XN中的第n個值,X(k)是傅里葉變換后的值。在模型中使用的方法是快速傅里葉變化(FFT,fast Fourier transform)。

離散傅立葉變換后的數據是一組復數,其中一半數據和另一半數據是共軛關系。本文只取一半數據F(XN)。F(XN)的表達式為

2.1.2 時域與頻域拼接處理

傅里葉變換只反映數據在頻域的特征,為了將時域和頻域的特征聯系在一起,常用短時傅里葉變換方法,其實質是加窗的傅里葉變換。這種方法是一種數據變形處理。但是本文希望從原始數據出發,得到一種同時包含時域數據和頻域數據的形式。所以本文采取時域數據和頻域數據拼接的方法分析特征。

具體的拼接方法是首先對頻域數據取模后得到|F(XN)|。取模是一種對復數進行計算的方法,假設復數z=a+bi,復數模值計算為

F(XN)中的每一個值都是復數,對每一個值取模之后,本文可以得到|F(XN)|的表達式,即

然后,將頻域數據的模值|F(XN)|和時域數據XN拼接的數據看作一組同時包含時域和頻域特征的數據,定義為時頻域數據TF(XN)。時頻域數據TF(XN)的表達式為

2.1.3 自相關函數處理

自相關函數(ACF,autocorrelation function)在信號處理中經常用來分析數據并描述數據的相似性[23]。通過使用自相關函數對時間序列進行處理,進一步對數據在時域上的特征進行分析。本文將自相關函數定義為ACF(XN)。離散序列的自相關函數的表達式為

其中,x(n)表示時間序列XN中的第n個值,m表示時間間隔。

2.2 生成器和鑒別器的網絡結構

生成器和鑒別器的網絡由LSTM 網絡構成。LSTM 網絡是RNN 的變體,一般用于與時間序列相關的任務,它由一系列結構相同的神經元構成,該神經元在每個時間步中重復使用。LSTM 的神經元內部有一個記憶狀態,在處理序列數據時,輸入不僅有序列數據,還有上一個時刻的記憶狀態,并向下一個時刻輸出當前的記憶狀態。因此LSTM 網絡是處理時間序列常用的網絡。

2.2.1 生成器網絡

生成器的網絡結構主要由LSTM層和全連接層構成。生成器在每個時間步的輸入獲取不同的隨機噪聲向量。隨機噪聲向量由標準正態分布采樣得到,并通過LSTM 網絡進行計算。LSTM 網絡的激活函數是tanh 函數。全連接層將LSTM 層的輸出轉換為指定的長度。生成器的網絡結構如圖3所示。

圖3 生成器的網絡結構

LSTM 網絡的層數為2,隱藏層的神經單元個數為64。全連接層采用Linear 函數進行轉換,并將每個時間步的全連接層的輸出組合后得到合成數據。

2.2.2 鑒別器網絡

鑒別器是對合成時間序列和真實時間序列的每個時間步的輸出進行鑒別,最后取均值得到真/假的評價。鑒別器的網絡結構如圖4 所示。

圖4 鑒別器的網絡結構

Data 表示輸入鑒別器網絡的數據,是真實數據或合成數據以及它們的變體。鑒別器的網絡結構和生成器的網絡結構類似。鑒別器的全連接層使用Sigmoid 函數,將最后的輸出轉化為[0,1]區間的值。輸出代表鑒別器對輸入的評價。本文提出的模型包含多個鑒別器,不同的數據需要通過不同的鑒別器。

合成數據和真實數據的處理過程相同,本文以真實數據的鑒別過程為例說明多鑒別器如何對數據進行鑒別。多鑒別器的處理流程如圖5 所示。

圖5 多鑒別器的處理流程

每個鑒別器網絡的輸出y的取值范圍為[0,1],將4 個鑒別器的輸出數值進行平均,定義最終結果大于或等于0.5 的是真實數據(評價為真),小于0.5 的是合成數據(評價為假)。因此,輸出結果可表示為

經過數據處理的數據T(XN)在通過頻域鑒別器、時頻域鑒別器和自相關鑒別器時分別提取出與之相對應的數據。將不同鑒別器的評價結果進行平均得到最終結果。

2.3 模型訓練

MDGAN 模型的訓練分2 個部分介紹,第一部分介紹模型的損失函數,第二部分介紹模型的訓練過程。

2.3.1 損失函數

MDGAN 模型的訓練包括鑒別器和生成器2 個部分的訓練。在訓練中本文使用二分類交叉熵計算損失函數。BCE 的計算式為

鑒別器的目的是分辨出真實數據和合成數據。在訓練中本文使用二分類交叉熵對鑒別器的預測和數據的標簽進行計算。真實數據的標簽為1,合成數據的標簽為0。

越是優秀的鑒別器對真實時間序列的鑒別結果越接近1,對合成時間序列的鑒別結果越接近0。因此在鑒別器訓練時,本文最小化數據通過鑒別器的結果與對應標簽的二分類交叉熵。鑒別器的損失函數為

因為模型有多個鑒別器,需要分別計算結果。將計算結果代入式(12)中,然后利用式(12)對4 種鑒別器的網絡參數進行更新。4 種鑒別器的計算結果分別為

生成器的目的是隨機噪聲通過生成器生成與真實數據類似的合成數據。因此生成器生成的合成數據在通過鑒別器時,希望得到的評價是真實的。越是優秀的生成器生成的合成數據通過鑒別器的預測值越接近1。因此在生成器訓練時,本文最小化合成數據通過鑒別器的結果與真實標簽的二分類交叉熵。生成器的損失函數為

式(12)~式(17)中,Dt代表時域鑒別器,DF代表頻域鑒別器,DTF代表時頻域鑒別器,DACF代表自相關鑒別器,G代表生成器,yD代表鑒別器結果,XN代表真實時間序列,G(ZN)代表合成數據(1 代表真實,0 代表虛假)。

2.3.2 訓練過程

在訓練過程中,本文需要先對數據集進行預處理再進行訓練。

數據集的預處理是先取出所有數據并進行歸一化計算,然后將數據分為多個固定長度的序列進行隨機組合。例如,把10 000 個數據按20 的固定大小分為500 組,然后將這500 組數據進行隨機組合,目的是混合數據并使其類似于獨立同分布。將預處理之后的真實時間序列分布定義為pr,隨機噪聲數據的分布pz是正態分布。

在鑒別器和生成器的訓練過程中,先對鑒別器進行訓練,更新鑒別器參數,同時固定生成器的參數;然后對生成器進行訓練,更新生成器參數,同時固定鑒別器的參數。重復上述過程。訓練中對參數更新的方法采用Adam 優化算法[24]。多鑒別器生成對抗網絡生成樣本算法如算法1 所示。

算法1多鑒別器生成對抗網絡生成樣本算法

輸入批量值m,隨機噪聲z,真實樣本x,學習率γ,鑒別器更新次數nd,Adam 超參β

輸出生成器G,鑒別器D

初始化生成器參數θg,鑒別器參數θd

1) whileθghas not converged do

2) fort=0,1,…,nddo

3) 獲取真實數據 (x(1),…,x(m))~pr

4) 獲取噪聲數據 (z(1),…,z(m))~pz

6)endfor

7)獲取噪聲數據(z(1),…,z(m))~pz

9) end while

10) returnG,D

3 實驗結果分析

本節介紹實驗使用的數據集和評價指標,通過評價指標對實驗結果進行分析。在實驗中,為了更好地評估模型的性能,本文進行了橫向和縱向對比??v向對比中使用MDGAN 與頻域鑒別器GAN、自相關鑒別器GAN、時頻域鑒別器GAN 進行比較。橫向比較中使用3 種具有代表性的時間序列生成模型與MDGAN 進行比較,分別是RCGAN[19]、TimeGAN[20]和SigCWGAN[21]。

3.1 數據集

本文實驗使用的數據集是地磁數據集和牛津大學金融學院股票數據集中的標準普爾500 指數數據集。

地磁數據集共包含11 500 條數據。該數據是由手機自帶的地磁傳感器收集的一段5 min 內隨手機姿態變化的地磁數據。地磁數據集經常用來分析和預測實驗者使用時手機的不同姿態。

標準普爾500 指數數據集是牛津大學金融學院收集的股票數據,包括2000—2021 年的標準普爾500 指數數據集,共有5 515 條數據。每條數據包括每天的開盤價格、收盤價格和價格波動率。股票數據集經常用來分析和預測股票的趨勢。

3.2 性能評估

實驗中采取3 種常用的評估方法,分別是loss函數收斂性、主成分分析法(PCA,principal component analysis)和誤差分析,分別從定性和定量的角度說明MDGAN 的性能。

1) loss 函數收斂性。loss 函數的收斂性主要用于評價模型的訓練速度。

2) 主成分分析法。主成分分析法用于評價合成數據的分布情況,是最常用的線性降維方法。它的目標是通過某種線性投影將高維的數據映射到低維的空間中,并期望在所投影的維度上數據的信息量最大,實現使用較少的數據維度保留較多的原數據點特性。

3) 誤差分析。誤差分析評價合成數據的準確性。本文對合成時間序列和真實時間序列進行誤差分析,并使用均方誤差(MSE,mean square error)、均方根誤差(RMSE,root mean squared error)、平均絕對誤差(MAE,mean absolute error)和平均絕對誤差百分比(MAPE,mean absolute percentage error)這4 種誤差評價指標。

3.3 縱向對比結果

在縱向對比中,本文只使用地磁數據集對模型進行比較??v向比較的模型有MDGAN、頻域鑒別器GAN、時頻域鑒別器GAN 和自相關鑒別器GAN。MDGAN 中包含所有數據處理過程和對應的鑒別器,其他模型只包含一種數據處理過程和對應的鑒別器??v向對比是為了說明多鑒別器GAN 的合成數據比只包含一種鑒別器的GAN 模型的合成數據更加接近真實數據。

因為數據處理方式不同,4 種模型在loss 函數收斂性和主成分分析上的對比意義不是很重要,所以在縱向對比中本文只使用誤差分析對模型合成數據的準確性進行分析。誤差對比如表1 所示。

表1 模型誤差對比

從表1 可以看出,時頻域鑒別器GAN 的誤差大多略優于頻域鑒別器GAN 和自相關鑒別器GAN 的誤差。但是MDGAN 模型的誤差明顯優于另外3 種模型的誤差。所以本文MDGAN 模型生成的合成數據更加準確。

3.4 橫向對比結果

3.4.1 loss 函數收斂性分析

為了對比模型的loss 函數收斂性,本文使用地磁數據集對MDGAN、SigCWGAN、TimeGAN和RCGAN 這4 種模型進行訓練,損失函數的變化如圖6 所示。其中,Sig loss 表示SigCWGAN模型的損失函數。

圖6 訓練過程中損失函數的變化

由圖6 可以看出,TimeGAN 和RCGAN 模型的loss 函數在1 000 次左右還沒有趨于穩定,但是SigCWGAN 和MDGAN 模型的loss 函數在400 次左右已經趨于穩定。這是因為TimeGAN 和RCGAN采用單一鑒別器,在訓練過程中這2 種模型會在生成器和鑒別器之間的博弈花費更多的時間,不如多鑒別器GAN 的訓練效率高。MDGAN 擁有多個鑒別器,在與生成器的博弈過程中會更加準確地對序列進行評價,這樣有利于生成器快速地獲得數據特征。而SigCWGAN 將生成器和鑒別器的損失函數合為一個損失函數,因此會提高訓練的速度。綜上,本文所使用的MDGAN 在模型訓練的收斂速度上要優于TimeGAN 和RCGAN,與SigCWGAN 不相上下。

3.4.2 主成分分析

為了直觀地觀察數據的分布,本文采用了主成分分析法將原始數據和合成數據的特征降維到二維平面,來觀察數據之間的差異。

本文使用2 個數據集進行實驗,對4 種模型進行評價。對比結果分別如圖7 和圖8 所示。合成數據覆蓋部分越大,說明模型越優秀。對比2 個數據集在4 組模型中的實驗可以看出,MDGAN 模型在2 個數據集訓練得到的合成數據分布均優于TimeGAN、SigCWGAN 和RCGAN 的合成數據分布。因為MDGAN 模型采用多鑒別器對合成數據的多個特征進行鑒別,所以合成數據的分布更加接近真實數據的分布。

3.4.3 誤差分析

從圖7 和圖8 中能直觀看到合成數據的分布是接近真實數據數據分布的,但是不能客觀地評價合成數據的好壞,因此本文對2 個數據集的合成數據進行誤差分析,分別如表2 和表3 所示。其中,股票數據集在預處理階段已進行歸一化處理。

表2 地磁數據集不同模型誤差對比

圖7 地磁數據集PCA 可視化結果

圖8 股票數據集PCA 可視化結果

從表2 和表3 可以看出,MDGAN 的誤差略低于 TimeGAN,但是明顯低于 SigCWGAN 和RCGAN。這說明本文所提模型的準確性要高于其他3 種模型。

表3 股票數據集不同模型誤差對比

3.4.4 總體分析

在loss函數收斂性方面,MDGAN與SigCWGAN不相上下,明顯高于TimeGAN 和RCGAN。在主成分分析中,MDGAN 模型合成數據的分布最接近真實數據的分布。在誤差分析中,MDGAN 的誤差略低于TimeGAN,但是明顯低于SigCWGAN 和RCGAN。

從模型的綜合性能比較,本文所提MDGAN 要略優于 SigCWGAN 和 TimeGAN,明顯高于RCGAN。

4 結束語

本文設計了基于生成對抗網絡的多鑒別器時間序列生成模型,該模型采用4 種不同的鑒別器對合成數據進行鑒別,進而更好地識別時間序列的數據特征,使生成器能夠快速合成高質量的數據。實驗表明,對于地磁和股票這2 種不同類型的數據集,所提模型均能夠合成出與真實數據近似度較高的數據,在模型收斂性、合成數據分布以及合成數據誤差3 個方面都保持了良好的性能。

本文所設計的MDGAN 模型能夠為一些需要大量時間序列數據集的用戶提供一個獲取數據的有效手段。盡管本文所提模型只通過2 種數據集進行了實驗驗證,但該模型的設計思路是可以借鑒并拓展的。在面對更加廣泛的時間數據集時,可以采取針對性的特征鑒別,適當調整鑒別器的結構,使其達到復雜度和精度的最優折中。未來可進一步對特征提取的環節進行研究,使生成器輸出的合成數據具有更強的可控性。

猜你喜歡
鑒別器頻域誤差
基于雙鑒別器生成對抗網絡的單目深度估計方法
基于DDR-CycleGAN的紅外圖像數據增強
基于頻域的聲信號計權改進算法
角接觸球軸承接觸角誤差控制
Beidou, le système de navigation par satellite compatible et interopérable
壓力容器制造誤差探究
頻域稀疏毫米波人體安檢成像處理和快速成像稀疏陣列設計
網絡控制系統有限頻域故障檢測和容錯控制
九十億分之一的“生死”誤差
基于改進Radon-Wigner變換的目標和拖曳式誘餌頻域分離
91香蕉高清国产线观看免费-97夜夜澡人人爽人人喊a-99久久久无码国产精品9-国产亚洲日韩欧美综合