?

一種建立在GPT-2模型上的數據增強方法

2024-04-09 01:42張小川陳盼盼邢欣來楊昌萌滕達
智能系統學報 2024年1期
關鍵詞:語義分類樣本

張小川,陳盼盼,邢欣來,楊昌萌,滕達

(重慶理工大學 兩江人工智能學院, 重慶 401135)

句子分類[1](sentence classification,SC)是最基本和常見的自然語言處理(natural language process,NLP)任務之一,廣泛應用于NLP的很多子領域,如意圖識別、情感分析、問題分類等。當給定一個句子作為輸入時,其任務是將其分配給一個預定義標簽。深度神經網絡往往需要大規模的高質量標記的訓練數據來實現高性能,然而在特定領域,由于人工標注數據集代價昂貴,常常只有少量樣本可供使用。本文研究在數據匱乏情況下的句子分類任務準確率較低的問題,訓練數據的不足使得句子分類任務模型無法得到有效的訓練,從而導致泛化能力差。為解決這一問題,數據增強是一種有效的方法。

通常,數據生成的語義一致性和多樣性對目標任務至關重要[2],語義保留即前后語義保持一致是數據增強最基本的要求,訓練樣本的豐富表達能使神經網絡更好地學習權重。一些學者的研究工作已經開始注重數據的多樣性和質量。如在計算機視覺中,文獻[3]使用代理網絡來學習如何增強多樣性。孫曉等[4]利用生成對抗網絡生成同一個人的不同面部表情實現數據增強。NLP中的一些研究[5]對原句進行隨機替換、隨機交換、插入和刪除操作實現增強數據的多樣性,為了避免簡單數據增強方法(easy data augmentation,EDA)方法引入過多噪聲,一種更簡單的數據增強方法(an easier data augmentation, AEDA)[6]將隨機插入token改為隨機插入標點符號,一定程度上緩解了噪聲引起的語義偏差問題,然而隨機插入標點符號可能會不恰當地斷句,語義保留和多樣性仍無法同時有效控制。隨著大規模預訓練語言模型的問世,一些研究將其應用于數據增強,Anaby等[7]提出基于語言模型的數據增強方法(language-model-based data augmentation, LAMBADA),采用訓練數據微調GPT-2模型[8],在訓練過程中將相應的標簽拼接到每個樣本,以便為該類生成新數據,在句子分類方面取得了顯著的改進。然而,該方法采用top-k和top-p采樣的方式增加多樣性,這種方式很有可能會導致累計誤差的產生,使得生成句子質量低下。

從本質上講,語義一致性和多樣性的目標其實是相互沖突的,即生成多樣性高的樣本更可能導致語義發生變化,因此,需要同時考慮多樣性與語義一致性,對生成數據進行控制,得到較為平衡的數據。本文提出一種引入懲罰項的數據增強方法(punishing generative pre-trained transformer for data augmentation, PunishGPT-DA),用于生成增強數據來改進句子分類任務。此方法的數據增強過程建立在預訓練語言模型GPT-2基礎上,通過設計懲罰項、超參數,使用雙向編碼器表征模型(bidirectional encoder representations from transformers,BERT)[9]作為過濾器完成數據增強。實驗結果表明了該方法的有效性。

1 數據增強相關工作

從增強數據的多樣性來看,數據增強方法可以大致分為基于復述的方法、基于噪聲的方法和基于采樣的方法3類。

基于復述的方法包括在詞匯、短語、句子層面的重寫。Zhang等[10]首先利用詞庫(a electronic lexical database, WordNet)替換句子中的同義詞應用于數據增強;條件BERT(conditional bert, CBERT)[11]掩蓋句子的部分字符,由BERT生成替換詞;Jiao等[12]使用數據增強來獲得特定任務的蒸餾訓練數據,利用BERT將單詞標記為多個單詞片段,并形成候選集;回譯以生成的方式重寫整個句子,被應用于低資源句子分類[13],使用不同的二級語言提高了分類精度,Hou等[14]通過L層變換器對串聯的多個輸入話語進行編碼,利用重復感知注意和面向多樣性的正則化生成更多樣化的句子。Kober等[15]使用對抗生成網絡(generative adversarial network, GAN)生成與原始數據非常相似的樣本。

基于噪聲的方法添加微弱噪聲,使其適當偏離原始句子。EDA[5]通過隨機插入、刪除、替換、交換操作得到增強數據。Peng等[16]通過刪除對話語句中的槽值來獲得更多的組合;Sahin等[17]通過依賴樹變形對句子進行旋轉。Sun等[18]將混合技術應用到基于Transformer的預訓練模型中進行數據增強(Mixup-Transformer),將Mixup與基于Transformer的預訓練結構相結合,進行數據增強;Feng等[19]在提示部分隨機刪除、交換和插入文本字符,用于微調文本生成器;Andreas[20]提出了一種簡單的數據增強規則,通過采用出現在一個類似環境中的其他片段替換真實的訓練樣本的某個片段,來合成新的樣本。Guo等[21]提出一種序列到序列模型的混合方法(sequence-level mixed sample data augmentation,SeqMix),通過組合訓練集中的輸入輸出序列來創建新的合成樣本。丁家杰等[22]通過對原始數據集中的噪聲進行處理擴充數據集,在問答任務上實現了良好效果。

基于采樣的方法掌握數據分布,并在其中采樣新的樣本。大型語言模型(large language models, LLMs)的出現為生成類似于人類標注的文本樣本創造了新的條件。LLMs的參數空間允許它們存儲大量知識,大規模預訓練使得LLMs能夠編碼用于文本生成的豐富知識。如生成式預訓練語言模型(generative pre-trained transformer, GPT)系列,GPT~GPT-3[8,23-24]采用預訓練+微調的方式,其中預訓練階段通過大規模的無標注數據對模型進行訓練,使其學習到通用的語言表示和語義理解能力,微調階段利用有標注數據進行監督學習,使模型能夠適應特定的任務要求,提高性能和準確度。GPT系列目前已經發展到4.0, 聊天生成預訓練轉換器(chat generative pre-trained transformer, ChatGPT)遵循指導生成預訓練轉換器(instruct generative pre-trained transformer,InstructGPT)[25]的訓練方式,利用帶有人類反饋的強化學習(reinforcement learning from human feedback, RLHF),使其在對話領域能夠對輸入產生更豐富的響應。這些最先進的模型也被廣泛地用來進行數據增強,Abonizio等[26]通過連接樣本中的3個隨機token作為GPT-2模型生成階段的前綴生成樣本。Kumar等[27]研究了不同類型的基于Transformer的預訓練語言模型,表明將類標簽處理到文本序列為微調預訓練模型進行數據增強提供了一種簡單有效的方法;Bayer等[28]設計了一種基于GPT-2的方法,通過設計不同的前綴分別處理短文本和長文本的生成,在短文本任務和長文本任務上都取得了很好的改進。類似的,Claveau等[29]使用特定于類的數據微調GPT-2模型,并從原始文本中輸入一個隨機單詞進行生成。然后應用分類器對生成的數據樣本進行過濾。Liu[30]凍結GPT-2模型softmax之前的層,采用強化學習對softmax之后的層進行微調。隨著ChatGPT的問世,Dai等[31]提出了ChatAug,利用ChatGPT為文本生成增強數據,獲得了顯著提升。

引入噪聲的方法可以有效提升數據的多樣性,利用預訓練語言模型的數據增強方法可以更好地學習到語言規律和語義信息,因此,基于上述工作,本文提出懲罰生成式預訓練語言模型的數據增強方法(punishing generative pre-trained transformer for data augmentation, PunishGPT-DA),通過設計損失函數微調預訓練語言模型GPT-2,有效保證增強數據的質量。

2 PunishGPT-DA

2.1 方法概述

句子分類是一種基于句子數據進行分類的任務,屬于監督學習問題的一個實例。給定訓練集Dtrain=,包含N個訓練樣本,其中xi是由{xi1,xi2,···,xip}組成的文本序列,包含p個字符,li∈{1, 2, ···,q}表示在含有q個標簽的集合中,樣本xi對應的標簽。xi∈X,X代表整個樣本空間,假設對于所有N,存在函數f,使li=f(xi),監督學習的目標是在僅給定數據集Dtrain的情況下在整個X上近似f,從Dtrain的域推廣到整個X,即在Dtrain上訓練分類算法F,使其能夠近似f,然而如果Dtrain非常小,將顯著地影響算法F的性能。數據增強試圖通過合成額外的訓練數據來解決這個問題,給定訓練集Dtrain和算法F,本文的目標是生成Daug=,Daug=Dtrain∪Dfilter,其中Dfilter是方法每次迭代后生成的數據,Daug是最終數據集,包含T個樣本,yj是由{yj1,yj2,···,yjm}組成的文本序列,包含m個字符,對應標簽為lj。

