?

結合注意力轉移與特征融合算法的在線知識蒸餾

2023-02-02 09:25梁興柱
湖北理工學院學報 2023年1期
關鍵詞:錯誤率分支注意力

梁興柱,徐 慧,胡 干

(1.安徽理工大學 計算機科學與工程學院,安徽 淮南 232001;2.安徽理工大學環境友好材料與職業健康研究院(蕪湖),安徽 蕪湖 241003)

近年來,深度神經網絡憑借強大的特征學習能力在計算機視覺處理中取得了令人欣喜的成績[1-2]。但是,功能強大的模型往往會伴隨大量的參數,占據較大的內存,不利于模型的部署與應用[2],因而提出了深度學習模型壓縮技術,包括剪枝、二值化、輕量化模型設計[3-4]、知識蒸餾(Knowledge Distillation , KD)[5]等。其中,知識蒸餾又可以劃分為離線知識蒸餾(Offline KD)和在線知識蒸餾(Online KD)。傳統Offline KD[5]是一個兩階段蒸餾方法,必須先預訓練一個功能強、參數多的教師模型,然后再將教師模型學到的知識遷移到性能較弱、參數較少的學生模型上,從而得到一個速度快、能力強的網絡,達到減少參數、提高學生模型性能的目的,但存在訓練時間長、計算成本高、占據內存大等缺點。Zhang等[6]利用深度相互學習(Deep Mutual Learning,DML)使學生模型直接從其他學生模型的預測中學習。Lan等[7]提出采用一個門控單元作為網絡共享低層,在動態建立教師模型的同時訓練多分支網絡。Romberg等[8]進一步優化了FitNet 網絡并率先在知識蒸餾領域提出利用教師與學生間的注意力圖學習來代替特征圖學習。Chen等[9]在OKDDip兩級蒸餾方法的研究基礎上使用分類器多樣化損失函數和特征融合模塊來提高學生模型的多樣性和網絡中注意機制的性能。

現有的大部分方法很難構建一個功能強大的教師角色,且忽略了單個子分支性能的自我提升。因此,本文提出一種結合注意力機制(Attention Mechanism,AM)與特征融合(Feature Fusion Module,FFM)的在線知識蒸餾方法(KD-ATFF),利用特征融合構建強大的教師角色指導模型訓練,同時將深層神經元的注意力轉移到淺層網絡,進一步提升子分支的性能。

1 KD-ATFF

KD-ATFF擁有n個分支在線集成網絡模型,每個分支網絡由M個block組成。在沒有教師模型指導的前提下,每個block在訓練過程中將自己的特征圖轉化為注意力圖,模塊之間相互學習彼此的注意力地圖,增加知識差異性。各個分支由淺層網絡到深層網絡相互學習得到的多樣性知識被保留到最后1個block的特征,然后再將各分支的最后1個block的特征送入特征融合模塊進行信息融合,最后根據模塊內部為各分支分配的權重組成集成教師指導各子模型訓練。KD-ATFF模型結構如圖1所示。

圖1 KD-ATFF模型結構

1.1 CL模塊

(1)

(2)

采用L2范數作為2個模塊進行互學習的損失函數,將2個模塊的注意力圖進行二次互學習得到的值進行相加后取1/2作為一次損失。模塊間的互學習目標函數為:

(3)

1.2 特征融合模塊

KD-ATFF共有n個相同的網絡架構分支,多分支的特性比來自單分支的特性包含的信息要豐富得多。為便于表示,所有的學生模型從1到n進行索引。由于網絡的深層會產生更豐富的語義信息,故將來自多個分支的最后1個block的特征作為特征融合模塊的輸入。這樣可以利用高級的語義信息來豐富特征,所產生的權值能取得較好的效果。加權集成目標ze可表示為:

(4)

式(4)中,f(·)為特征融合模塊中心塊的函數,為每個分支輸出相應的重要性分數;Fa為來自第a個分支的最后一個塊的特征圖;za為來自第a個分支的logits。

以3個分支作為輸入為例,來自每個分支的最后1個block的特征映射Fa將被連接在一起,然后送入中心卷積塊。中心卷積塊是由多個卷積層、批處理歸一化和ReLU激活函數組成,中心塊的最后一層是全連接層,用于融合來自多個分支的語義信息。與其他方法相比,特征融合模塊可以獲得更多的語義信息,能夠有效地提高模塊的性能。最終目標由各輔助分支的logits輸出zi加權和得到。特征融合模塊結構如圖2所示。

圖2 特征融合模塊結構

1.3 蒸餾模型的損失函數

1.3.1傳統標簽損失

(5)

式(5)中,T為溫度參數,取T=3。利用最小化交叉熵訓練,得到標簽學習的損失為:

Llabel=-∑ieilogqi

(6)

式(6)中,ei是標注的標簽分布;qi是最小化預測的類概率。

1.3.2注意力轉移損失函數

(7)

式(7)中,αm和αm+1分別為第m以及第m+1個block輸出的注意力圖。

1.3.3集成教師損失函數

KD-ATFF每個分支不僅從地面真實標簽中學習,還從通過特征融合模塊獲得的加權集成目標中學習。知識轉移是通過將學生模型生成的概率分布q與目標分布z對齊實現的。用KL散度表示其損失函數,第a個學生模型的預測分布為qa(a=1,2,…,n),每個輔助分支學習ze中提取的知識,故所有分支的蒸餾損失為:

(8)

