超越CBAM,全新注意力機制!GAM:不計成本提高精度 附Pytorch實現

作者丨ChaucerG

編輯丨極市平臺

超越CBAM,全新注意力機制!GAM:不計成本提高精度 附Pytorch實現

為了提高計算機視覺任務的效能,人們研究了各種注意力機制。然而,以往的方法忽略了保留通道和空間方面的資訊以增強跨維度互動的重要性。因此,本文提出了一種透過減少資訊彌散和放大全域性互動表示來提高深度神經網路效能的全域性注意力機制。本文引入了3D-permutation 與多層感知器的通道注意力和卷積空間注意力子模組。在CIFAR-100和ImageNet-1K上對所提出的影象分類機制的評估表明,本文的方法穩定地優於最近的幾個注意力機制,包括ResNet和輕量級的MobileNet。

1 簡介

卷積神經網路已廣泛應用於計算機視覺領域的許多工和應用中。研究人員發現,CNN在提取深度視覺表徵方面表現良好。隨著CNN相關技術的改進,ImageNet資料集的影象分類準確率在過去9年裡從63%提高到了90%。這一成就也歸功於ImageNet資料集的複雜性,這為相關研究提供了難得的機會。由於它覆蓋的真實場景的多樣性和規模,有利於傳統的影象分類、表徵學習、遷移學習等研究。特別是,它也給注意力機制帶來了挑戰。

近年來,注意力機制在多個應用中不斷提高效能,引起了研究興趣。Wang等人使用編碼-解碼器residual attention模組對特徵圖進行細化,以獲得更好的效能。Hu 等人分別使用空間注意力機制和通道注意力機制,獲得了更高的準確率。然而,由於資訊減少和維度分離,這些機制利用了有限的感受野的視覺表徵。在這個過程中,它們失去了全域性空間通道的相互作用。

本文的研究目標是跨越空間通道維度研究注意力機制。提出了一種“全域性”注意力機制,它保留資訊以放大“全域性”跨維度的互動作用。因此,將所提出的方法命名為全域性注意力機制(GAM)。

2 相關工作

注意力機制在影象分類任務中的效能改進已經有很多研究。

SENet在抑制不重要的畫素時,也帶來了效率較低的問題。

CBAM依次進行通道和空間注意力操作,而BAM並行進行。但它們都忽略了通道與空間的相互作用,從而丟失了跨維資訊。

考慮到跨維度互動的重要性,TAM透過利用每一對三維通道、空間寬度和空間高度之間的注意力權重來提高效率。然而,注意力操作每次仍然應用於兩個維度,而不是全部三個維度。

為了放大跨維度的互動作用,本文提出了一種能夠在所有三個維度上捕捉重要特徵的注意力機制。

3 GAM注意力機制

本文的目標是設計一種注意力機制能夠在減少資訊彌散的情況下也能放大全域性維互動特 徵。作者採用序貫的通道-空間注意力機制並重新設計了CBAM子模組。整個過程如圖1 所示, 並在公式1和2。給定輸入特徵對映

, 中間狀態

和輸出

定義為:

其中

分別為通道注意力圖和空間注意力圖;

表示按元素進行乘法操作。

超越CBAM,全新注意力機制!GAM:不計成本提高精度 附Pytorch實現

通道注意力子模組

通道注意子模組使用三維排列來在三個維度上保留資訊。然後,它用一個兩層的MLP(多層感知器)放大跨維通道-空間依賴性。(MLP是一種編碼-解碼器結構,與BAM相同,其壓縮比為r);通道注意子模組如圖2所示:

超越CBAM,全新注意力機制!GAM:不計成本提高精度 附Pytorch實現

空間注意力子模組

在空間注意力子模組中,為了關注空間資訊,使用兩個卷積層進行空間資訊融合。還從通道注意力子模組中使用了與BAM相同的縮減比r。與此同時,由於最大池化操作減少了資訊的使用,產生了消極的影響。這裡刪除了池化操作以進一步保留特性對映。因此,空間注意力模組有時會顯著增加引數的數量。為了防止引數顯著增加,在ResNet50中採用帶Channel Shuffle的Group卷積。無Group卷積的空間注意力子模組如圖3所示:

超越CBAM,全新注意力機制!GAM:不計成本提高精度 附Pytorch實現

Pytorch實現GAM注意力機制

import torch。nn as nn import torch class GAM_Attention(nn。Module): def __init__(self, in_channels, out_channels, rate=4): super(GAM_Attention, self)。__init__() self。channel_attention = nn。Sequential( nn。Linear(in_channels, int(in_channels / rate)), nn。ReLU(inplace=True), nn。Linear(int(in_channels / rate), in_channels) ) self。spatial_attention = nn。Sequential( nn。Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3), nn。BatchNorm2d(int(in_channels / rate)), nn。ReLU(inplace=True), nn。Conv2d(int(in_channels / rate), out_channels, kernel_size=7, padding=3), nn。BatchNorm2d(out_channels) ) def forward(self, x): b, c, h, w = x。shape x_permute = x。permute(0, 2, 3, 1)。view(b, -1, c) x_att_permute = self。channel_attention(x_permute)。view(b, h, w, c) x_channel_att = x_att_permute。permute(0, 3, 1, 2) x = x * x_channel_att x_spatial_att = self。spatial_attention(x)。sigmoid() out = x * x_spatial_att return out if __name__ == ‘__main__’: x = torch。randn(1, 64, 32, 48) b, c, h, w = x。shape net = GAM_Attention(in_channels=c, out_channels=c) y = net(x)

4實驗

4。1 CIFAR-100

超越CBAM,全新注意力機制!GAM:不計成本提高精度 附Pytorch實現

4。2 ImageNet-1K

超越CBAM,全新注意力機制!GAM:不計成本提高精度 附Pytorch實現

4。3 消融實驗

超越CBAM,全新注意力機制!GAM:不計成本提高精度 附Pytorch實現

超越CBAM,全新注意力機制!GAM:不計成本提高精度 附Pytorch實現

5 參考

[1]。Global Attention Mechanism: Retain Information to Enhance Channel-Spatial Interactions