自回归与非自回归模型不可兼得?预训练模型BANG全都要!
编者按:近两年,预训练技术的发展极大地提高了自然语言生成的效果,但随着数据量和模型大小的增加,模型在使用时的推断耗时也随之变大。为了降低自回归生成的时延,并行预测目标语句所有单词的非自回归模型被提出。然而,非自回归和半非自回归的依赖关系学习和生成难度较大,它们的生成质量往往弱于自回归模型。针对上述问题,微软亚洲研究院的研究员们提出了新的自然语言生成预训练 BANG。
近两年,预训练技术为自然语言生成的效果带来了极大的改善。基于 Transformer,更大的模型,更大的预训练语料往往可以在下游任务上提供更好的结果。与此同时,模型在使用时的推断耗时也随之变大。这些预训练工作往往针对自回归语言生成模型设计,自回归每次会使用已生成的序列作为已知信息预测未来的一个单词,最终再把每个时间步生成的单词拼成一个完整的序列输出。这其中的时延成为了线上使用或者实时使用这些预训练的自然语言生成模型的瓶颈。
非自回归模型的提出缓解了自回归模型的高时延问题。在非自回归模型中,每个单词之间没有依赖关系,整个输出序列的每个单词被并行地同步预测。虽然其推断速度得到了很大改善,但是生成质量却往往弱于自回归模型。为了平衡推断速度和生成质量,半非自回归的模型被提出和研究。半非自回归的经典做法是把非自回归生成的结果进行多次迭代,但不同半非自回归模型的算法差异比较大。由于和自回归相比,非自回归和半非自回归的依赖关系学习和生成难度较大,所以它们往往在文本-文本翻译,或者语音-文本翻译,文本-语音翻译等输入输出较为对齐的任务上可以提供不错的生成效果,但是很少在问答、对话、摘要等任务上进行研究,而这些领域被自回归生成验证可以拥有不错的生成质量且在预训练下得到提升。
针对上述问题,微软亚洲研究院的研究员们提出了新的自然语言生成预训练 BANG,并指出自回归和非自回归生成可以被统一地理解为,有多大比例的上文信息可以被使用。BANG 的贡献主要有:
1)BANG 在大规模预训练中,通过考虑遮盖任意长度的前文来沟通自回归和非自回归生成;
2)提出跨流可见的多流注意力机制来实现高效的预训练,所有单词在考虑到任意长度前文被遮盖的前提下都可被并行预测;
3)对于不同的需求状况,BANG 支持自回归微调,非自回归微调和半非自回归微调。BANG 第一次把不同的生成方案在同一个预训练模型里进行支持;
4)研究员们在 16GB 的英语语料上进行了预训练,在摘要、对话、问题生成上,BANG 对自回归效果和半非自回归效果带来了显著的提升,并达到了与非预训练的 Transformer 自回归模型相似的评测结果。对于自回归生成的微调,BANG 也可以和当前主流的自回归预训练模型达到相似的结果。
总体结构
基于 Transformer 编码器-解码器的序列生成框架,BANG 由多层堆叠的使用自注意力机制的 Transformer 编码器和多层堆叠的使用跨流可见多流自注意力机制的 Transformer 解码器组成。研究员们考虑了使用输入序列 X={x_1,x_2,…,x_(|X|)},生成预测目标序列 Y={y_1,y_2,…,y_(|Y|)} 的过程。
首先,编码器将输入序列解码编码为隐状态 H_enc。
在解码器端,对于 Y 中的每个单词 y_t,解码器都会产生将前文中的任意长度前缀遮盖后的预测概率:
而 BANG 目标序列的条件生成概率和优化的语言模型则可描述为:
BANG 会优化 Y ̂ 而非原始的输出序列 Y。对 Y 中的每个单词 y_t,Y ̂ 都会考虑对任意 i<t,并用 [MASK] 遮盖掉 y_t 上文的前 i 个单词。可以看到,BANG 的优化目标由三部分组成,自回归部分,非自回归部分和沟通部分。自回归部分和非自回归部分直接优化了下游任务,而沟通部分则设计了一个从自回归到非自回归的课程学习路径。
跨流可见多流自注意力
为了实现上述的优化目标,且高效并行化计算,研究员们提出了跨流可见多流自注意力机制。以预测 y_4 为例,如图1:
图1:BANG 预训练中的信息流
在图1 BANG 预训练中的信息流中,M-S 指主要流(main stream),喂入真实的字符;P-S 指预测流(predicting stream),喂入 [M]([MASK])。P-S 中的 [MASK] 向 M-S 和它之前的 P-S 进行注意力计算来获取前文的真实单词 +[MASK] 字符的信息。
图1最上面的一行展示了主要流和第一个预测流。预测 y_4 使用的 [M] 向主要流中的 y_1,y_2,y_3 进行注意力计算,即 y_4 以条件概率 P(y_4 |y_1,y_2,y_3) 进行预测,其效果如左侧所示。第一个预测流中的所有字符以完整的前文信息进行了自回归的预测。
图1中的第二行则展示了 y_4 在第二个预测流中的效果。第二个预测流中,每个被预测的单词所看到的前文信息都被遮盖住了一个字符,即如左侧所示,y_4 看到真实的 y_1 和 y_2,但是 y_3 被 [M] 遮盖。其实现如右侧的主要流和两个预测流所示。第二个预测流中的 [M] 向主要流的 y_1,y_2 以及第一个预测流中 y_3 的 [M] 进行注意力计算。第一个预测流 y_3 的 [M] 与第二个预测流中的 y_4 则组成了条件概率 P(y_3,y_4 |y_1,y_2)。比较第一行和第二行,可以看到,随着注意力流的增大,前面的上文信息被遮盖,生成方式也从自回归向非自回归移动。
图1中最后一行展示了 y_4 在第四个预测流中,最终以非自回归的方式进行预测。此时第四预测流中预测 y_4 的 [M] 向第一个预测流中 y_1 的 [M],第二个预测流中 y_2 的 [M] 和第三个预测流中 y_3 的 [M] 进行注意力计算,此时没有任何真实的上文信息被使用。
可以看到,第一个预测流中,每个单词都以自回归进行预测;每个预测流中的第一个单词以非自回归进行预测;其他位置则以介于自回归和非自回归之间的方式进行预测。假设目标序列长度 |Y|=n,则 BANG 设置 n 个预测流,此时每个词的任意长度前缀被 [M] 替换的情形都在同一个时间步中被进行并行的预测。
为了优化 GPU 的显存占用和计算量,BANG 采用了成块的计算方案。因为每个位置只会看到它之前的预测流信息,所以 BANG 从第一个预测流向最后一个预测流进行计算,将重复计算的 K 和 V 向量缓存下来。在第 l 层的工作流程如下:
其中,Linear 是从隐状态中获取 Q,K, V 向量的三个线性计算函数,⊕ 代表拼接操作,Attn 函数则可以描述为:
其中,L 为相对位置偏差和控制哪些位置可以被看到的遮盖矩阵。
微调策略
继续以预测 y_4 为例,来看一下针对自回归、非自回归、半自回归的微调策略。在 BANG 自回归生成微调中,预测流中的 [M] 可以从主要流中获取完整的前文信息。其训练方式同 XLNet 的双流机制。
图2:BANG 自回归微调
在 BANG 的非自回归微调中,只有一个预测流,并放置若干个 [M],使用单向信息流,与预训练一致。最后以第一个结束符 [SEP] 代表生成作为结束。
图3:BANG 非自回归微调
而在 BANG 的半非自回归微调中,训练过程同预训练方案,推断过程如图4所示,可以进行任意步数的自回归生成,作为高质量的上文线索,然后将剩余部分并行生成。
图4:BANG 半非自回归生成
主实验
BANG 使用了 Wikipedia 加 BookCorpus 的 16GB 英语语料,使用 MASS 的连续字段掩盖预测任务进行了 BANG_base 的预训练。对于每个连续的64个单词的片段,会掩盖其中连续的15%即9个单词,用预测其掩盖的部分作为输出。BANG_base 使用了6层编码器、6层解码器、隐状态768和9个预测流进行了35轮的预训练。并使用了 SQuAD 1.1 问题生成、XSum 摘要和 PersonaChat 对话生成作为评测集,进行了自回归、非自回归、半自回归的对比,结果如下:
表1:SQuAD 1.1 问题生成的实验结果
表2:XSum 摘要任务的实验结果
表3:PersonaChat 对话生成的实验结果
可以看到,BANG 对于非自回归和半非自回归的效果提升非常明显,推断速度基本相似,而对于自回归模型的效果与当前主流的自回归预训练模型也达到了相似的水准。BANG 非自回归的结果达到了未预训练 Transformer 的相似水平,并带来了约十倍的推断速度提升,这表明,通过预训练,非自回归也可以在普通的自然语言生成任务上得到不错的生成质量。
与非自回归预训练对比
因为 BANG 是非自回归的第一个大规模语料的预训练工作,所以在表1-表3中的非自回归和半非自回归的对比模型是没有经过预训练的。为了验证 BANG 对于非自回归生成预训练的有效性,研究员们使用了非自回归的方案进行了预训练并与 BANG 进行对比:
表4:SQuAD 1.1 问题生成上,没有预训练、非自回归预训练和 BANG 预训练的对比
表5:Xsum 摘要任务上,没有预训练、非自回归预训练和 BANG 预训练的对比
可以看出,预训练可以显著提升非自回归的生成结果,而经过相同的非自回归微调,BANG 一致地超过了纯非自回归预训练结果。这表明,BANG 所提出的沟通自回归和非自回归的预训练方案是取得更好结果的原因。
案例分析
本文作者:齐炜祯、宫叶云、段楠
论文链接:(将于近日更新)
BANG: Bridging Autoregressive and Non-autoregressive Generation with Large Scale Pretraining
https://arxiv.org/abs/2012.15525
近期,研究员还将开源代码,敬请关注:
https://github.com/microsoft/BANG
你也许还想看: