查看原文
其他

【源头活水】一文教你彻底理解Google MLP-Mixer(附代码)

“问渠那得清如许,为有源头活水来”,通过前沿领域知识的学习,从其他研究领域得到启发,对研究问题的本质有更清晰的认识和理解,是自我提高的不竭源泉。为此,我们特别精选论文阅读笔记,开辟“源头活水”专栏,帮助你广泛而深入的阅读科研文献,敬请关注。

作者:知乎—月球上的人

地址:https://zhuanlan.zhihu.com/p/372692759

随着深度神经网络发展至今,网络结构优化的瓶颈也慢慢显现出来。由此,文艺复兴随之出现,MLPs(Multi-layer Perceptrons)这种古老的结构也开始被重新拉上舞台。本文深入浅出介绍Google新坑,MLP-Mixer。

参考代码地址:https://github.com/lucidrains/mlp-mixer-pytorch

Google最近又挖了一个新坑,MLP-Mixer。原文提到,CNN以及self-attention这种相对复杂的网络结构在视觉任务上已经取得很好的表现了,但是我们真的需要这么复杂的网络结构吗?MLP这种简单的结构是否也能够取得SOTA的表现呢?MLP-Mixer给出了答案。

convolutions and attention are both sufficient for good performance, neither of them are necessary.--引自原文


01

为什么叫MLP-Mixer?

MLP好理解,这个网络结构没有采用convolution以及attention的网络结构,纯粹使用MLP作为主要架构。

那为什么叫Mixer呢?举个例子就明白了,现在很多视觉任务的网络架构,其实就是mix不同的特征,找出各个特征之间的关系来获取有用的信息。从CNN的网络结构来理解就很简单了,拿一个NxNxC的kernel来举例,

(1)NxN这两个维度其实就是来mix不同位置像素点的mixer

(2)而C这个维度则是来mix一个像素点不同通道特征的mixer。

MLP-Mixer将CNN这两个任务切割开来,用两个MLP网络来处理,分别为
(1)不同位置的mix叫做token-mixing
(2)同一位置不同通道的mix叫做channel-mixing。
知道了名字含义,那接下来的网络结构也就好理解了。

02

一步一步拆解网络结构
MLP-Mixer总体结构

MLP-Mixer总体结构

def MLPMixer(*, image_size, patch_size, dim, depth, num_classes, expansion_factor = 4, dropout = 0.): assert (image_size % patch_size) == 0, 'image must be divisible by patch size' num_patches = (image_size // patch_size) ** 2 chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear
return nn.Sequential( # 1. 将图片拆成多个patches Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), # 2. 用一个全连接网络对所有patch进行处理,提取出tokens nn.Linear((patch_size ** 2) * 3, dim), # 3. 经过N个Mixer层,混合提炼特征信息 *[nn.Sequential( PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)), PreNormResidual(dim, FeedForward(dim, expansion_factor, dropout, chan_last)) ) for _ in range(depth)], nn.LayerNorm(dim), Reduce('b n c -> b c', 'mean'), # 4. 最后一个全连接层进行类别预测 nn.Linear(dim, num_classes) )

Mixer结构

MLP结构
mixer分为token-mixer以及channel-mixer,简单来说就是分别对输入特征平面进行
(1)沿列方向的特征提炼,这个可以用kernel_size为1的conv1d实现全连接的效果
(2)沿行方向的特征提炼,简单的linear就可以了。
chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear
MLP的实现也很简单,两个全连接层中间以GELU(Gaussian Error Linear Units)作为激活函数。代码实现如下,
class PreNormResidual(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = nn.LayerNorm(dim)
def forward(self, x): return self.fn(self.norm(x)) + x
def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear): return nn.Sequential( dense(dim, dim * expansion_factor), nn.GELU(), nn.Dropout(dropout), dense(dim * expansion_factor, dim), nn.Dropout(dropout)    )


03

实验结果
具体实现结果可以参考原文,这里只是简单贴了一张图出来展示。图中比较了Mixer,Transformer,ResNet在JFT-300M数据集上预训练,在224x224图片上fine-tuning后,在imageNet验证集上的Top-1 accuracy,可以看到Mixer取得了比ResNet稍好,和ViT transformer基本上差不多的表现,并且训练速度(img/sec/core) Mixer会比其他两个都要快,证明了MLP这种简单结构的潜力。

04

思考:为什么MLP突然复兴了?
MLP-Mixer论文中的实现数据证明了MLP这种古老结构的能力,似乎CV任务主流网络结构完成了MLP->CNN->Transformer->MLP这样一次轮回,学术界又有一个大坑可以填了。
那为什么MLP在以前不行,现在又行了呢?个人觉得主要有以下两个原因,
1. 算力的提升以及数据完善;
2. 现代网络结构技巧,说是只是用了MLP,其实现在SOTA模型里面的一些结构,比如说layernorm, GELU,skip connection (ResNet) ,也在MLP-mixer中起到了至关重要的作用。
这里再说一下ResNet,私以为,不管是transformer还是这里的mixer,其实本质上都是ResNet-like的网络结构,skip connection的存在,可以让网络往深了堆,并且所有权重都能充分训练。所以说,Residual 才是 Is All Your Need。

参考资料

https://arxiv.org/abs/2105.01601

https://github.com/lucidrains/mlp-mixer-pytorch#citations

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


“源头活水”历史文章


更多源头活水专栏文章,

请点击文章底部“阅读原文”查看



分享、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存