查看原文
其他

【源头活水】Transformer in CV—— Vision Transformer

“问渠那得清如许,为有源头活水来”,通过前沿领域知识的学习,从其他研究领域得到启发,对研究问题的本质有更清晰的认识和理解,是自我提高的不竭源泉。为此,我们特别精选论文阅读笔记,开辟“源头活水”专栏,帮助你广泛而深入的阅读科研文献,敬请关注。

作者:知乎—limzero

地址:https://www.zhihu.com/people/lim0-34

之前一直做CV,今年出现的两篇(detr ,vit)将Transformer应用到cv的文章将cv和transformer结合带火了,所以最近学习了一下该方面的知识。之前也看过不少讲Transformer的文章,但没有看源码,很难理解其中奥义,这里还是从源代码角度来记录下自己的理解,防止后面遗忘。本篇是Vision Transformer的学习笔记,接下来还会学习DETR(End-to-End Object Detection with Transformers)相关的内容。
关于vision Transformer,文章地址:https://arxiv.org/abs/2010.11929
下面将以Vision Transformer的pytorch源码实现为例来理解:
  • 图片数据是如何被编码为Transformer能够处理的形式的
  • Transformer以及其中核心的self attention机制是如何work的
代码地址:https://github.com/lucidrains/vit-pytorch

01

图片数据如何处理为序列
这是来自于论文的图,输入图片被切分为固定尺寸的patch,然后连接起来,然后对每一个patch做一次线性变换降维后再添加位置编码等信息,再将其送入Transformer的编码器(其实这里的patch操作感觉和CNN中的卷积操作有点相似,只不过是实现的方式不一样)
下面是实现的代码(每一行代码都添加类相应的注释):
class ViT(nn.Module): def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3): super().__init__() assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' num_patches = (image_size // patch_size) ** 2 #切片数量(2048//32)**2==64**2==4096 patch_dim = channels * patch_size ** 2 #一张2048x2048的图被分为32x32大小的4096块,每一块3通道,将每一块展平:32x32x3=3072 所以patch_dim维度为:3072
self.patch_size = patch_size #patch_size:16
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))#位置编码:[1,4096+1,dim=512] self.patch_to_embedding = nn.Linear(patch_dim, dim)#将3072维度(像素点)embeding到512维度的空间 self.cls_token = nn.Parameter(torch.randn(1, 1, dim))#每一个维度都有一个类别的标志位 self.transformer = transformer
self.pool = pool self.to_latent = nn.Identity()#占位符
self.mlp_head = nn.Sequential(# 分类头 nn.LayerNorm(dim), nn.Linear(dim, num_classes) )
def forward(self, img): p = self.patch_size#32 ''' #img:[batch, 3, 2048, 2048] #'batch 3 (h 32) (w 32)'->'batch (h,w) (32 32 3)' 将图像分块,且每块展平(像素为单位连接起来) ''' x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)#[batch, 4096, 3072] 4096块,每一块展开为3072维向量 x = self.patch_to_embedding(x)#[batch, 4096, 512] 将3072维度的像素嵌入到512的空间 b, n, _ = x.shape cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)#[1,1,512]->[b,1,512] x = torch.cat((cls_tokens, x), dim=1)#[batch,4096+1,512] x += self.pos_embedding[:, :(n + 1)]#加上位置编码信息 ''' 以上步骤干的事情: - 输入图片分块->展平:[batch,c,h,w]->[batch,num_patch,c*patch_size*patch_size] - 原始的像素嵌入到指定维度(dim):[batch,num_patch,c*patch_size*patch_size]->[batch,num_patch,dim] - 每一个样本的每一个维度都加入类别token,给分片的图像多加一片,专门用来表示类别 - [batch,num_patch,dim]->[batch,num_patch+1,dim] - 给所有的"片(patch)"加入位置编码信息.这里的位置编码初始化为随机数,是通过网络学习出来的 以上步骤产生的输出结果即可送入到Transformer里面进行编码 [batch,num_patch+1,dim]经过transformer的编码将会出来一个[batch,num_patch+1,dim]的向量 ''' x = self.transformer(x)#[batch,num_patch+1,dim]->[batch,num_patch+1,dim]
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]#[batch,dim]
x = self.to_latent(x) return self.mlp_head(x)
整体而言,上面的代码所做的工作如下:
  • 输入图片分块->展平:[batch,c,h,w]->[batch,num_patch,c*patch_size*patch_size]
  • 原始的像素嵌入到指定维度(dim):[batch,num_patch,c*patch_size*patch_size]->[batch,num_patch,dim]
  • 每一个样本的每一个维度都加入类别token,给分片的图像多加一片,专门用来表示类别 [batch,num_patch,dim]->[batch,num_patch+1,dim]
  • 给所有的"片(patch)"加入位置编码信息.这里的位置编码初始化为随机数,是通过网络学习出来的
  • 以上步骤产生的输出结果即可送入到Transformer里面进行编码 [batch,num_patch+1,dim]经过transformer的编码将会出来一个[batch,num_patch+1,dim]的向量