為此,本文提出了一種面向句子分類的數據增強方法PunishGPT-DA。PunishGPT-DA由生成器Gθ和過濾器F2個模塊組成。圖1說明了本方法的步驟:1)通過改進的損失函數微調生成器的語言模型,訓練生成器學習在原始句子的基礎上合成新樣本,得到參數被微調之后的生成器Gθ。2)對Dtrain進行處理作為Gθ的輸入生成數據Dsyn,Dsyn相較于原損失函數訓練出的生成器生成的數據擁有更高的多樣性,但也不可避免地引入了噪聲。3)針對此問題,采用原始數據Dtrain微調過濾器F,將每次迭代生成的樣本Dsyn由F過濾,丟棄低質量的樣本,得到過濾后的增強樣本Dfilter,Dfilter并入原始數據集中作為新的Dtrain進行下一次迭代,經過一定次數的迭代后得到最終的數據集Daug。

圖1 PunishGPT-DA數據增強過程Fig.1 PunishGPT-DA data augmentation process

2.2 生成器

PunishGPT-DA采用預訓練語言模型GPT-2生成數據,GPT-2是一個在海量數據集上訓練的語言模型,采用“預訓練+微調”的二段式訓練策略,它利用龐大的語料庫進行預訓練,語料庫被處理成由token組成的長序列,由U=w1,w2,···,wj,···,wT表示,生成模型采用無監督自回歸訓練的方式,以最大化生成目標序列的概率為目標,根據極大似然估計,可以最大化目標序列U出現的概率,即最大化P(U),根據條件概率的鏈式法則,可以將生成目標序列的概率表示為條件概率的乘積:

將式(1)取對數并加上負號,得到負對數似然損失函數為

在數據增強任務中,同預訓練一致,以句子自身指導模型的微調,即以最大化生成目標序列的概率為目標,因此,以負對數似然函數作為損失函數的生成模型鼓勵生成與原數據相似的句子,使生成的文本趨于重復和“枯燥”,當以此為目標訓練得非常好時,甚至會生成與輸入句子完全一致的樣本數據。

為了關注生成數據的多樣性,本文引入懲罰項來中和現有的損失函數,同時為了平衡多樣性與語義一致性,引入超參數α,改進后的損失函數為

式(3)是一種加權損失函數,由Jθ和exp(-Jθ)2部分組成。其中Jθ,即式(2)是負對數似然損失,用于衡量生成的序列和目標序列之間的差距;exp(-Jθ)將其視為懲罰項,用于懲罰過度相似的生成結果,這意味著,如果生成器產生與目標序列中過于相似的token,它將受到懲罰。本文擬通過添加exp(-Jθ),使模型會在給定上下文條件下,根據語言的語法和語義規則,更加關注可能性較小但仍然有一定意義與合理性的輸出。這些輸出可能是預測概率較小但仍然合理的單詞、短語、句子結構等,在某些情況下可能會提供更有趣、更具創造性的文本。α是一個用于控制Jθ和exp(-Jθ)2部分在損失函數中重要程度的超參數,當α較小時,exp(-Jθ)的影響更大,從而鼓勵生成多樣性更高的樣本。相反,當α較大時,Jθ的影響更大,從而鼓勵生成語義一致性更高的樣本。因此,式(3)可以看作在保證生成序列準確的基礎上,通過懲罰過度自信的生成結果來鼓勵生成更多的多樣性,通過調整α的值,可以在一致性和多樣性之間進行平衡,獲得高質量的生成結果。

此外,在預測階段,通常采用序列的前i個字符作為前綴提示后續詞語的生成,然而,Dtrain中存在多個序列前i個字符相同,以相同的前綴作為提示會導致原本不同標簽的2個句子對應的增強樣本可能相同,使得增強樣本語義標簽不明。因此,本文為每條訓練數據添加了數字序號作為該數據的唯一標志,數字序號隨訓練數據一起參與訓練。在預測階段,數字序號與前i個字符一起作為前綴,確保了前綴的唯一性,并為生成器提供了額外的上下文,形式為(〈SOS〉,w1,w2,···,wi),其中〈SOS〉是數字序號,(w1,w2,···,wi)是樣本的前i個字符。這種操作確保了增強樣本彼此不同,但仍然基于實際數據。

