深入理解注意力机制
作者:晟沚
编辑:龚赛
前 言
注意力机制和人类的视觉注意力很相似,人类的注意力是人类视觉所特有的大脑信号处理机制。人类通过快速扫描全局图像,获得需要重点关注的目标区域,得到注意力焦点,而后对这一区域投入更多注意力,以获取更多所需要关注目标的细节信息,从而抑制其他无用信息。这是人类利用有限的注意力资源从大量信息中快速筛选出高价值信息的手段,是人类在长期进化中形成的一种生存机制,极大地提高了视觉信息处理的效率与准确性。比如给一张印有图片的报纸,那人会先去看报纸的标题,然后会看显目的图片。
深度学习中的注意力机制从本质上讲和人类的选择性视觉注意力机制类似,目的也是从众多信息中选择出对当前任务目标更关键的信息。
01
channel-wise attention
本文从SCA-CNN中提到的channel-wise的角度来理解注意力机制,paper地址:SCA-CNN,首先我们从几个问题出发来理解.
02
为什么要引入channel wise attention
举个例子:当你要预测一张图片中的帅哥时,那么channel wise attention就会使得提取到帅哥特征的feature map的权重加大,这样最后output结果就会准确不少.
为什么要引入multi-layer呢?
因为高层的feature map的生成是依赖低层的feature map的,比如你要预测图片中的帅哥,我们知道,底层网络提取到的更多是底层的细节,而高层网络才能提取到全局的语义信息,那么只有低层kernel提取到更多帅哥边缘特征,高层才能更好地抽象出帅哥来。另外如果只在最后一个卷积层做attention,其feature map的receptive field已经很大了(几乎覆盖整张图像),那么feature map之间的差异就比较小,不可避免地限制了attention的效果,所以对multi-layer的feature map做attention是非常重要的。
03
为什么还需要spatial attention
前面channel-wise attention 只会关注到图像的一个小部分,而spatial attention的作用为关键部分配更大的权重,让模型的注意力更集中于这部分内容。
channel wise attention是在回答“是什么”,而spatial attention是在回答“在哪儿”.
spatial attention是以feature map的每个像素点为单位,对feature map的每个像素点都配一个权重值,因此这个权重值应该是一个矩阵,大小是图片的大小;channel wise attention则是以feature map为单位,对每个channel都配一个权重值,因此这个权重值应该是一个向量。
04
具体细节
首先,我们可以通过下图看到整个attention添加的过程,作用在multi-layer 的feature map上.
网络中第(l-1)层的输出feature map,该feature map经过channel wise attention 函数得到权重 βl。然后这个βl和Vl相乘就得到中间结果的feature map。接下来feature map经过spatial attention函数得到权重 αl。这个αl和前面生成的feature map相乘得到最终的Xl.
本文主要讲channel-wise attention,首先我们将feature map V reshape to U, and U = [u1,u2.....uC], U是R^(W*H),分别代表每个channel;C代表channel的总数;然后我们用mean pooling得到 channel feature v;然后我们再实现下图所示的公式.
最后,channel-wise attention的具体实现pytorch代码:
class AttentionLayer(nn.Module):
def __init__(self, channel, reduction=64, multiply=True):
super(AttentionLayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel),
nn.Sigmoid()
)
self.multiply = multiply
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
if self.multiply == True:
return x * y
else:
return y
END
往期回顾之作者夏敏
fine-gained image classification
文章不多,都是精华!!!
机器学习算法工程师
一个用心的公众号