查看原文
其他

线性Transformer应该不是你要等的那个模型

苏剑林 PaperWeekly 2022-07-04



©PaperWeekly 原创 · 作者 | 苏剑林
单位 | 追一科技
研究方向 | NLP、神经网络

笔者之前已经多次讨论过线性 Attention 的相关内容。介绍线性 Attention 的逻辑大体上都是:标准 Attention 具有 的平方复杂度,是其主要的“硬伤”之一,于是我们 复杂度的改进模型,也就是线性 Attention。有些读者看到线性 Attention 的介绍后,就一直很期待我们发布基于线性 Attention 的预训练模型,以缓解他们被 BERT 的算力消耗所折腾的“死去活来”之苦。
然而,本文要说的是:抱有这种念头的读者可能要失望了,标准 Attention 到线性 Attention 的转换应该远远达不到你的预期,而 BERT 那么慢的原因也并不是因为标准 Attention 的平方复杂度。


BERT之反思

按照直观理解,平方复杂度换成线性复杂度不应该要“突飞猛进”才对嘛?怎么反而“远远达不到预期”?出现这个疑惑的主要原因,是我们一直以来都没有仔细评估一下常规的 Transformer 模型(如BERT)的整体计算量。

很多读者都已经知道,Transformer 的结构大体上是 Embedding 层加若干个 Transformer 层,Embedding 层的计算量很少,我们主要关心 Transformer 层。忽略残差、Layer Normalization 等计算量比较小的层不计,每个 Transformer 层主要组成就是两个子层:Self Attention(简称 SA)和 FeedForward Network(简称 FFN)。

虽然 Transformer 的开山之作声称“Attention is all you need” [1],但是也有不少工作论证了残差、FFN 等模块的必要性了,比如《Attention is Not All You Need: Pure Attention Loses Rank Doubly Exponentially with Depth》[2]

现在问大家一个问题:

你觉得是SA计算量大还是FFN计算量大?


评估计算量

毋庸置疑,SA 的复杂度是 ,而FFN的复杂度则是 ,如果你直接凭此就想当然地说 SA 计算量比 FFN 大,那就错了!
我们知道加法比乘法快很多,所以在估计计算量的时候我们主要计算要做多少次乘法,神经网络里边,主要的运算是矩阵相乘,不难估计按照定义一个 的矩阵乘以一个 的矩阵要做 abc 次乘法,所以 abc 就是两个矩阵相乘的复杂度了,这是我们估算 Transformer 复杂度的依据。
设 n 为序列长度,d 为 head_size(base 版是 64),h 为 head 的数目(base 版是 12),那么 hd 就是我们通常说的“hidden_size”(base 版是 768)。对于 SA 来说,一开始是 的投影变换,即 的矩阵乘以 的矩阵做3次,因此计算量是 ;然后是 h 个 Attention 头的运算,每个头先是 相乘得到 的 Attention 矩阵(softmax 和归一化的计算量暂且忽略),然后 的矩阵与 相乘得到 的矩阵,这两步的计算量都是 ,所以总计算量是 ;最后的输出还有一个投影变换,也是 的矩阵乘以 的矩阵,计算量是 。所以,SA 的总计算量是

至于 FFN 就比较简单了,它就是两个全连接层,也就是两个矩阵变换(激活函数的计算量也忽略不计),一般的参数设置是:第一层是 的矩阵乘以 的矩阵,第二层就是 的矩阵乘以 的矩阵。所以总计算量是
这样一来,如果 SA 的计算量比 FFN 大,就意味着

对于 base 版来说,这意味着 !也就是说,只有当序列长度超过 1536 时,SA 的计算量才大于 FFN,在这之前,都是线性复杂度的 FFN 占主导!
这还不止,由上面的结果我们可以得到 Transformer 层总的计算量为

它是关于 n 的一次项和二次项的求和,当 n 足够大时,复杂度自然是 ,然而二次项占主导的条件是

对于 base 版来说,这意味着 !也就是说,当序列长度接近 5000 时,Transformer 的复杂度才真正体现出二次性!


综合的结论

综合上述结果,我们可以得到结论:对于 base 版来说,当序列长度不超过 1536 时,Transformer 的复杂度都是近乎线性的;当序列长度超过 1536 时,Transformer 的计算量逐渐以 Attention 为主,复杂度慢慢趋于二次方,直到长度超过 4608,才真正以二次项为主。当然这个边界只是一个估计,实际情况可能有所偏差,大家就此感知一下范围和数量级就好。

