查看原文
其他

你的大模型为什么训不快?大模型预训练技术精要

InternLM 2024-04-23

本文来源:https://zhuanlan.zhihu.com/p/647395142

作者:盐梅,经授权后发布


大模型训练用时可以拆解成两方面:一个是模型运行时间,这部分要想快就需要买更好的机器,研究fused kernel来加速,一般时间周期长,难度高,这部分后面文章会有专门讨论。第二部分是通讯时间,大模型训练过程中需要不断从其他GPU,其他节点收集梯度信息更新模型。千万不可小看第二部分通讯时间,笔者在默认设置下在自己环境上跑大模型并行,发现通讯时间竟然高达90%,模型训练异常低效,优化后占比降低到20%左右,显著提升了模型训练效率,并且吞吐率到能对齐甚至略高于llama2官方的吞吐率。



训模型常见思路


为啥要做并行方式验证和调优?


是基于你的训练GPU,训练集群的inter-gpu带宽,inter-node带宽,模型大小,训练数据,选择一个最优吞吐率的并行方案(人工加粗加黑,集群不同并行方案就不同)。别人在别人机器上的验证并行方式不一定适用你自己的机器。一般做好并行方式验证调优的话吞吐率可以提升30%-50%。通过并行方式验证和调优,我们就是找到适合自己机器通讯环境的最优模型并行方式,这样可以逼近集群运算能力的上限。


如何确定使用什么样的并行策略?


如果模型不是特别大7B,13B直接数据并行Data parallel(DP)搞起(zero0,zero1,zero2往上尝试),还塞不下然后考虑加Pipeline parallel(PP)。如果模型特别大70B,130B的时候,这时候考虑加Tensor parallel(TP)到2-4-8(一般先设置为2比较合适,优先增加PP),扩大PP直到能塞下模型,如果已经能塞的下模型了,再增大PP速度会降低速度,因为需要通讯量增加了,在笔者的8节点64 A800机器上增加BS塞满显存并不能抵消通讯量带来的速度损失。如果还是实在塞不下再考虑gradient chekpointing和offload等技术。


注意:如果用了Pipeline parallel(PP),就不能用zero2和zero3。因为zero-2/3就将梯度也切分了,但在做梯度累计的时候,流水线并行的时候的需要从不同的nodes汇集梯度会需要巨量的通信开销,所以deepspeed直接禁用了zero2/3 + PP的组合。


一般单GPU batch size (micro batch size)设置为1就行了,大的batch size通过梯度累计实现,预训练过程是一个GPU计算密集型任务,GPU利用率一般都能跑满,所以加大batch size提速不大。更重要的是,显存应该用于存模型训练相关参数,比如模型参数,梯度,优化器参数等。这样可以减少整体通信开销,加快训练速度。有一种场景是可以增大micro batch size的,就是模型本来是塞不下的,我的PP从4增大到8以后可以塞下了,但空出来20%的显存,这时候我可以将micro batch size 从1变成2把多出来的20%显存占满,实验发现也可以提速5%左右,原理是GPU计算更大矩阵乘法效率会更高。


Zero3只适合机器比较少的时候玩一下,扩展到多机就很慢,因为zero3本质上其实是DP+TP,如果我们有64个GPU,就相当于64路TP了,通讯成本太高了限制了训练速度。


为啥优先用PP而不是TP(2以上)呢?


我们知道做Ring-AllReduce通讯量是和参数的通讯节点数量线性相关的。

具体可参考文章:https://zhuanlan.zhihu.com/p/617133971


这个问题和通讯量有关,PP简单来说就是横着切模型,模型切成很多块(chunk),但层还是完整的。这样每次通讯的话只需要在断点出进行块与块之间的通讯,因此PP无论切多少刀,每次只需要在两个节点之间通讯,并且不同chunk通讯是异步的,同一时间的通讯压力小。而TP是竖着切模型,每一层都切开了,通讯的话需要每一层不同块进行all reduce通讯,如果切得太多,则同一时间需要参与通讯的GPU数量就很多,训练就会慢。如果TP=2切一刀那么通讯量和PP是一样的,如果切成4份那么需要在4个节点之间进行同步通讯,通讯量直接翻倍。甚至TP如果切多了,比GPU个数还多了,就需要inter-node的通讯了,这个一般比Intra-node通讯慢很多,自然是无法承受的。



预训练常见技术:红榜


