?

基于融合CNN和Transformer的圖像分類模型

2022-12-18 07:19何明智朱華生李永健唐樹銀孫占鑫
南昌工程學院學報 2022年4期
關鍵詞:分支全局局部

何明智,朱華生,李永健,唐樹銀,孫占鑫

(南昌工程學院 信息工程學院,江西 南昌 330099)

近十年來,卷積神經網絡(CNN)在圖像分類任務中一直扮演著非常重要的角色,它具有十分優秀的特征提取能力。如Krizhevsky[1]等提出由分層卷積構成的AlexNet曾獲得2012年的圖像分類大賽冠軍。而Simonyan[2]等提出的VGGNet將卷積操作推向更深層。Szegedy[3]提出的Inception模型則從通過組合不同大小的卷積核來提升網絡的性能。He[4]等提出帶有殘差連接的ResNet解決了深層卷積神經網絡出現的梯度爆炸和梯度消失的問題。Huang[5]等提出帶有密集連接的DenseNet更能充分利用卷積層間的特征流。盡管CNN網絡擁有出色的局部特征捕獲性能,但在獲取全局特征上能力不足。

近幾年來,發展迅速的視覺注意力機制在一定程度上幫助傳統的CNN網絡獲取全局特征信息。如Hu[6]等提出的SENet利用全局自適應池化層在通道域上捕獲全局信息后壓縮再加權到特征通道上。而Woo[7]等提出的CBAM則同時利用最大池化層和平均池化層分別對通道域和空間域信息重新整合后再加權。雖然利用池化運算能以較小的參數量和計算量獲得全局特征表示,但池化運算不可避免地會忽略一些重要的細節信息。受非局部均值方法[8](non-local mean)的啟發,Wang[9]等提出由自注意力機制構成的non-local模塊,使特征圖的每個位置的響應是全局位置特征的加權和,使得CNN獲得全局特征信息。但基于整體2D特征圖的自注意力運算量大,不利于在空間高分辨率進行的視覺任務。

而近年來,Transformer架構[10]在自然語言處理任務上獲得的成功,讓研究人員將Transformer引入到視覺任務中。Dosovitskiy[11]等在原生的Transformer架構上改進,提出基于視覺任務的Vision Transformer(ViT)。ViT將輸入的圖像劃分成固定大小的特征塊,經過線性變換后得到特征序列,然后對特征序列進行多頭自注意力運算,既能充分獲得長距離的特征依賴,同時也降低了運算量。作為優秀的特征提取骨干網絡,ViT也被廣泛應用于目標檢測[12]和目標跟蹤[13]等任務。但由于ViT直接對特征圖劃分成特征塊序列,導致提取邊緣以及局部特征信息能力減弱。因此在沒有超大規模數據集預訓練下,ViT在圖像分類任務表現較差。針對這個問題,Chen[14]等提出加入卷積算子的友好型Transformer架構Visformer模型,它在較小規模的數據集上表現出色。而D’Ascoli[15]等提出的ConViT模型,則是將CNN的歸納偏置帶到Transformer中,提升Transformer對圖像樣本利用率。Graham[16]等提出的Levit模型在圖像劃分前利用級聯多個小卷積能獲取圖像的局部特征,同時增大卷積步長,對圖像進行下采樣,有效降低模型的參數量。而針對ViT模型復雜的位置編碼,Zhang[17]等提出的ResT模型利用深度可分離卷積對特征塊嵌入相對位置信息。

與以上現有模型不同,本文提出了一種基于融合CNN和Transformer的圖像分類模型FCT(Fusion of CNN and Transformer,簡稱為FCT)。FCT模型由CNN分支和Transformer分支融合構成。在FCT模型中CNN分支不僅在低層次中向Transformer分支補充基礎的局部特征信息,并且在模型的中、高層次中,CNN分支也能向Transformer架構提供不同的局部和全局特征信息,增強模型獲取特征信息的能力,提升圖像分類的準確率。

1 模型結構

1.1 整體模型結構

在基于深度學習的圖像分類領域,局部特征和全局表示一直是許多優秀模型不可缺少的組成部分。CNN模型通過級聯卷積操作分層地收集局部特征,并保留局部線索作為特征圖。Vision Transformer則通過級聯的自注意力模塊以一種軟的方式在壓縮的特征塊之間聚合全局表示。

