自我蒸馏方法-减轻大模型微调过程中的灾难性遗忘
写在前面
大家好,我是刘聪NLP。
大模型在指定任务上进行微调后,会取得较为不错的效果,但同时可能带来模型原有能力的下降。今天给大家带来一篇通过自我蒸馏减轻大模型微调时的灾难性遗忘的方法-SDFT(Self-Distillation Fine-Tuning)。
Paper: https://arxiv.org/abs/2402.13669
Github: https://github.com/sail-sg/sdft
特定任务微调后导致模型遵循通用指令能力变弱的主要原因是任务数据集的信息分布与原始LLM的信息分布之间存在差距。目前主流解决大模型微调后灾难行遗忘的方法是在微调过程中加入通用的指令数据。
而自我蒸馏方法主要是通过模型本身对任务数据进行生成引导,构建自我蒸馏数据集,改变任务数据的信息分布,减少与原始模型信息分布的差距,如下图所示。
方法介绍
大模型SFT过程,就是将指令和输入的上下文内容,映射到相应的输出上,最小化数据信息分布与语言模型信息分布之间的差异,如下:
其中,表示模型训练参数,表示输入上下文内容,表示指令内容,表示模型输出。
SDFT方法首先根据原始大模型对微调指令数据进行生成回复内容修改,将任务数据的指令回复结果映射到大模型分布内的回复结果,
在重写过程中,减少对大模型的额外要求,仅让其重新回复结果,自我蒸馏提示模板如下图所示,
然后为了确保蒸馏的回复内容质量,采用简单的启发式方法来评估蒸馏的回复内容。例如,在数学推理问题中,如果可以从蒸馏的回复内容中提取出最终答案,则采用蒸馏的回复内容;否则保留原始回复内容。
PS: 这里跟作者交流过,实际上仅数学任务采用了这种责令,其他不好判断的任务默认蒸馏效果准确。
最后,采用蒸馏后的回复内容替换原始回复内容用于大模型微调,
实验结果
所有实验均利用Llama-2-chat-7b模型,采用Lora方法训练,学习率初始为1e-4,按照余弦调度策略衰减到0,训练批次大小为8。
数据集涉及单任务和多任务两种数据:
单任务:OpenFunctions、GSM8K和MagiCoder; 多任务:Alpaca、Dolly和LIMA;
模型在评估过程中,利用Advbench榜单进行安全性评估,利用AlpacaEval榜单进行实用性评估,利用OpenLLM榜单进行知识评估。
如下表所示,普通微调虽然可以增强模型在目标任务上的效果,但也会导致在其他任务上性能显著下降。而SDFT可以有效缓解这种性能下降,甚至会有效果提示。
如下表所示,在Chat模型上进行普通任务微调,会导致模型对齐效果丧失,也就是安全性下降,而SDFT方法可以有效缓解。
有趣的是,虽然微调会对下游任务有影响,但对模型本身的知识能力影响较小。
分析自我蒸馏数据占比对微调的影响,如下图所示,当自我蒸馏数据占比越高时,效果越好。
如果将自我蒸馏数据与原始数据混合进行训练,发现与混合比例为50%时持平或略低,也可以从侧面体现反应数据质量的作用大于数据数量。
对微调后模型与原始模型指令生成结果及嵌入向量进行分析,发现普通微调方法随着数据量增加,偏移越严重,而SDFT方法微调后的模型嵌入偏移更小。
写在最后
自我蒸馏方法在不引入额外数据的情况下,可以极大程度的减轻模型的遗忘现象。后期可以利用外部模型,将自我蒸馏数据保留机制进行完善,说不定会有意想不到的效果。
PS:给公众号添加【星标⭐️】不迷路!您的点赞、在看、关注是我坚持的最大动力!
欢迎多多关注公众号「NLP工作站」,加入交流群,交个朋友吧,一起学习,一起进步!
我们的口号是“生命不止,学习不停”!
往期推荐:
Yi技术报告细节分享 大模型增量预训练新技巧-解决灾难性遗忘 如何提高LLMs的文本表征(Text Embedding)能力? DEITA-大模型指令微调的数据高效筛选方法 大模型微调技巧 | 高质量指令数据筛选方法-MoDS 辟谣!微软撤回声称ChatGPT为20B参数的论文,并给出解释。 如何看待微软论文声称 ChatGPT 是 20B (200亿) 参数量的模型? 大模型微调技巧-在Embeeding上加入噪音提高指令微调效果 如何从数据集中自动识别高质量的指令数据 BaiChuan2技术报告细节分享&个人想法 大模型LLM微调经验总结&项目更新 打造LLM界的Web UI 是我们在训练大模型,还是大模型在训练我们? Llama2技术细节&开源影响 大模型时代-行业落地再思考 垂直领域大模型的一些思考及开源模型汇总 如何评估大模型-LLMs的好坏? 总结|Prompt在NER场景的应用