【Flash Attention】 可以提速+降显存,尤其是llama2这种序列较长的降显存非常明显,在Zero2+flash attention可以在单机8卡上放下7B模型,并且训练速度可以提速大约75%,不使用gradient checkpointing可以提速25%,相比于Zero2+gradient checkpointing的组合快了一倍。

具体可参考文章:https://zhuanlan.zhihu.com/p/645238961


【Zero优化器】显存占用:zero0>zero1>zero2>zero3 显存降低大概每一级降低10%左右(具体看模型,以及卡的个数,总之不会多特别多),速度上来看的话zer0>zero1>zero2>zero3 每一级降低5%左右(速度上感觉不到特别明显的差别)。所以huggingface 推荐 直接可以上zero2,因为速度没有慢多少,但可以降低显存比较多。从tflops来看会比较明显 zero 0 > zero 1 > zero 2 。因为优化器中的参数是最多的,一般在megatron框架中加入--use-distributed-optimizer就是zero1阶段就能升下不少显存,实测速度也不会降低。


【Pipeline Parallel(PP)和Tensor Parallel(TP)】在保证能塞得下模型的前提下,PP * TP 应该越少越好,因为多了每个GPU显存降下去了,但通讯时间延长了。但PP和TP怎么配置也是讲究的,TP应该比较小,如果从2开到4吞吐量下降13%左右,2还行明通讯瓶颈一般到2了。PP开的如果太小的话 bubble就会比较多,所以提速不明显。64卡训练13B模型比如 pp2 + tp 1 和pp1 + tp2速度差不多,但pp4 + tp1 比 pp1 + tp4快20%。所以笔者发现13B模型64卡A800训练最优设置的PP 4 + TP 1,同时PP2 + TP 2 效果也非常接近,这个也是因为降低了PP的bubble。一般建议TP大小为Hidden size / TP > 1024。并且笔者的机器是A800机器,机内带宽是比较低的,PP 和 TP 的先后不存在明显的差异。同时TP一般会搭配sequence-parallel进行使用,可以在降低显存的情况下,不增加任何通讯成本。

具体可参考文章:

https://zhuanlan.zhihu.com/p/646406772

https://zhuanlan.zhihu.com/p/626553071



在使用PP的时候我们会发现各个节点的显存占用不一样,比如PP为4的时候显存占用Rank0>Rank1>Rank2>Rank3>Rank4。这个是因为PP不同Rank激活值存储时间不一样。比如下图中的Rank0-F0节点需要存储激活值的时间要比F1/2/3要长,这部分存储的激活值增加了显存占用。同时embedining层也在Rank0节点,比中间节点多占用了一些显存。



【Selective Activation Recomputation】按照huggingface的描述可以在仅仅增加2.7%的FlOPs overhead的情况下对于GPT3可以降低70%的显存,是一个大杀器。


【梯度累积 Gradient Accumulation】 可以打印出来每次forward时间backward时间,梯度做传播的时间,如果发现做梯度传播的时间占整体训练时间1/5以上,那么可以考虑增加global batch size提升gradient accumulation,这样可以减少梯度在各个节点通讯的次数,从而显著提升吞吐量,笔者在256卡A800测试中发现,global batch size增加四倍后吞吐量增加25%,进一步提升训练效率,同时增加global batch size,需要适应调整学习率,一般可以线性增加,比如global batch size增加4倍,LR也要增加四倍,要不然训练的iteration太少,模型更新太慢的话,走不到最优点。



预训练常见技术:黑榜


【Full Activation Recomputation】 为虽然可以降显存,但是速度会慢25%,能不用的时候尽量不用。降显存完全可以用PP 和 TP,尤其是PP对通讯要求不是特别高,并且设置的得当的话还会加速。这个技术更像是一个显存严重不足场景下不得已为之的方法。


【Offload Optimizer/parameters】使用后训练速度会慢非常多,基本慢了一倍。这个技术更像是一个显存严重不足场景下的不得已为之的方法。



一些调参的takeways


如何选择常见的训练策略



不同并行策略如何选择





如何选择合适的Zero优化器

以及Deepspeed相关训练策略








如何在DB-GPT社区使用InternLM?手把手教程来啦!

2023-10-16

详解大模型评测工作流,以OpenCompass为例

2023-10-13

用AI搞一种很新的创作,用InternLM书写和代码之间的爱恨情仇!

2023-10-11

继续滑动看下一个
向上滑动看下一个

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存