Transformer多头自注意力机制的本质洞察
作者:Nikolas Adaloglou
编译:王庆法
译者注:本文详述了简单的多头自注意力机制的令人惊讶的丰富涵义和洞察,对理解大模型背后的transformer 如何高效工作的原理提供了丰富的视角。
这篇文章是为那些想要真正了解自注意力为什么以及如何工作的好奇的人准备的。在实施或仅解释一篇带有transformer的新的花哨论文之前,我认为介绍有关注意力机制的各种观点会很有趣。
在研究这个话题几个月之后,我发现了许多隐式的直觉,这些直觉可以给基于内容的注意力机制赋予意义。
我为什么要花时间进一步分析自注意力?
首先是因为我找不到直接的答案来回答为什么多头自注意力会起作用的显然的问题。其次,因为很多顶尖的研究人员,比如谷歌大脑的hadamaru,都认为这是2018年之后最重要的公式:
TL;DR (Too long; Don't Read)
有趣的是,在自注意中隐藏着两种类型的并行计算:
将嵌入向量批处理到查询矩阵中
引入多头注意力。
我们将分析这两者。更重要的是,我将尝试提供不同的视角来解释为什么多头自注意力有效!
自注意作为两个矩阵相乘
数学
我们将考虑没有多头的自点积注意力,以提高可读性。给定我们的输入
和可训练的权重矩阵:
dmodel是我们序列中每个输入元素的嵌入向量的大小。
dk是每个自注意层特有的内在维度。
batch是批量大小
tokens是我们序列具有的元素个数。
我们创建 3 种不同的表示形式(查询、键和值):
然后,我们可以将注意力层定义为:
您可能想知道注意力权重在哪里。首先,让我们澄清一下,注意力是作为点积实现的,且就在这里发生:
点积越高,注意力“权重”就越高。这就是为什么它被认为是一种相似性度量。现在让我们看看数学内部。
直观的说明
首先说明,我们将考虑查询并非来自键和值所在序列的情况。假设查询是一个包含 4 个token的序列,而我们要与之关联的序列包含 5 个token。
两个序列都包含相同嵌入维度的向量,在我们的示例中dmodel =3。那么自注意可以定义为两个矩阵相乘。
花一些时间分析下图:
通过将所有查询放在一起,我们有了一个矩阵乘法,而不是每次都有一个单一的查询向量和矩阵相乘。每个查询的处理完全独立于其他查询。这是我们通过使用矩阵乘法并提供所有输入token/查询免费获得的并行化。
查询-键矩阵乘法
基于内容的注意力具有不同的表示形式。注意力层中的查询矩阵在概念上是数据库中的“搜索”。键将说明我们将要查找的位置,而值实际上将为我们提供所需的内容。将键和值视为我们数据库的组件。
直观地说,键是查询(我们正在寻找什么)和值(我们将实际获得什么)之间的桥梁。
请记住,每个向量到向量的乘法都是点积相似性。我们可以使用键来指导我们的“搜索”,并告诉我们对应于输入元素在哪里查找。
换句话说,键将用于计算如何根据我们的特定查询来权衡值的注意力。
请注意,我没有在图中显示 softmax 操作,也没有显示缩减因子
注意V矩阵乘法
然后权重αij用于获取最终加权值。例如,输出O11, O12, O13将使用第一个查询中的注意力权重,如图所示。
普通transformer的交叉注意
同样的原则也适用于编码器-解码器注意力或交叉注意力,这是完全有意义的:
键和值是通过在多个编码器块之后对最终编码输入表示进行线性投影来计算的。
详解多头注意力如何工作
将注意力分解为多个头是并行和独立计算的第二部分。就个人而言,我喜欢将其视为同一序列的多个“线性视图”。
最初的多头注意力定义为:
独立的注意力“头”通常由线性层连接和相乘,以匹配所需的输出维度。输出维度通常与输入嵌入维度dim相同. 这使得我们更容易地堆叠多个transformer块以及跨连接身份识别。
我从Peltarion的博客文章中发现了一个关于多头注意的精彩插图:
直观地说,多头使我们能够独立地关注序列的(一部分)。
如果您喜欢数学和输入输出图,看这里:
自注意独立计算的并行化
同样,所有表示都是从相同的输入创建的,并合并在一起以生成单个输出。然而Qi, Ki, Vi, 各自都在较低维度dk=dmodel/heads上.计算可以按batch大小独立进行。事实上,我们对每个batch和头部进行相同的计算。
通常,独立计算具有非常简单的并行化过程。尽管这取决于 GPU 线程中的底层低级实现。理想情况下,我们会为每个batch 和每个头部分配一个 GPU 线程。例如,如果我们有 batch=2 和 heads=3,我们可以在 6 个不同的线程中运行计算。即使尺寸是dk=dmodel/heads ,从理论上讲,这将引入零开销。开销来自连接结果并再次乘以WO,也是最小的。
到目前为止,您可能已经了解了这个理论。让我们深入研究一些有趣的观察。
关于注意力机制的洞察和观察
自注意不是对称的!
因为我们倾向于使用相同的输入表示,所以不要陷入自注意是对称的陷阱!当我开始了解transformer时,我犯了这个灾难性的错误。
洞察0:自注意不是对称的!
如果你做数学计算,理解起来就很容易:
更具体地说,如果键和查询具有相同数量的N个token,注意力矩阵N×N就可以解释为有向图:
与权重对应的箭头可以被视为信息路由的一种形式。
为了使自注意对称,我们必须对查询和键使用相同的投影矩阵:WQ=WK。而这将呈现为无向图。
为什么?因为当你将矩阵与其转置相乘时,你会得到一个对称矩阵。但是,请记住,结果矩阵的秩不会增加。
受此启发,有许多论文对键和查询使用一个共享投影矩阵,而不是两个。多头注意力有更多信息。
注意作为多个本地信息的路由
基于“使用显式关系编码增强transformer以解决数学问题”论文:
洞察1:“这(他们的结果)表明,注意力机制不仅包含它所关注的状态的子空间,还包括那些保留几乎全部信息内容的状态的仿射变换。在这种情况下,注意力机制可以解释为将多个局部信息源路由到一个局部表征的全局树结构中“~ Schlag et al.
我们倾向于认为多个头部将允许头部关注输入的不同部分,但本文证明了最初的猜测是错误的。头部保留了几乎所有内容。这会将注意力呈现为查询序列相对于键/值的路由算法。
可以有效地对编码器权重进行分类和修剪
在另一项工作中,Voita等人[4]分析了在他们的工作“分析多头自注意:特化的头部完成繁重的工作,其余的可以修剪”中使用多个头部时会发生什么。他们通过观察注意力矩阵确定了3种类型的重要头部:
1.主要关注邻居的位置头。
2.指向具有特定语法关系的token的句法头。
3.指向句子中生僻词的头。
证明其头部分类重要性的最好方法是修剪其他类别。以下是他们的修剪策略示例,该策略基于普通transformer的 48 个头(8 个头乘以 6 个块)的头分类:
如图所示,通过主要保留被归类为杰出类别的头部,他们设法保留了 48个头部中的 17 个,BLEU得分几乎相同。请注意,这大约相当于编码器头的 2/3。
以下是在两个不同的机器翻译数据集中修剪transformer编码器头的结果:
有趣的是,编码器注意力头最容易修剪,而编码器-解码器注意力头似乎是机器翻译最重要的。
洞察2:基于编码器-解码器注意力头主要保留在最后层的事实,文中强调解码器的第一层用于语言建模,而最后层用于源句子的调节。
头部共享共同的投影
这方面的另一篇有价值的论文是Cordonnier等人的“多头注意力:协作而不是连接”。
累积图描述了预训练键矩阵和查询矩阵的方差总和(按 X 轴的降序排列)。预训练的投影矩阵来自一个名为BERT的著名NLP模型,具有dimhead=64 和 12 个头,这表明需要研究矩阵的64⋅12=768个秩。
观察结果再次基于以下等式:
我们将研究预训练投影乘积
左图分别描绘了每个头乘积的秩(红色),而右图是每层串联头的乘积的秩。
洞察 3:即使每个权重矩阵的单独乘积不是低秩,但它们串联的乘积(如右侧红色所示)是低秩。
这实际上意味着头部共享共同的投影。换句话说,独立的头部实际上难以置信的学会了关注相同的子空间。
编码器-解码器上的多头注意非常重要
Paul Michel et al. [2] 展示了从不同注意力子模型逐步修剪时多个头的重要性。
下图显示,当从编码器-解码器注意力层(交叉注意力)修剪头时,性能下降得更快。如下机器翻译的BLEU分数报告。
作者表明,修剪普通transformer60%以上的交叉注意头将导致性能显著下降。
洞察4:编码器-解码器(交叉)注意力明显更依赖于多头分解表示。
应用softmax后,自注意低秩
最后,Sinong Wang等人[7]的一项研究表明,在应用softmax之后,所有层的自注意都是低秩的。
同样,累积图描述了特征值的总和(按 X 轴的降序排列)。从广义上讲,如果使用很少的特征值,归一化累积总和为 1,则意味着这些是关键维度。
对于该图,他们将预训练模型的层和头做奇异值分解为P,并绘制了10k个句子平均的归一化累积奇异值
洞察5:应用softmax后,自注意力处于低秩。这意味着包含在P中的很大一部分信息可以从第一个最大的奇异值(此处为 128)中恢复。
基于这一观察结果,他们提出了一种简单的线性注意力机制,通过向下投影键和值,称为Linformer注意力。
注意力权重作为快速权重记忆系统
上下文依赖的快速权重生成是在90年代初由Schmidhuber等人于1991年引入的。具有慢权重的慢速网络不断为快速网络生成快速权重,使快速权重有效地依赖于上下文。
通过移除众所周知的注意力机制中的softmax,我们得到类似的行为。
其中值和键的外积可以看作是快速权重。
这或多或少是个数据库,其中:
最后,你会得到一些看起来像90年代描画的快速权重的东西:
基于这一观察结果,他们讨论了多种方法来替代softmax操作的移除,并与已经提出的线性复杂性注意力方法相关联。以下是我喜欢这项工作的一个洞察:
洞察6:“因此,为了防止关联在检索时相互干扰,相应的键需要是正交的。否则,点积将关注多个键并返回值的线性组合。Schlag等.
秩崩溃和token一致性
最近,dong等人[6]发现自注意对token一致性具有归纳偏见。
洞察7:令人惊讶的是,他们注意到,如果没有MLP和跳过连接等其他组件,注意力会呈指数级收敛到秩1的矩阵。
为此,他们研究了负责抵消秩崩溃的机制。简而言之,他们发现了以下内容:
1.跳过连接至关重要:它们可防止transformer输出在网络深度方面以指数级速度下降到秩1。
2.多层感知器将特征投射到更高维度并返回初始维度也有作用
3.层归一化在防止秩崩溃方面没有任何作用。
我敢打赌,您可能想知道层范数有什么用。
层范数:迁移学习的关键要素主要是预训练transformer
首先,归一化方法是当前数据集中稳定训练和更快收敛的关键。然而,它们的可训练参数给迁移学习带来了实际挑战。
在transformer案例中,论文“预训练transformer作为通用计算引擎”[10]提供了一些关于仅微调层范数的洞察,相应于γ和β可训练参数。
直观地说,这些参数对应于重新缩放和迁移注意力信号。
他们对属于低数据系的数据集的最关键组件进行了大规模的消融研究。
洞察8:令人惊讶的是,作者发现,在对大型(高数据系)自然语言任务进行预训练后,层范数可训练参数(参数的0.1%)对于微调transformer最关键[10]。
您可以想象低数据系到获取大量标记数据的领域,就像医学成像一样昂贵且困难。然而,在他们的工作中,他们使用MNIST和CIFAR-10等数据集作为低数据系数据集。将它们与transformer可以预训练的大量文本进行比较。
可以看出,冻结的transformer的性能与完全微调的transformer相当,这表明两件事:
洞察9:在海量自然语言数据集上预训练自注意力会产生出色的计算原语。
计算原语是未分解的构造或组件(在给定上下文中,例如编程语言或语言表达式的原子元素)。换句话说,基元是最小的处理单元。事实证明,这些大型NLP数据集中学习的Q,K,V投影矩阵学习了可迁移的原语。
洞察 10:微调注意力层可能会导致小型数据集的性能发散。
关于二次复杂性:我们到达了吗?
在不提及“寻找二次复杂性的替代方案上花费的大量研究”的前提下,我们还不能给注意力机制下结论。我将简要介绍一下Yi Tay等人2020年提供的下图中发生的事情:
大体上,这里分为两类:
1.使用数学来近似全二次全局注意力(all2all)的方法,例如利用矩阵秩的Linformer。
2.试图限制和稀疏注意力的方法。最原始的例子是“窗口”注意力,它在概念上类似于卷积(下图(b))。最成功的稀疏基方法是BigBird,如下图所示,使用了上述注意力类型的组合。
显然,全局注意力由“特殊”token保持,例如用于分类的 CLS token。
话虽如此,降低二次复杂性的道路远未结束。
我计划在领域变得清晰后提供一篇全新的文章。尽管如此,如果你真的想运行一些大型稀疏注意力模型,请查看Deepspeed。它是微软开发的稀疏transformer最著名和最快速的实现之一。它为Pytorch提供了具有重大加速的GPU实现。
结论
经过这么多的观点和洞察,我希望你在分析基于内容的注意力方面至少获得一个新的洞察。在我看来,如此简单的想法如何产生如此巨大的影响以及如此多的含义和洞察,真是令人惊讶。
References
[1] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). Attention is all you need. arXiv preprint arXiv:1706.03762.
[2] Michel, P., Levy, O., & Neubig, G. (2019). Are sixteen heads really better than one?. arXiv preprint arXiv:1905.10650.
[3] Cordonnier, J. B., Loukas, A., & Jaggi, M. (2020). Multi-Head Attention: Collaborate Instead of Concatenate. arXiv preprint arXiv:2006.16362.
[4] Voita, E., Talbot, D., Moiseev, F., Sennrich, R., & Titov, I. (2019). Analyzing multi-head self-attention: Specialized heads do the heavy lifting, the rest can be pruned. arXiv preprint arXiv:1905.09418.
[5] Schlag, I., Irie, K., & Schmidhuber, J. (2021). Linear Transformers Are Secretly Fast Weight Memory Systems. arXiv preprint arXiv:2102.11174.
[6] Yihe Dong et al. 2021. Attention is not all you need: pure attention loses rank doubly exponentially with depth
[7] Wang, S., Li, B., Khabsa, M., Fang, H., & Ma, H. (2020). Linformer: Self-attention with linear complexity. arXiv preprint arXiv:2006.04768.
[8] Tay, Y., Dehghani, M., Abnar, S., Shen, Y., Bahri, D., Pham, P., ... & Metzler, D. (2020). Long Range Arena: A Benchmark for Efficient Transformers. arXiv preprint arXiv:2011.04006.
[9] Zaheer, M., Guruganesh, G., Dubey, A., Ainslie, J., Alberti, C., Ontanon, S., ... & Ahmed, A. (2020). Big bird: Transformers for longer sequences. arXiv preprint arXiv:2007.14062.
[10] Lu, K., Grover, A., Abbeel, P., & Mordatch, I. (2021). Pretrained Transformers as Universal Computation Engines. arXiv preprint arXiv:2103.05247.