最近,Transformers 被应用于大规模的图像分类任务并取得了高分表现,由此动摇了卷积神经网络长期以来的霸主地位。
然而 transformers 的优化研究至今甚少。在这项工作中,Facebook 联合团队构建并优化了用于图像分类的深度 transformer。
特别是,此项研究研究了这种专用架构和优化方法的相互作用,其对 transformer 架构的改变,明显提高了深度 transformer 的准确性。
即使训练时没有额外数据,新方法仍在 Imagenet 上获得了 86.3% 的最高准确率。在没有额外训练数据和利用重分配标签及 Imagenet-V2 匹配频率的情况下,研究的最佳模型在 Imagenet 上达到了目前的最高水准。
这项研究的目标是在训练 transformer 进行图像分类时提高优化的稳定性,不过,此次研究更具体地考虑了 Dosovitskiy 等人提出的视觉转换器 (ViT) 架构作为参考架构,并采用 Touvron 等人的数据高效图像转换器 (DeiT) 优化程序。
在这两个工作中,都没有证据表明,深度仅在 Imagenet 上训练时能带来好处:更深的 ViT 架构性能较低,而 DeiT 只考虑 12-blocks 的 transformer。
最终,此项研究进行了两次 transformer 的结构更改,提高了深度 transformer 的精度,并通过实验验证了此项研究方法的有效性和互补性。
体系结构更改
丨分层缩放极大地促进了收敛,并提高了深度 transformer 的精度。在训练时仅仅给网络增加了几千个参数(可忽略不计的数量)。丨提出的的特定类关注结构,为 class embedding 提供了更有效的处理。丨团队提出的 CaiT 模型在 Imagenet-Real 和 Imagenet V2 匹配率上达到了最高水平。丨提出的方法在迁移学习方面取得了的竞争力十足的成果。(a) ViT 图像分类器采用了 Child 等人提出的预归一化。(b) rezero/skippinit 和 Fixup 删除了 η 归一化和 warmup 步骤(即,在早期训练阶段降低了学习速率)并各添加了一个可学习的标量来替代他们,分别初始化初始化为 0 和 1。Fixup 步骤还引入了偏差并修改了线性层的初始化。这是因为原先的方法不能与深度视觉 transformer 融合。(c) 此项研究通过重新引入规范前的 η 和 warmup 步骤来进行适应。(d) 此项研究主要提出了如下方法:引入了各通道加权,(即与参数矩阵相乘,该参数矩阵是一个对角阵,diag( ,..., ))其中,此项研究将每一个权重初始化为一个很小的值 λi=ε。双阶段即为 Self-attention + Class-attention,该设计旨在规避 ViT 架构的一个问题,即要求可学习权重同时优化两个冲突的目标:(1)引导各模块间的自关注同时(2)总结对线性分类器有用的信息。此项研究的方案是按照编码器 — 解码器的结构将这两个阶段分开。在双阶段处理之前,插入了所谓的类别标记(class token),在之后的 transformer 中由 CLS 表示。这么做可以消除 transofrmer 第一层的差异这样可以使 transformer 完全致力于在 patch 之间进行 selt-attention。作为不受矛盾的目标困扰的基准,此项研究还考虑了对 transformer 输出端的所有 patch 进行平均池化,这就如卷积操作之后通常会做的那样。研究提出的 CaiT 网络,具体由如下两个不同的处理阶段组成:(a) Self-attention 阶段与 ViT Transformer 相同,但是不包括类别嵌入(CLS)。(b) Class-attention 阶段是一组层设置,将 patch embeddings 编译成类别嵌入 CLS,随后将其反馈给线性分类器。Class-attention 在 multi-head class-attention(CA)和 FFN 层多次交替出现,在这个阶段,只有 class embedding 是被更新的,类似于 transformer 中输入端的 ViT 和 DeiT 中输入的可学习矢量。最主要的区别是在此项研究的架构中,在前向传播的时候,此项研究不会将信息从 class embedding 复制到 patch embeddings。只有 class embedding 是由 CA 和 FFN 处理中的残差来更新的。在 ViT Transformer(左)中,class embedding(CLS)与 patch embedding 一起插入。这种选择是有损效果的,因为相同的权重被用于两个不同的目标:辅助 attention 的过程,并且准备好要送入给分类器的向量。通过展示插入 CLS 后性能会提高说明了这个问题(中)。在 CaiT 架构中(右),团队还建议在插入 CLS 以减少计算量时冻结 patch embeddings,这样,网络的最后一部分(通常是两层)就可完全专注于总结要提供给线性分类器的信息。Multi-heads class attention:CA 层的作用是从一组已处理的 patch 中提取信息。它和 SA 层的作用相同,只是他依赖于这么两点:(1)class embedding xclass(在第一个 CA 中的 CLS 处初始化)(2)其自身加上了冻结的 patch embedding xpatches。考虑一个有着 h 个 heads 和 p 个 patches 的网络,并由 d 这个嵌入尺寸来牵引,此项研究用几个投影矩阵来参数化 multi-heads class attention, 、 、 、 ∈ × ,其相应的偏差为 ∈ × ,使用这种表示方法来计算 CA residual block 的步骤如下,首先扩充 patch embeddings(以矩阵的形式)为 ,然后再进行了投影:Class attention 的权重由下式给出:其中 ∈ 这种注意力机制设计加权求和 来输出残差向量
CA 层从 patch embedding 中提取有用信息到 class embedding 中。在初步实验中,第一个 CA 和 FFN 是性能提升的主要来源,一组由两个 block 组成的层 (2 个 CA 和 2 个 FFN) 足以使性能达到最佳。在实验中,当它由 12 块 SA+FFN 层和 2 块 CA+FFN 层组成时,用 12+2 表示一个 transformer。这些层在 class-attention 和 self-attention 中包含着相同数量的参数,在这方面 CA 和 SA 是一样的。此处也用相同的参数化方法来处理 FFN。然而这些层处理起来要快得多,因为 FNN 只需处理矩阵向量乘法。CA 函数在内存和计算方面也比 SA 要高效率得多,因为只计算 class vector 和 patch embeddings 之间的关注度。在团队的早期实验中,当增大架构规模时,视觉 transformers 变得越来越难训练。深度是不稳定性的主要来源之一。例如,DeiT 程序在不调整超参数的情况下,无法在 18 层以上的深度正常收敛。大型 ViT 模型有 24 层和 32 层,需要用大型训练数据集进行训练,但在 Imagenet 这样并不是很大的数据集上训练时,仅仅是拥有大模型,并不具备竞争力。在接下来,此项研究分析了用不同架构稳定训练的各种方法。在这一阶段,考虑在 300 个 epochs 期间使用 Deit-Small 模型,以便与 Touvron 等人的结果报告进行直接比较。团队测量了 Imagenet1k 分类数据集上的性能作为深度的函数。提高收敛性的第一步是调整与深度交互最多的超参数,特别是随机深度。这种方法在 NLP 中已经很流行,往往被用来训练更深的架构。对于 ViT,它最早是由 Wightman 等人在 Timm 实现中提出的,随后在 DeiT 中被采用。每层下降率线性地取决于层深度,但实验中,与更简单的统一损失率 dr 的选择相比,这种选择并没有什么优势。实验显示 DeiT 的默认随机深度允许训练多达 18 个 SA+FFN 块。再之后,训练就会变得不稳定。通过增加损失率超参数 dr,性能增加到 24 层。在 36 层时达到饱和(测得在 48 层时下降到 80.7%)。此项研究还对标准化方法进行了实证研究。如前所述,Rezero、Fixup 和 T-Fixup 在训练 DeiT 现成的时候并不一致。然而,如果此项研究重新引入 Layer- Norm3 和 warmup,Fixup 和 T-Fixup 与基线 DeiT 相比实现了一致性,甚至提高了训练结果。修改后的方法能够用更多的图层收敛,而不会过早的饱和。至于 ReZero 收敛,研究表明(列 α=ε),最好将 α 初始化为一个小值,而不是 0,就像 LayerScale 那样。所有的方法都对收敛有好处,它们倾向于减少对随机深度的需求,因此此项研究根据每个方法相应地调整这些下降率。研究通过计算残差激活的范数和主分支 的激活的范数之间的比率来评估 Layerscale 对 36-blocks transformer 的影响。结果发现,用 Layerscale 训练模型使这个比率在各层之间更加一致,并且可能防止了一些层对激活产生不相称的影响。此项研究的最佳模型通过重新评估的标签和 Imagenet-V2 匹配频率,在无需额外新数据的情况下,在 Imagenet 上达到了最高水平:https://arxiv.org/pdf/2103.17239v1.pdf丨麻省理工学院学者万字长文:计算机作为一种通用技术的衰落数据实战派希望用真实数据和行业实战案例,帮助读者提升业务能力,共建有趣的大数据社区。