查看原文
其他

【源头活水】DynamicViT: 动态Token稀疏化的高效视觉 Transformer

“问渠那得清如许,为有源头活水来”,通过前沿领域知识的学习,从其他研究领域得到启发,对研究问题的本质有更清晰的认识和理解,是自我提高的不竭源泉。为此,我们特别精选论文阅读笔记,开辟“源头活水”专栏,帮助你广泛而深入的阅读科研文献,敬请关注。

来源:知乎—赵文亮
地址:https://zhuanlan.zhihu.com/p/379126740
我们提出了一种基于动态token稀疏化的高效视觉transformer,通过分层剪枝66%的输入tokens,可减少31%~37%的FLOPs,并将模型运行速度提高了40%以上,保证精度下降在0.5%以内,可应用于各种视觉 transformer 模型中。
论文地址: 

https://arxiv.org/abs/2106.02034

代码仓库: 

https://github.com/raoyongming/DynamicViT

项目主页: 
https://dynamicvit.ivg-research.xyz/
预训练模型和训练/测试代码均已开源,欢迎使用


01

简介
自从去年 ViT 的提出,许多工作开始将 transformer 应用到各种视觉任务中。限制视觉transformer 的一个最大的瓶颈在于自注意力机制的巨大计算量(与token数的平方成正比)。然而,图片中所有的token对分类是同等重要的吗?为了探究这个问题,我们首先利用 transformer 可视化工具 可视化出图片中每个 token 对输出结果的影响:
可以看到,模型的预测结果只与图片中的少部分的token有关。也就是说,我们可以动态去除掉一些重要性较低的token,而不会对模型的准确率带来较大的影响。之前也有一些工作关注在以更低的复杂度去计算或近似自注意力机制(例如 Performer等),但是这些方法中仍然对所有的token计算FFN(Feedforward Network)。相比之下,如果我们可以去除掉大部分不重要的 token,那么自注意力和FFN的计算量都会随之减少。
基于上面的发现,我们提出了动态 token 稀疏化的方法。动态 token 稀疏化可以看成一种下采样的方式,但是每次下采样中保留哪些token是由当前的输入来动态确定的。作为对比,CNN 中的下采样方式是预先定义好的结构化下采样,而动态 token 稀疏化是非结构化的。需要指出的是,由于CNN需要时刻保持feature map的空间结构,动态 token 稀疏化方法在CNN上并不能实现加速。而在 transformer 中,自注意力的计算无需考虑token之间的空间关系,所以我们可以通过动态token稀疏化直接去除不重要的token。


02

方法
下面介绍我们的整体框架,包括一个主干网络和多个预测模块。其中,主干网络可以是各种视觉 transformer 模型,例如 DeiT,LV-ViT;预测模块用来动态预测token 稀疏化策略。
需要注意的是,我们的 token 稀疏化是以层次化的方式进行的。例如,假设 transformer 中有 12 个 block,我们会在第 4 个,第7个,第10个 block 之前分别实施一次稀疏化操作。

预测模块

为了减小额外的计算开销,我们预测模块的具体实现方式是几层轻量化的MLP,其输出结果表示了每个 token 被保留下来的概率。

注意力 masking 策略

在训练的时候,我们利用 Gumbel-Softmax 的技术将上面的概率变成二值化的mask。Gumbel-Softmax 既可以保证期望不变,又可以保证二值化的过程是可导的。为了实现并行化,我们在训练时并不会将无用的 token 剪掉,而是提出了一个注意力 masking 策略 来显式切断这些token和其他token之间的联系。

这种注意力 masking 策略可以保证最终的预测结果只与保留下来的 token 有关。在测试时,我们直接对每个token的保留概率排序,每次保留固定比例的token即可。例如,假设每次的保留比例为,经过三次 token 稀疏化操作后只有的token被保留下来。

03

实验结果
我们首先以 DeiT-S,LV-ViT-S,LV-ViT-M 三种模型作为主干网络分别用 0.7~1.0 的保留比例进行 token 稀疏化。可以看到我们的方法可以减少 31%~37% 的 FLOPs,并实现43%~54% 的实际加速,同时分类准确率几乎不受影响。
通过动态 token 稀疏化,我们可以实现非常有竞争力的准确率/复杂度 tardeoff。
我们的模型超过了包括EfficientNet-B5 在内的许多SOTA的CNN 模型。这个现象从下面的图中也可以很明显地看到。
此外,我们还发现我们的方法比直接在 channel 维度缩放具有更好的性能。下图中,我们的DynamicViT-256/0.7 在相近的计算量下的ImageNet分类准确率比 DeiT-Ti 高了4% 左右。这个现象也说明了我们的方法提供了一个新的模型缩放的方向。
为了验证动态token稀疏化的有效性,我们对比了结构化下采样(与CNN中相同)、使用静态稀疏化策略(即策略与输入无关),发现我们的动态 token 稀疏化的方法能在相同 FLOPs 下取得最高的准确率。
我们将我们的动态token稀疏化的过程可视化出来,其中白色的位置表示该位置的token已经被剪掉。可以看到,在经过3个stage过后,大部分背景的token都被逐渐剪掉,模型关注在与分类有关的物体上面。这一可视化也说明我们的方法具有很好的可解释性。

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


“源头活水”历史文章


更多源头活水专栏文章,

请点击文章底部“阅读原文”查看



分享、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存