02

Transformer以及self-Attention
Transformer的编码部分由6个self-attention以及FFN堆叠构成
class Transformer(nn.Module): def __init__(self, dim, depth=6, heads, dim_head, mlp_dim, dropout): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) ])) def forward(self, x, mask = None): for attn, ff in self.layers: x = attn(x, mask = mask) x = ff(x) return x
其中的Attention部分如下所示:
class Attention(nn.Module): def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): super().__init__() inner_dim = dim_head * heads#64*8=512:8个head,每个head:64维. self.heads = heads self.scale = dim ** -0.5#对应Attention公式里面的分母
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)#输入的维度映射到多头注意力机制的维度,将输入处理成qkv矩阵 self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) )
def forward(self, x, mask = None): #x:(batch,num_patch+1,dim) b, n, _, h = *x.shape, self.heads qkv = self.to_qkv(x).chunk(3, dim = -1)#([batch=1,num_patch+1=65,inner_dim=512]) # 'batch=1 num_patch=65 (head=8 dim_head=64) -> [batch=1 head=8 num_patch=65 dim_head=64]' q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) ''' 此时:q,k,v的维度都是[batch=1 head=8 num_patch=65 dim_head=64] '''
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale ''' (num_patch=65,dim_head=64)*(num_patch=65,dim_head=64)^T->(num_patch=65,num_patch=65), 向量的内积可以理解为相似度,q的一行代表了其中的一个patch,k同理,二者相乘,代表了序列中两个patch之间的相似度. 这其实类似与信息检索之中的query,key匹配过程,这个相似度就可以作为权重.而v初始化为q,k一样的形状, 也代表了输入的每个patch的特征,为了让该特征具有更好的表征能力,每一个patch的特征都应该有其余所有patch(包括 该patch自己加权而来,这里的权重即为q*k^T) ''' mask_value = -torch.finfo(dots.dtype).max#指定mask_value为dots.dtype下的最小值
if mask is not None: mask = F.pad(mask.flatten(1), (1, 0), value = True) assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' mask = mask[:, None, :] * mask[:, :, None] dots.masked_fill_(~mask, mask_value) del mask
attn = dots.softmax(dim=-1)#注意softmax的维度,按照行进行的softmax
out = torch.einsum('bhij,bhjd->bhid', attn, v)#(num_patch=65,num_patch=65)*(num_patch=65,dim_head=64)->(num_patch=65,dim_head=64) ''' 此处得到的out可以理解为通过前面的attention矩阵和v获得了对每一个patch(word)的更好的嵌入表示. ''' out = rearrange(out, 'b h n d -> b n (h d)') out = self.to_out(out) return out
关于self-attention的理解:
向量的内积可以理解为相似度,q的一行代表了其中的一个patch,k同理,二者相乘,代表了序列中两两patch之间的相似度.这其实类似与信息检索之中的query,key匹配过程,这个相似度就可以作为权重.而v初始化为q,k一样的形状,也可以看作代表了输入的每个patch的特征,为了让该特征具有更好的表征能力,每一个patch的特征都应该由其余所有patch(包括该patch自己加权而来,这里的权重即为q*k^T)
FFN部分
class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x)
这部分代码没有说明需要特殊说明的,这里引用"FFN 相当于将每个位置的Attention结果映射到一个更大维度的特征空间,然后使用ReLU引入非线性进行筛选,最后恢复回原始维度。"[1]

03

编码出来的特征接一个分类head
Transformer的编码器出来的特征可以直接接一个分类head,得到输入对应每一个类别的概率,这和常规的cnn接全连接分类head没有不同
以上就是vit的核心思路,给人的感觉就像的将cnn换成了Transformer的编码器.
虽然整个流程似乎没有用到任何CNN的东西,但是其中的很多操作感觉本质上和CNN是一个道理,虽然没有严格的数学证明,但直觉上看.vit是利用了各种方式在模拟CNN的过程

[1]https://zhuanlan.zhihu.com/p/48508221
[2]https://zhuanlan.zhihu.com/p/106867810

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


“源头活水”历史文章


更多源头活水专栏文章,

请点击文章底部“阅读原文”查看



分享、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存