NeurIPS 2022 | 稀疏且鲁棒的预训练语言模型
收录会议:
论文链接:
代码链接:
背景及动机
对于模型参数量问题,一些工作尝试用稀疏子网络代替 PLM。[1,2,3] 将微调后的 PLM 剪枝为稀疏子网络。[4,5,6,7] 采用了彩票假设 [8] 的设定,直接剪枝未经微调的 PLM,并把得到的子网络在下游任务进行微调。更进一步,[9] 发现 PLM 中本身就包括一些子网络,它们可以直接用于下游任务测试,而无须对模型权重进行任何微调。图 1 展示了这三种微调-剪枝流程。
▲ 图1 通过不同微调-剪枝流程得到的PLM子网络,在in-distribution和out-of-distribution两种场景下测试。
虽然近期工作在上述两个问题上都取得了不错的进展,但是还很少有工作对 PLM 的高效性和鲁棒性同时进行探究。然而,为了促进 PLM 在真实场景中的应用,这两个问题是需要被同时解决的。因此,本文将 PLM 剪枝研究扩展到了 OOD 场景,在上述三种微调-剪枝流程下,探究是否存在既稀疏又对 dataset bias 鲁棒的 PLM 子网络(Sparse and Robust SubNetworks, SRNets)?
BERT剪枝及去偏
2.1 BERT子网络
2.2 剪枝方法
基于权重的剪枝 [8,10] 移除绝对值最小的模型参数。通常,剪枝和训练是交替进行的,这整个流程也叫做迭代权重剪枝(Iterative Magnitude Pruning, IMP):
将完整模型训练至收敛。 将一部分绝对值最小的模型参数移除。 训练剪枝后的子网络。 重复 1-3,直到子网络达到目标稀疏程度。
2.2.2 掩码训练
标准交叉熵损失:计算主模型预测概率 和正确类别独热分布(one-hot) 之间的交叉熵 。 Product-of-Experts(PoE)[13]:先结合主模型和偏见模型的预测概率 和 ,再计算交叉熵 。 样本重权重 [14]:直接用偏见程度调整每个样本损失函数值的权重,给予偏见程度高的样本低权重 。
置信度正则化 [15]:这是一种基于知识蒸馏的方法,需要一个用标准交叉熵训练好的教师模型。教师模型的预测分布 给学生模型(主模型)提供监督信号,同时用偏见程度调整每个样本损失值的权重:
搜索稀疏且鲁棒的BERT子网络
3.1 实验设置
模型:我们主要以 BERT-base 为研究对象。同时,为了验证结论的普适性,我们还在 RoBERTa-base 和 BERT-large 上进行了部分实验,相关结果请参见论文。
任务及数据集:本文在三个自然语言理解任务上进行了实验:自然语言推断(Natural Language Inference, NLI),释义识别(Paraphrase Identification)和事实验证(Fact Verification)。每个任务都有一个 in-distribution(ID)数据集和一个 out-of-distribution(OOD)数据集。其中 ID 数据集中存在 dataset bias,而 OOD 数据集在构建时去除了这些 bias。
NLI:MNLI 作为 ID 数据集,HANS 作为 OOD 数据集。 释义识别:QQP 作为 ID 数据集,PAWS-qqp, PAWS-wiki 作为 OOD 数据集。 事实验证:FEVER 为 ID 数据集,FEVER-symmetric(v1,v2)为 OOD 数据集。
3.2 微调后搜索BERT子网络
▲ 图2 标准交叉熵微调后搜索BERT子网络的效果
我们用四种不同的剪枝-损失函数组合对微调后的 BERT 进行压缩。这里我们只展示标准交叉熵和 PoE 两种损失函数,关于样本重权重和置信度正则化请参见我们的论文。
如果在剪枝过程中采用标准交叉熵损失,即 mask train(std)和 imp(std),得到的子网络相比完整 BERT 在 HANS 和 PAWS 上总体有略微的提升。这可能是因为部分和表层特征相关的参数被剪枝了。
如果在剪枝过程中采用 PoE 去偏方法,即 mask train(poe)和 imp(poe),我们可以获得 70% 稀疏的子网络(保留 30% 参数),它们在 OOD 数据集上比完整 BERT 有显著的提升,并且在 ID 数据集上保持了95%以上的性能。这说明在剪枝过程中采用去偏损失函数对于同时实现压缩和去偏的目标是很有效的。另外,由于掩码训练没有改变模型参数值,mask train(poe)的效果意味着带有偏见的 BERT 模型中本身就存在鲁棒子网络。 在两种训练损失下,掩码训练的总体效果都优于 IMP。
3.2.2 PoE去偏方法微调BERT
现在我们对 PoE 微调的 BERT 进行剪枝,看看子网络的效果如何。因为前一小节已经发现在剪枝过程中使用 PoE 比标准交叉熵效果好,此处我们仅考虑 PoE。
从图 3 的结果中,我们可以看出:
和带有偏见的 full bert(std)不同,在已经较为鲁棒的 full bert(poe)中,子网络的 OOD 效果没有显著超越完整 BERT。但是在较高的 70% 稀疏程度下,子网络的 ID 和 OOD 效果仍然没有很多下降,保持了完整 BERT 95% 以上的性能。
从带有偏见的 BERT 中搜索到的子网络(图 3 中橙色曲线)和较为鲁棒 BERT 中搜索到的子网络(图 3 中蓝色曲线)效果没有很大差距。这意味着只要在剪枝过程中引入了去偏方法,对完整BERT的去偏就不是必须的。
3.3 单独微调的BERT子网络
根据图 4 中的结果,我们发现:
如果采用标准交叉熵微调,在 20%~50% 稀疏程度内,用掩码训练搜索到的子网络 mask train(poe)的 OOD 效果要优于完整 BERT。这说明拥有较为鲁棒结构的子网络在训练过程中,相比完整 BERT 更不容易受到 dataset bias 的影响。 如果采用 PoE 微调,在 70% 稀疏程度内,掩码训练和 IMP 剪枝得到的子网络和完整 BERT 的效果相当。 结合以上两点,我们可以吧 BERT 中的彩票假设 [4,5] 推广到 OOD 场景:预训练 BERT 中包含了一些子网络,它们可以在下游任务上用标准交叉熵或 PoE 去偏方法单独微调,并且在 ID 和 OOD 场景下取得和微调完整 BERT 相当的效果。 对比标准交叉熵微调和 PoE 微调,在所有情况下后者的 OOD 性能都有明显优势。这说明即使子网络已经学习到了较为鲁棒的结构,对于参数值的去偏训练仍然是重要的。
3.4 不微调参数的BERT子网络
在本小节中,我们直接在预训练 BERT 参数(包括随机初始化的下游分类器参数)上进行掩码训练而不对参数值进行微调(对应图1,(c))。在掩码训练中,我们采用了标准交叉熵和 PoE 两种损失函数。
从图 5 中,我们发现:
在掩码训练过程中采用交叉熵得到的子网络(50% 稀疏程度以下)和同样用交叉熵训练出的完整 BERT 效果相当。 对于 PoE 损失函数也是同样的结论。这说明未经微调的预训练 BERT 中本身已经存在适用于特定下游任务的 SRNets。 对比预训练 BERT 中的子网络(图 5 中绿色曲线)和微调后 BERT 中的子网络(图 5 中橙色曲线),我们发现后者总体效果略优于前者,但是二者差距较小。这引出了一个比较有意思的问题:是否有必要首先把完整 BERT 微调至收敛,再开始掩码训练?在 3.6.1 小节中将对这个问题进行进一步探究。
3.5 利用OOD数据搜索无偏的oracle子网络
通过 3.2-3.4 小节的实验分析,我们已经发现在不同的微调-剪枝流程下都存在稀疏且鲁棒的 BERT 子网络。本小节希望探究这些子网络 OOD 性能的上界。为此我们利用 OOD 训练集进行掩码训练搜索 oracle 子网络,而用和训练集没有重合数据的 OOD 测试集进行测试。和之前的实验一样,我们探究了三种微调-剪枝流程下的效果。为了反映子网络结构本身的去偏能力,在对子网络参数进行微调时(对应图1,(b)的设置),我们采用的是对 dataset bias 较为敏感的标准交叉熵损失。
在 BERT 微调前后搜索到的 oracle 子网络(分别对应图 6 中 bert-pt subnet 和 bert-ft subnet)在一定稀疏范围内可以取得很高的 OOD 性能。特别地,20%~70% 稀疏的 bert-ft subnet 在 HANS 上都取得了 100% 的准确率。 如果用标准交叉熵对 oracle 子网络参数进行微调,它们的 OOD 效果会有一定的下降。然而相比完整 BERT,这些 oracle 子网络在训练过程中对 dataset bias 明显更加鲁棒。 以上发现说明,BERT 子网络对 dataset bias 鲁棒性的上界很高,在不同微调-剪枝流程下理论上都存在几乎无偏的 BERT 子网络。
3.6 改进掩码训练方法
相比于在微调后的 BERT 中搜索子网络,在微调前进行搜索的总体训练开销更小。在 3.4 小节中我们发现,二者的最终效果存在一些差距。那么,我们是否可以在训练开销和子网络效果之间找到一个更合适的 trade-off?为此我们进行了一系列的实验,从不同微调程度的完整 BERT checkpoints 开始掩码训练。
3.6.2 逐渐提升稀疏程度
▲ 图8 掩码训练中固定子网络稀疏程度和逐渐提升稀疏程度对比。
在本文中,对于预训练语言模型 BERT,我们探究能否同时实现其子网络的稀疏性和鲁棒性。通过在三种自然语言处理任务上进行大量的实验,我们发现在三种常见的微调-剪枝流程下,的确存在稀疏且鲁棒的 BERT 子网络(SRNets)。进一步利用 OOD 训练集,我们发现 BERT 中存在对特定 dataset bias 几乎无偏的子网络。最后,针对掩码训练剪枝方法,我们从开始剪枝的时刻和掩码训练过程中子网络稀疏程度的控制两个角度,对子网络搜索的效率和效果提出了改进的思路。
在我们工作的基础上,仍然有几个方向值得继续改进和探究:
本文只探究了 BERT 类型的 PLM 和自然语言理解任务。在其他类型的 PLM(例如 GPT)和 NLP 任务(例如自然语言生成)中也可能存在 dataset bias 问题。在这些场景下实现 PLM 的压缩和去偏也是很重要的。 本文采用的整个微调-剪枝流程仍然有很大的优化空间。例如 3.6.1 小节中提到的,事先对开始掩码训练的时刻进行精确预测也是一个有意思的研究方向。
参考文献
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:hr@paperweekly.site
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」