為了充分利用局部特征和全局表示,本文提出了一個融合網絡結構FCT。FCT模型利用來自卷積分支的局部特征逐步提供到Transformer分支用以豐富局部細節,使得FCT網絡模型獲得局部特征和全局表示。

如圖1所示,FCT模型主要由卷積stem塊、CNN分支、Transformer分支以及全局自適應池化層和全連接層組成。stem塊由大小為7、步長為2、填充為3的卷積和大小為3、步長為2、填充為1的最大池化層構成,它用于提取初始局部特征(例如邊緣和紋理信息),然后將初步處理后的特征圖傳遞給兩個分支。CNN分支與Transformer分支分別由多個卷積模塊和Transformer模塊組成,這種并行結構可以使CNN分支和Transformer可以分別最大限度地保留局部特征和全局表示。而Patch Embedding則作為一個橋梁模塊,用于將完整的特征圖線性映射成特征塊序列,并逐步地把局部特征圖傳遞給Transformer分支,使CNN分支的局部表示特征圖能和Transformer分支的全局特征表示圖相融合。為了使網絡枝干產生層次表示,隨著網絡的深入,Patch Merging在Transformer分支中起到下采樣的作用,它可以減少特征塊序列的數量,使特征塊數量減少到原來的四分之一,從而有效地降低整體網絡的運算量和參數量。最后,將特征圖輸入到自適應平均池化層中,壓縮成1×1序列,然后通過全連接層輸出參數結果。

1.2 CNN分支

如圖1所示,CNN分支采用特征金字塔結構,其中特征圖的分辨率隨著網絡深度的增加而降低,同時通道數在不斷增加。本文將整個分支分為4個stage,每個stage包含兩組卷積,而根據ResNet-18[4]所定義,每個卷積組由兩個大小為3、填充為1的卷積,以及輸入和輸出之間的殘差連接組成。從stage2開始,每個stage的第一個卷積的步長為2,其余為1。在整個CNN分支中,每個stage都擁有兩個卷積組。ViT模型通過一個步驟的Patch Embedding將一副圖像線性映射為特征塊序列,導致局部細節的丟失。而在CNN網絡中,卷積核在有重疊的特征映射上窗口滑動,這樣能保留精細的局部特征。因此,CNN分支能夠連續地為Transformer分支提供局部特征細節。

圖1 FCT模型結構圖

1.3 Transformer分支

1.3.1 Transformer塊

與ViT模型不改變特征序列的數量和通道數不同,本文的Transformer分支通過Patch Merging下采樣構成特征金字塔結構,其中特征序列的數量隨著網絡深度的增加而減少,同時通道數與CNN分支相對應的stage相同,也在不斷增加,用以更好地與CNN分支傳遞的特征信息相融合。本文將整個分支區分為4個stage,每個stage包含不同數量的Transformer塊。每個Transformer塊由多頭自注意力(MHSA)模塊和多層感知機(MLP)模塊(包含向上映射全連接層和向下映射全連接層,以及包含兩層GELU非線性激活層)組成。每一層的多頭自注意力塊和MLP塊中的殘余連接之前都使用層次歸一化[18](LayerNorm,LN)。Transformer模塊可用下式所表示:

(1)

(2)

其中z為輸入的特征序列,l為Transformer模塊的層次。

在整個Transformer分支中,stage1~stage4的MHSA模塊的頭部數量分別為1、2、4、8。而每個stage中Transformer模塊的數量分別為2、2、6、2。

1.3.2 特征塊線性映射Patch Embedding

標準的Transformer架構的輸入為等長的特征序列,以ViT為例,它在Patch Embedding層將一幅三維圖像x∈h×w×3分割成大小為p×p的特征塊。這些特征圖塊被線性映射為二維特征塊,其中x∈n×c,而n=hw/p2。一般地,ViT模型將特征塊尺寸設計為14×14,特征塊數量為16×16。為了減少參數量和運算量,在本文中,FCT模型將特征塊設計為2×2的大小。如圖1所示,第一個Patch Embedding模塊將寬高為96×96,通道數為64的特征圖劃分成寬高分別為48個2×2的特征塊,即48p×48p,通道數仍然為64。