2.3 過濾器

使用增強樣本的一個障礙是它可能引入的噪聲和誤差。雖然在微調生成器時同時考慮了語義保留和豐富表達,避免了模型過度生成低頻詞,但自然語言具有復雜性,有可能微小的改動便會影響句子的語義,導致增強數據集中的低質量樣本對下游任務模型的性能產生影響。為此,如圖1所示,本文使用基于BERT的過濾器F對其進行過濾選擇,過濾器F包括BERT層、線性層、ReLU激活函數層。輸入數據首先經過BERT層獲取特征表示,其次通過Dropout技術進行正則化處理,以減少過擬合風險,然后將Dropout層的輸出輸入到一個具有786個輸入特征和類別數量輸出特征的線性變換層,將特征表示映射到分類標簽的空間,最后經過ReLU激活函數得到最終的分類結果。對于生成的樣本 (y,l),驗證是否F(y)=l,若分類正確則保留,不正確舍棄。因此,每一次完整的迭代后會得到增強數據集Dfilter,Dfilter并入原始集作為新的訓練集。

3 實驗結果與分析

3.1 數據集

本文共使用了3個公開的句子分類數據集,分別是由法國公司SNIPS在人機交互過程中收集的數據集SNIPS,包含7個意圖類別共14 484條數據。由文本檢索會議(text retrieval conference, TERC)標注的細粒度問題分類數據集TREC,包含6種問題類型共5 952條數據。由斯坦福大學自然語言處理組標注的情感分析數據集(stanford sentiment treebank v2, SST-2), SST-2屬于電影評論情感分類的數據集,用2個標簽(positive和negative)標注,共8 741條數據。

3.2 實驗設置

根據先前工作[25]模擬用于句子分類少樣本場景的設置,本文針對每個任務的訓練集進行子采樣,每個類隨機選擇10個樣本,每個數據增強模型均對其進行16倍擴充。為避免數據集的隨機性帶來誤差,本文一個任務下的對比實驗均采用相同的子數據集。為更好地測試模型的性能,本文的驗證集和測試集采用完整的數據集。

在微調GPT-2階段,設置批量大小為2,迭代次數為100,學習率設定為1×10-5,樣本最大長度為20,超過則截斷;生成數據時每條句子的提示為“i w1w2”。BERT在大量數據上進行預訓練,并在幾個句子分類任務上表現出最先進的性能。因此,本文使用BERT模型構建過濾器及句子分類器,本文使用“BERT-Base-Uncased” 模型,該模型有12層,768個隱藏狀態和12個頭。PunishGPT-DA使用BERT模型第1個特殊字符([CLS])的輸出作為句子的特征表示,在傳入下一層進行分類之前,以0.1的dropout設置應用于句子表示。訓練過程采用自適應矩估計算法(adaptive moment estimation,Adam)進行優化,學習率設置為4×10-5,本文對模型進行100個epoch的訓練,并在驗證集上選擇表現最好的模型進行評估。

所有的實驗均在Intel Core i5-9 500 3.00 GHz處理器,GeForce RTX 2028 SUPER顯卡,Ubuntu 20.04.4 LTS,python 3.8.0下進行。

本文實驗將與以下模型進行對比:

1) GPT-2[7]:為驗證本文提出損失函數的有效性,本文以GPT-2作為基準模型,該模型以式(1)為損失函數,其余條件與PunishGPT-DA保持一致。

2) EDA[4]:以詞替換、交換、插入和刪除為基礎的數據增強方法。

3) AEDA[5]:在句子中隨機插入標點符號實現數據增強。

4) GPTcontext[25]:采用文獻[6]中的方式,將標簽與序列連接起來構造訓練集:y1SEPx1EOSy2, ···,ynSEPxnEOS。在此基礎上以yiSEPw1, ···,wk作為生成階段的提示,生成增強數據。

3.3 實驗結果與分析

本文對比了在意圖識別、問題分類及情感分析任務少樣本情景下的數據增強策略,表1總結了多種數據增強方法下同一模型在不同數據集中的分類準確率。

