NeurIPS 2022 | 仅需3分钟!开源Transformer快速训练后剪枝框架来了
论文标题:
A Fast Post-Training Pruning Framework for Transformers
https://arxiv.org/pdf/2204.09656.pdf
https://github.com/WoosukKwon/retraining-free-pruning
3. 基于线性最小二乘的掩码微调(Mask Tuning),以找到最优实值掩码,尽可能恢复模型性能。
▲ 图1. 本文剪枝框架示意图
方法
在不进行再训练的基础上,该问题可分解为三个子问题:1)确定各层注意力头和中间层神经元的修剪比例;2)确定各层注意力头和中间层神经元的修剪位置,即获得最优二值掩码;3)确定最优实值掩码,尽可能恢复模型性能。三个子问题分别由 2.2, 2.3, 2.4 三个小节依次解决。
for num_heads in range(1, num_hidden_layers * num_attention_heads + 1): # 遍历所有的可能的头修剪数量
heads_mac = mac_per_head(num_patches, hidden_size, attention_head_size) * num_heads # 计算修剪头减少的 FLOPs
neurons_mac = max_mac - heads_mac # 按照 FLOPs 约束, 计算神经元需要减少的 FLOPs
num_neurons = int(neurons_mac / mac_per_neuron(num_patches, hidden_size)) # 计算最少需要修剪的神经元数量
num_neurons = max(num_neurons, 0)
# 贪心策略修剪重要性得分最小的元素
total_importance = sorted_head_importance[:num_heads].sum() + sorted_neuron_importance[:num_neurons].sum()
if total_importance > max_importance: # 记录全局最优的头/神经元修剪比例
max_importance = total_importance
head_indicies = sorted_head_indicies[:num_heads]
neuron_indicies = sorted_neuron_indicies[:num_neurons]
2.3 掩码重排列
费雪信息矩阵的块对角近似:掩码搜索阶段忽略了不同掩码变量间的相互作用,这虽然简化了问题,但也导致了次优解。例如,同一层的两个注意力头有类似的作用,只修剪其中一个对性能的影响不大,但当两者都被修剪时,模型的性能显著降低。因此作者在这一阶段考虑了每个 MHA 层或 FFN 层内部的相互作用,以找到更好的修剪位置。
# Greedy search
masked_indicies = indicies[:num_pruned] # 单一 MHA/FFN 层已修剪的元素索引
for index in indicies[num_pruned:]: # 遍历未修剪的元素
masked_indicies.append(index) # 将其加入已修剪元素列表
grad_vectors = grads[masked_indicies]
grad_sum = grad_vectors.sum(dim=0)
complement = grad_sum - grad_vectors
grad_sum_length = complement.pow(2).sum(dim=1)
removed = grad_sum_length.argmin() # 选出最重要的元素移出已修剪列表
del masked_indicies[removed]
2.4 掩码微调
1. FLOPs 和准确率比较:如图 3 所示,在没有任何再训练和大幅降低修剪成本的情况下,本文的方法取得了与之前方法相当或更好的结果。
▲ 图3. 与过往方法的压缩性能对比
▲ 表1. 与过往方法的修剪成本对比
▲ 表2. 三阶段消融实验
参考文献
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:hr@paperweekly.site
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