笔者以前也建议过很多读者,对于不超过 2000 长度的“长文本”任务,直接用 NEZHA [3] 或者 RoFormer 这种不限长度的模型试试,不要想太多的技巧,原因也是如此。你想再多技巧,也顶多是降到线性复杂度,而在这个长度范围内模型本身就是近乎线性的,各种技巧也省不了多少。

对于老老实实用 BERT base 的读者来说,maxlen 一般不超过 512,远低于上述界限,因此就不要再说 Attention 的平方复杂度费硬件之类的吐槽了,因为事实是:
BERT 之所以慢,主要是因为它真的大,而不是因为 Attention 的平方复杂度。


“线性”含义

至于对线性 Attention “远远达不到预期”而感到疑惑的另一个原因,则是没有从实际情况分析线性 Attention 的计算量,以至于对线性 Attention 期待过高。

线性 Attention 的介绍可以参考《线性 Attention 的探索:Attention 必须有个 Softmax 吗?》,这里不做重复。简单来说,线性 Attention 就是按照 的顺序算注意力。所以按照前面的估算方法,线性 Attention 每个头运算的计算量就是 ,而标准 Attention 则是 ,因此如果 ,那么线性 Attention 是比标准 Attention 要省计算量的。(注:实现线性效率的 Attention 也不止这一种思路,但总的而言复杂度是相似的,因此下面的结论也有代表性。)
对于 base 版来说,那就是 ,这个界还是很容易达到的,所以有些读者可能会想“能省一点是一点”、“不用白不用”。然而,这是假设了标准 Attention 与线性 Attention 都用同一个 d 的前提下得出的结果。
而认真琢磨过《Performer:用随机投影将 Attention 的复杂度线性化》、《Transformer 升级之路:3、从  Performer 到线性 Attention》的读者都知道,线性 Attention 有着比标准 Attention 更严重的“低秩瓶颈”,所以如果切换为线性 Attention 后还用同一个  d,那么线性 Attention 的效果将会明显下降,而如果要保留大致相同的效果,那么线性 Attention 要用更大的 d(一般是原来的 4 倍左右)。
这样一来,线性 Attention 的计算量应该是 ,如果线性 Attention 要比标准 Attention 快,那么就要 ,对于 base 版来说,就是 ,这也超出了一般读者所能用到的范围了。况且换成线性 Attention 后,前面关于 SA 和 FFN 的计算量结论依然存在,即大部分序列长度下占主导计算量的还是FFN等线性运算,换了线性 Attention 后也无法感觉到明显的速度提升。所以,总的来说

你要不是成千上万的序列长度,就不要想着换线性 Attention 了。


再翻翻论文

事实上,就算不进行上述分析,只要认真读过关于 Attention 效率改进相关工作的读者,从论文中的某些图片就可以得到类似的结论:所谓更“高效”的 Attention,一般都只适用于成千上万的序列长度,只有在这个场景下性能才有明显提升。

比如较早的工作 Sparse Transformers [4],里边有一张图显示出处理的序列长度都是 3000+ 的:
比如大名鼎鼎的 Reformer [5],演示性能的序列长度都是以 K 为单位的:

大家颇多好评的 Longformer [6] 也是如此:

还有 Google 关于线性 Attention 的经典之作 Performer [7],显示出哪怕序列长度是 ,Performer 与 Transformer 的差距也不能说特别显著:

最后是比较新的工作 Luna,提供了一个比较综合的对比表格,同样支持我们的结论:

从已有的各个高效 Attention 的工作中,我们可以得出结论:这些改进工作所关心的序列长度主要都是以千为单位的,有明显计算效率提升的序列长度基本上都要好几千;当然,我们前面的讨论主要针对的还是时间复杂度,对于空间复杂度,也就是显存占用量,降低的幅度一般要比时间复杂度提升的幅度的要大,但总体而言都是长序列才有价值。


换个期待吧

所以,如果你的序列长度还只是一两百,那么就完全不要期望 Attention 本身的改进了,老老实实换个小点的模型就好。你可以期望未来会有更小的模型能达到同样好的效果,但是不要期望同样大的模型通过修改 Attention 来提升效率,因为说白了,就算把 Attention 完全去掉,也提升不了多少性能。

参考文献

[1] https://arxiv.org/abs/1706.03762
[2] https://arxiv.org/abs/2103.03404
[3]https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/NEZHA-TensorFlow
[4] https://arxiv.org/abs/1904.10509
[5] https://arxiv.org/abs/2001.04451
[6] https://arxiv.org/abs/2004.05150
[7] https://arxiv.org/abs/2009.14794


特别鸣谢

感谢 TCCI 天桥脑科学研究院对于 PaperWeekly 的支持。TCCI 关注大脑探知、大脑功能和大脑健康。


更多阅读




#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编




🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧



·

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存