則,整個KD-ATFF的損失函數為:

(9)

Llabel是第α個子模型的傳統知識蒸餾損失;θ,β是調節軟硬標簽比例的超參數。

1.4 KD-ATFF的算法流程

與兩階段傳統蒸餾訓練不同,Online KD中學生網絡和集成教師同時進行訓練。在每個子網絡進行相同的隨機梯度下降,并訓練整個網絡直到收斂,作為標準的單模型增量批處理訓練。批處理貫穿整個訓練過程,在每個batch進行子模型參數的更新和執行訓練。KD-ATFF的算法流程如下。

輸入:訓練數據集D;訓練Epoch數;分支數n

輸出:n個訓練好的模型{θ1,θ2,θ3,…,θn}

2.whilee≤do

3.使用公式(5)計算所有分支的預測{q1,q2,…,qn}

5.計算每個子模型的輸出:(z1,z2,…,zn)

6.通過FFM獲取每個分支的權重

7.使用公式(4)計算目標logits

8.利用公式(8)計算蒸餾損失LKL

10.e=e+1

11.end while

2 實驗分析

2.1 實驗設置與數據集

實驗采用CIFAR-10和CIFAR-100多類別分類基準數據集。CIFAR-10是1個自然圖像數據集,包含從10個對象類中提取的50 000/10 000個訓練/測試樣本(總共60 000個圖像),每個類有6 000個大小為32×32像素的圖像。CIFAR-100與CIFAR-10類似,也包含50 000/10 000個訓練/測試圖像,但覆蓋100個細粒度類,每個類別有600張圖片。實驗在CIFAR-10/100數據集上的batchsize設置為256。

實驗所使用的學生網絡包括:ResNet-32(3.3 M)[2],ResNet-110(0.5 M)[2]和MobileNet(1.7 M)[10]。分支m設置為2;θ,β,T分別為1,1,3;選取隨機梯度下降,SGD為優化器;模型的學習率初始化為0.1,并且每80周期減少為原來的1/10。采用top-1分類錯誤率,將所有模型的訓練結果取平均值。模型訓練和測試的計算成本,使用浮點運算(FLOPs)標準。

2.2 實驗結果分析

將KD-ATFF方法與幾種有代表性的蒸餾方法進行比較,采用不同的骨干網絡分別在CIFAR-10/100數據集上進行實驗。不同骨干網絡在CIFAR-10/100上的top-1錯誤率見表1。

表1 不同骨干網絡在CIFAR-10/100上的top-1錯誤率

由表1可知, KD-ATFF與其他幾種方法相比,top-1錯誤率有明顯降低,能適用于不同網絡且能得到良好的分類效果。具體來說,KD-ATFF模型在CIFAR-10上以ResNet-32或MobileNet為骨干網絡時,比原始Baseline的top-1錯誤率降低了約30%;在CIFAR-100上以ResNet-110為骨干網絡時,top-1錯誤率比DML降低了1.57%,比ONE降低了1.31%;以MobileNet為骨干網絡時,top-1錯誤率比DML降低了1.76%;在CIFAR-10上以ResNet-32為骨干網絡以及在CIFAR-100上以ResNet-110為骨干網絡時,與最新的OKDDip方法相比,top-1錯誤率都能夠與其比肩甚至比其更優。實驗結果表明,KD-ATFF提高了模型的泛化性,訓練出的模型更加高效,對提升模型準確率有很大的貢獻。

2.3 消融實驗

為驗證CL模塊和特征融合模塊的有效性,在CIFAR-100數據集上使用ResNet-110為骨干網絡進行消融研究,將這2個模塊與ONE中的Gate模塊進行比較。消融實驗結果見表2。由表2可知,當只使用CL模塊時的性能已經略超過其他方法。這說明CL模塊在不同維度的注意力轉移能學習到更多的知識。與ONE中的Gate模塊相比,CL的top-1錯誤率降低了0.76%。FFM與CL模塊同時工作時整體模型的改善更為明顯。與獨立的CL模塊相比,top-1錯誤率降低了1.31%。實驗結果表明,CL模塊在整體性能改善中發揮了重要的作用,特征融合模塊可以明顯增強分支間的多樣性。

表2 消融實驗結果

3 結論

結合注意力轉移與特征融合的在線知識蒸餾方法(KD-ATFF)是一種可以在不需預先訓練教師模型的前提下訓練學生模型的改進在線蒸餾模型?;谀K間的差異性,KD-ATFF引入了注意力機制,讓不同維度的模塊互相學習,在各子模型的最后1個block的特征輸出加入特征融合模塊,分配不同權重組成集成教師指導各子模型訓練,以提升整體模型的性能。與其他幾種代表性的在線知識蒸餾方法相比,KD-ATFF的top-1錯誤率明顯降低,驗證了注意力轉移和CL模塊以及特征融合模塊的有效性。

猜你喜歡
錯誤率分支注意力
讓注意力“飛”回來
巧分支與枝
小學生分數計算高錯誤率成因及對策
一類擬齊次多項式中心的極限環分支
“揚眼”APP:讓注意力“變現”
正視錯誤,尋求策略
A Beautiful Way Of Looking At Things
解析小學高段學生英語單詞抄寫作業錯誤原因
降低學生計算錯誤率的有效策略
生成分支q-矩陣的零流出性
91香蕉高清国产线观看免费-97夜夜澡人人爽人人喊a-99久久久无码国产精品9-国产亚洲日韩欧美综合