查看原文
其他

【源头活水】AdaViT: Adaptive Tokens for Efficient Vision Transformer



“问渠那得清如许,为有源头活水来”,通过前沿领域知识的学习,从其他研究领域得到启发,对研究问题的本质有更清晰的认识和理解,是自我提高的不竭源泉。为此,我们特别精选论文阅读笔记,开辟“源头活水”专栏,帮助你广泛而深入的阅读科研文献,敬请关注。

来源:知乎—煎饼果子不要果子

地址:https://zhuanlan.zhihu.com/p/452330394


01

主要思路和创新点
因为 Transformer 复杂度很高的原因,在卷精度之余大部分创新都是针对缩减其计算成本。动态网络就是其中一个主要方向,可以筛选用于计算的 token / head / block 等等,本文就是逐层过滤掉一些不需要的 token。
先看图片下方,其实这篇文章筛选的效果是很好的,这些都是保留至最后的 token。上面一行讲述的就是模型如何过滤 token,在每层,每个特征都会经过一个超简 MLP 预测是否被截断,保留下来的 token 会输入到下一层继续。
首先是一个简单的整体模型公式定义:

L 为层总数

t 就是 token,K 是 token 总数
F 为每层 Transformer 模块,epsilon 为 patch 生成模块,最后 C 则是通过类别 token 输出最终预测。
sigma 就是 sigmoid,使其值在 0-1
每个 token 都会输入一层简易 MLP 来预测是否保留或截断,作者称 h 为 halting score,当这个分数累计超过一定值时,就说明后面不需要这个 token 了,在最后一层 L 所有 token 的分数都设置为 1。判断是否截断的相关公式如下:
 是一个极小正参数,emmm 具体为什么要设一个 1 我也没有特别明白。N 就代表第 k 个 token 被截断位置的层数,在推理阶段这些 token 在后几次层是直接被去掉的,但在训练阶段只是将这部分 token 置 0,同时 softmax 不考虑。
在截断当层,因为此时很可能加总的 h 要大于 1,因此还为当层定义了一个保留分数,即 1 减去前面所有截断分数,使得所有层加总等于 1:
有意思的是,与其他方法不同,本文还将这个截断分数用于类别 token,并且根据各层分数整合最终用于预测的类别特征。我觉得这时候可以把截断分数当作类别的重要性分数,当累计值超过 1 时,就不需要后续类别特征了,然后定义一个每层类别特征用于参与最终预测的概率:
截断一层概率值就等于剩下的重要性
最后类别特征就等于各层特征乘相应重要性
但是损失函数似乎写错了,应该是和真值的交叉熵损失 emmmm?
这部分对应的结构图参考下面:
上分支是针对类别,但似乎最后特征整合又写错了,第五层右边总和符号上标应该是 4?下分支就是针对正常的图像 token,最后 Nt 应该是 N1。之后文章提出一个 ponder 损失,一方面希望 token 被截断的越前越好,一方面希望截断前分数分布尽可能均匀:
N 是监督彻底截断尽可能早,r 则是让前面分布尽可能平均,不要让截断分数都集中在最后一层。这部分伪代码写的其实很清楚:
作者还额外提出了一个分布损失,希望不重要 token 的截断位置们都能处在差不多的位置上,每层 token 总值作为其分布:
其中会提前设定一个高斯分布,中心点位于期望截断层,使用 KL 散度来监督。最后损失函数就是上述三者加权:


02

实验结果和可视化
实验精度和计算成本
黄色是未使用先验分布监督,蓝色使用了先验分布监督,右侧图表明蓝色精度始终高于黄色。
一些可视化结果,颜色代表了被截断 token 的位置,越白表示越靠后,即保留的 token

论文信息

AdaViT: Adaptive Tokens for Efficient Vision Transformer
https://arxiv.org/pdf/2112.07658.pdf

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


“源头活水”历史文章


更多源头活水专栏文章,

请点击文章底部“阅读原文”查看



分享、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存