1.3.3 相對位置編碼

位置編碼(Position Embedding)對于利用特征塊序列的順序至關重要。在ViT中將一組可學習的參數添加到輸入標記中來編碼位置關系。設x∈n×c為輸入,θ∈n×c為可學習的位置參數,則位置編碼的嵌入可表示為

(3)

然而,使用這種可學習的相對位置編碼需要固定特征塊的長度,這限制了改變特征塊長度的處理。在本文的模型中,利用深度可分離卷積獲取特征序列的位置編碼關系后,加權到輸入序列中[15],可表示為

(4)

其中f為深度可分離卷積操作。

1.3.4 特征塊融合Patch Merging

隨著網絡的深入,特征塊的融合能減少特征序列的數量。每個特征塊融合層將每組2×2相鄰特征塊連接,并對連接特征應用線性層,這樣可以使特征序列的數目減少到四分之一,輸出的通道數增大到輸入通道數的2倍。通過加入Patch Merging,使整體網絡模型形成層次結構,使CNN分支的每個stage輸出特征通道數與Transformer分支的每個stage輸入特征序列的通道數相等。而CNN分支的每個stage輸出的特征圖尺寸大小為Transformer分支每個stage輸入序列數量的兩倍。CNN分支向Transformer分支傳遞特征圖,特征圖經過Patch Embedding處理后,得到的特征塊序列與上一個stage的Transformer層輸出特征塊序列的大小、數量以及通道數都相等,因此CNN分支傳輸的特征信息能和Transformer分支融合。

1.4 分支融合

由于CNN分支與Transformer分支上處理的特征結構有所差異,因此,CNN分支的特征圖需要先映射成特征序列,再加入相對位置編碼,Transformer分支上的特征塊序列則需要下采樣,減少特征塊序列的數量。融合其兩個分支可由下式表示:

zl=PM(zl-1)+Pos(PE(f(xl-1))),

(5)

(6)

其中z∈n×c為輸入的特征塊序列,x∈h×w×c為輸入的特征圖,l表示層次,PM表示特征塊融合(Patch Merging),T表示Transformer模塊,Pos表示嵌入位置信息(Position Embedding),PE表示特征塊線性映射(Patch Embedding),f為卷積模塊。由圖1可知,FCT模型將每個stage的CNN分支都將該層次的局部特征以及全局表示信息傳遞到Transformer分支,使得模型融合了豐富多樣的局部和全局特征信息。

2 實驗與分析

2.1 實驗環境

本文實驗的設備CPU為Xeon(R)CPU E5-2680 v4,GPU為NVIDIA GeForce RTX 3060。本文使用的Python版本為3.7.4,Pytorch版本為1.9.0。

2.2 實驗數據集

本文使用Oxford Flowers-102[19]和Caltech-101[20]作為實驗數據集。Oxford Flowers-102為英國常見的102個花卉類別的圖像數據集,每個類包含40到258幅圖像,每幅圖像具有較大的比例、姿勢和光線變化,一共包含8 189幅圖像。Caltech-101由101個類別的物體圖片組成,每個圖像都使用單個對象進行標記,每個類包含大約40到800幅圖像,圖像大小不一,總共8 677幅圖像。以上兩個數據集均按照6∶2∶2的比例隨機劃分成訓練集、驗證集和測試集。本文實驗的數據增強策略僅使用隨機剪裁和隨機水平翻轉。隨機剪裁是在數據訓練時將輸入的圖像數據首先按不同的大小和寬高比進行隨機裁剪,然后縮放所裁剪得到的圖像為384×384分辨率。隨機水平翻轉是在隨機剪裁操作后,以0.5的概率隨機水平翻轉。本實驗僅使用上述數據集所包含的圖像,不使用額外的圖像進行訓練。

2.3 訓練參數

本文實驗模型訓練的優化器為AdamW[21],學習率為0.000 1,權重衰減率為0.01,迭代次數為110。為了能加快模型收斂,本文使用學習率余弦衰退周期地對學習率進行動態調整,設置迭代20次為一個周期。

2.4 熱力圖可視化