表1 不同增強策略下的模型準確率Table 1 Model accuracy under different augmentation strategies%

如表1所示,與基線模型GPT-2相比,本文提出的數據增強方法在3個數據集上的準確率相對提升了1.1%、4.9%和8.7%,這說明本文提出的損失函數能有效提升增強數據的質量;相較于EDA、AEDA和GPTcontex方法,本文提出的數據增強方法在3個數據集上的準確率均有提升,表明了本文增強方法的普遍性。

本文對比了不同超參數α設置下PunishGPTDA的性能,采用SNIPS 的子采樣后的數據集,每個類別包含10個樣本,對其進行16倍擴充。如圖2所示,α=0.3之前模型準確率較低,這是因為在超參數控制下增強數據多樣性較強,為數據集引入了過多的噪聲;隨著α增大,曲線逐漸上升,直到α=0.45時下游任務模型準確率達到最高,此時生成模型能夠很好地控制數據多樣性和一致性之間的平衡,使模型準確率達到最好的效果;隨著α繼續增大,一致性占據優勢,使得生成數據相較于原數據只有微小的改動,致使模型準確率下降,趨于平緩。這表明,本文提出的損失函數能夠同時控制語義和多樣化的表達,有效平衡數據的一致性和多樣性。

圖2 不同超參數下模型準確率Fig.2 Model accuracy under different hyperparameters

本文研究了過濾機制對PunishGPT-DA性能的影響,分別在3個子采樣后的數據集上進行了消融實驗。實驗結果如表2所示,刪除了過濾機制后,模型準確率均有下降。這表明過濾器對整個增強過程至關重要。

表2 過濾機制對PunishGPT-DA的影響Table 2 Influence of filtering mechanism on PunishGPT-DA%

此外,本文還研究了在不同數據集大小情況下PunishGPT-DA對下游任務模型性能的影響。表3為模型在SNIPS 數據集上進行實驗的結果,每種意圖類別分別取為5、10、20、50、100條數據作為訓練樣本,構成少樣本數據集,并進行16倍擴充。如表3所示,隨著訓練數據的增多,本文的數據增強方法對下游任務模型性能的提升作用越來越弱。這表明在少樣本情境下,本文所提出的數據增強方法可以有效提升句子分類任務模型性能,當訓練數據較為充足時,已經能為下游任務模型提供較為豐富的信息,數據增強帶來的效益也就隨之減弱。

表3 PunishGPT-DA在不同數據集大小下的準確率Table 3 Accuracy of PunishGPT-DA under different dataset sizes%

為了更加明確損失函數的作用機制,本文分別對采用2種損失函數生成的數據進行了探索,如表4所示,本文分別摘取了部分數據。通過觀察損失函數式(3)生成的數據及過濾后的數據可以發現,數據較原始數據有較大的多樣性,但大體上符合標簽語義;采用損失函數式(2)生成的數據較原始數據只有個別單詞的變化,多樣性引入不足。由此可以發現本文提出損失函數的有效性。

表4 生成數據示例Table 4 Generate data samples

4 結束語

針對少樣本句子分類任務中訓練數據不足的問題,本文提出一種平衡語義一致性和多樣性的數據增強方法PunishGPT-DA,與當前主流方法相同,此方法建立在大規模的預訓練語言模型的基礎上,同時又區別于當前主流方法修改提示指導生成模型生成階段的做法,本文提出的方法從訓練角度指導模型生成數據。實驗結果表明,在小樣本情景下,本文方法可以更有效地保證數據質量,有效提高句子分類模型的分類準確率。盡管本文解決了增強樣本質量不高的問題,然而通過損失函數控制數據的生成,可能會導致語法不可控地變化,不符合人類正常的閱讀習慣,因此,在句子結構多樣性方面還有一定的提升空間。下一步將探索句子結構方面的改進,使其更加自然流暢。

猜你喜歡
語義分類樣本
分類算一算
用樣本估計總體復習點撥
語言與語義
分類討論求坐標
推動醫改的“直銷樣本”
數據分析中的分類討論
教你一招:數的分類
隨機微分方程的樣本Lyapunov二次型估計
“上”與“下”語義的不對稱性及其認知闡釋
村企共贏的樣本
91香蕉高清国产线观看免费-97夜夜澡人人爽人人喊a-99久久久无码国产精品9-国产亚洲日韩欧美综合