CBAM: Convolutional Block Attention Module
GitHub - Jongchan/attention-module: Official PyTorch code for "BAM: Bottleneck Attention Module (BMVC2018)" and "CBAM: Convolutional Block Attention Module (ECCV2018)"
Channel attention module
对于输入特征图,分别使用avarage-pooling和max-pooling得到
和
;然后接同一个网络,这个网络是含一个隐藏层的MLP,即两层全连接层,为了减少额外的参数开销,隐藏层的size设置为
,r是reduction ratio,第二个FC再还原回去;接着通过element-wise summation的方式进行融合;最后再接一个sigmoid激活函数得到channel分支的结果
。具体计算方法如下:
其中表示sigmoid函数,注意两个子分支的
和
相同,且
后接激活函数ReLU。
Spatial attention module
对于输入特征图, 沿通道方向分别使用avarage-pooling和max-pooling得到
和
;然后沿通道方向concatenate;然后接一个7×7的卷积;最后接一个sigmoid函数得到spatial分支的结果
。具体计算方法如下:
Arrangement of attention modules
作者通过实验确定了两个attention module按sequential的方式比parallel的方式效果好,通道attention module放在空间attention module前面效果更好。因此最终的结构如下所示:
Ablation studies
Channel attention
实验对比了通道注意力使用AvgPool、MaxPool、AvgPool&MaxPool的区别,结果表明两者结合起来使用效果最好。"We argue that max-pooled features which encode the degree of the most salient part can compensate the average-pooled features which encode global statistics softly."
Spatial attention
Arrangement of the channel and spatial attention
作者在该部分比较了三种不同的通道和空间分支融合方法,sequential channel-spatial、sequential spatial-channel、parallel,实验结果表明sequential channel-spatial的效果最好。
官方代码
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
class BasicConv(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True,
bn=True, bias=False):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
self.relu = nn.ReLU() if relu else None
def forward(self, x):
x = self.conv(x)
if self.bn:
x = self.bn(x)
if self.relu:
x = self.relu(x)
return x
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class ChannelGate(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
super(ChannelGate, self).__init__()
self.gate_channels = gate_channels
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(gate_channels, gate_channels // reduction_ratio),
nn.ReLU(),
nn.Linear(gate_channels // reduction_ratio, gate_channels)
)
self.pool_types = pool_types
def forward(self, x):
channel_att_sum = None
for pool_type in self.pool_types:
if pool_type == 'avg':
avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp(avg_pool)
elif pool_type == 'max':
max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp(max_pool)
if channel_att_sum is None:
channel_att_sum = channel_att_raw
else:
channel_att_sum = channel_att_sum + channel_att_raw
scale = F.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
return x * scale
class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7
self.compress = ChannelPool()
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = F.sigmoid(x_out) # broadcasting
return x * scale
class CBAM(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
super(CBAM, self).__init__()
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
self.no_spatial = no_spatial
if not no_spatial:
self.SpatialGate = SpatialGate()
def forward(self, x):
x_out = self.ChannelGate(x)
if not self.no_spatial:
x_out = self.SpatialGate(x_out)
return x_out
和BAM的区别
- BAM的channel和spatial是parallel模式,CBAM是sequence模式
- 在channel attention中,BAM只用了avg pool,而CBAM用了avg pool和max pool