本文在一幅有多朵花的圖像上分別對FCT、ResNet-18、ViT-base使用Grad-CAM[22]計算得到3個不同的模型的注意力熱力圖,如圖2所示。ResNet-18能精確地識別出圖像里面位于中心位置的花朵,但是沒有識別出另外的花朵;ViT-base模型雖然都能感受到所有花朵的位置,但是無法獲得更精確的細節信息。與ResNet-18和ViT-base模型識別的結果相比,FCT模型既能感受到所有花朵的位置,又能獲取花朵的關鍵局部細節。因此FCT可以通過Transformer分支的自注意力模塊獲得特征全局表示,也能從CNN分支的卷積模塊獲得局部細節的信息,令局部信息和全局信息有效融合。

圖2 熱力圖可視化

2.5 對比實驗

在Oxford Flowers-102和Caltech-101上,除了測試FCT模型以外,還測試了傳統的CNN模型(ResNet-18、ResNet-50)、原始的Transformer模型(ViT-base)和將Transformer融合卷積的模型。由表1可知,在沒有大規模數據的預訓練下,原始的Transformer模型ViT-base表現較差。其他文獻基于Transformer架構融合卷積的方法雖然能改善ViT的準確率,但是分類的準確率仍然無法達到傳統CNN模型效果。在Oxford Flowers-102上測試本文提出的FCT模型分類準確率比ResNet-18高5.84%,比ResNet-50高2.09%,比Visformer-tiny高9.64%。在Caltech-101數據集上,FCT模型分類準確率比傳統CNN模型高約2%,比其他Transformer架構模型優勢明顯。本文提出的融合CNN與Transformer的圖像分類模型FCT,充分地利用了CNN的低層次局部特征以及高層次的全局特征,使網絡模型擁有豐富的特征信息,提高模型的分類準確率。

表1 不同模型的測試集準確率

2.6 消融實驗

圖3為stem卷積層和CNN分支、Transformer分支分別融合后測試結果的可視化圖。由圖3可知,CNN+stem分支(ResNet-18)對原圖中三個花朵的部分,更集中關注能明顯識別為花蕊的區域,其感受區域比較局部。而Transformer+stem分支對原圖中三個花蕊的位置都能感受到,但是感受區域過大,超出了花蕊的位置。FCT模型融合了CNN的局部感受和Transformer的全局感受,能將三個花蕊的位置識別出來。

圖3 模型各分支識別效果可視化

本在Transformer分支上不同層次上的分類準確率,結果如表2所示。由表2可知,隨著從低層到高層卷積分支的融合,模型分類準確率在不斷提高。實驗結果說明,CNN分支不僅在低層次傳遞局部細節特征能提升模型的分類性能,而且在高層次的全局特征表示上也對提升模型的分類效果發揮了重要的作用。

表2 逐步融合不同層次卷積的分類準確率

3 結束語

針對傳統的CNN模型擁有出色的局部特征提取能力,但捕獲全局表示能力較弱,而視覺Transformer模型可以捕獲特征全局表示,但容易忽略局部細節的不足等問題,本文提出了一種基于融合CNN和Transformer的圖像分類模型FCT。FCT利用CNN分支的卷積算子來提取圖像的局部特征,利用Transformer分支的多頭自注意力機制來捕獲全局表示。在Oxford Flower-102和Caltech-101數據集上驗證,FCT模型的圖像分類準確率明顯優于傳統的CNN模型和Vision Transformer模型。下一步將探索融合模型中,Transformer分支向CNN分支傳遞全局特征信息的結構設計,使Transformer分支以及CNN分支同時擁有優秀的獲取局部特征和全局表示的能力,進一步提高模型的分類準確率。

猜你喜歡
分支全局局部
基于改進空間通道信息的全局煙霧注意網絡
領導者的全局觀
一類離散時間反饋控制系統Hopf分支研究
爨體蘭亭集序(局部)
軟件多分支開發代碼漏合問題及解決途徑①
凡·高《夜晚露天咖啡座》局部[荷蘭]
巧分支與枝
二分搜索算法在全局頻繁項目集求解中的應用
落子山東,意在全局
丁學軍作品
91香蕉高清国产线观看免费-97夜夜澡人人爽人人喊a-99久久久无码国产精品9-国产亚洲日韩欧美综合