比大更大:Pathways上实现的大语言模型PaLM
摘要
不久前Google推出了多模态AI构架Pathways[1],试图一次性处理文本、图像、语音等多种形式信息,同时以更稀疏、高效的方式表达模型,以达到更敏锐、更准确的效果。从实现上来看,Pathways采用了控制平面并行执行的异步分布式数据流设计,更容易表达复杂的新并行模式。上周Google又推出了基于Pathways训练的第一个大模型PaLM[2],该模型含有5400亿参数,1-shot的训练结果在新数据集BIG-bench上达到人类平均水平。与现有NLP大模型相比,PaLM从多个角度对模型进行优化。本文深入剖析PaLM模型,旨在探讨新型AI框架下的大模型训练。
Pathways的很多重要思想来源于现有系统,包括用于表达和执行TPU计算的XLA、用于表征和执行分布式CPU计算的TensorFlow图和执行器、基于Python编程框架的JAX以及TensorFlow API。通过有效地使用这些模块,Pathways不需要对现有模型进行很多改动就能运行。
Pathways包含1)一套后台加速器,通过紧密耦合的岛(island)对主机进行分组,有向图表示为图1(a);2)资源管理器负责中心调度不同岛上的设备, 如图1(b)。客户端可以请求与之通信相匹配的具体2D或者3D网格形状的虚拟切片,每个虚拟切片允许每个客户端来决定如何在网格上进行计算,如图1(c)。而资源管理器会动态给虚拟设备按照互联拓扑结构、内存容量等来分配物理设备。
与GPT系列模型一样,PaLM模型也是通过堆叠Transformer中的Decoder部分而成。总体来看,PaLM吸引人眼球的是该模型具有5400亿参数以及采用新一代AI框架Pathways训练。此外,模型结构也给出了很多方面优化,这些技术优化工作汲取了现有突出的研究成果,具体包括SwiGLU激活函数代替ReLU、层并行技术(Parallel Layers)、多查询注意力(Multi-Query Attention)、旋转位置编码(RoPE)、共享输入和输出词嵌入、去掉偏置参数(No Biases)等。值得注意的是,SwiGLU激活函数需要三个矩阵乘操作,虽然增加了GEMM操作,但有助于提升计算性能;旋转位置编码也已经出现在之前的NLP大模型。在这些改进方案中,本文重点剖析计算方面的层并行、多查询注意力机制的优化策略。
为了实现模型并行,很多研究工作围绕Transformer中的组成模块(Attention+MLP)展开,比如大模型典范Megatron[4]将MLP、Attention中的矩阵按照行、列进行拆分,实现相应行列的并行计算。在一个前向计算过程中,需要完成两次的数据整合操作(Attention、MLP层各一次)。然而,每层计算仍然采用传统的串行方式,即先计算Attention,然后将Attention的计算结果作为输入输送给MLP层,如图2中X标注的数据流向。
与Megatron不同,PaLM从层间计算的并行角度考虑,通过将MLP和Attention共享输入实现MLP、Attention的并行计算,如图2中粗红线标注。文[2]指出这种方法能在大规模训练中获得15%的提速,而且当模型达到62B之后没有性能损失。
标准Transformer采用多注意力机制,每个时间点的输入张量与三个相同大小的权重矩阵相乘,得到同样大小的Q、K、V。PaLM保留了多头注意力机制,但对K、V在注意力头之间实现了参数共享。从实现上来看,针对n个注意力头,Q的向量维度为[n, m],其中m为每个注意力头的大小。在传统Transformer中,K、V的维度与Q相同,在PaLM中,这个维度由[n, m]缩小到[1,m],也就是说,K、V张量在不同注意力头上是不变的,如图3所示。
这种改进虽然没有带来训练提速和性能提升,但大大缩短了自回归解码时间。主要是因为标准多头注意力在自回归解码中在硬件加速器上非常低效,因为K、V张量在样本中没有共享,一次只能解码一个token。
层并行技术、共享K、V多注意力机制在文[5]已经实现,是目前PaLM首个开源实现。如果要最大程度地利用现有大算力平台,实现上也要考虑并行计算策略,如DeepSpeed,相关实现方法可参考文[6]。
PaLM模型采用了两个TPU v4 Pod来完成540B参数的训练,如图4所示。每个Pod中含有3072个TPU v4芯片,整个模型共用了6144个芯片,两个Pod间通过DCN实现数据并行。
每个Pod复制一份模型参数,每个权重张量通过12路模型并行、256路分片数据并行划分到3072个芯片上。在前向计算中,权重在数据并行的维度进行聚合,每层保存一个完整的分片激活张量。在后向计算中,其他激活函数被复制,与其他重新计算方法相比,在使用较大的batch size时,这种方法能带来更大的吞吐量。
得益于巨大的硬件配置和Pathways架构,整个训练过程中没有用其它大模型训练的流水线策略。这样避免了训练流程中的bubble,提升了整体效率。
通过上述的训练方式,PaLM在新的BIG-bench基准上达到了平均人类水平,在0-shot、few-shot上也超过了GPT3。在具体下游任务上,PaLM在复杂逻辑推理任务上取得了突破性的语言熟练能力,如图5所示的笑话解释和逻辑推理。
图5:推理任务(来源:文[2]中的图1)
作为新AI框架下的NLP大模型PaLM,充分利用了现有模型的各种改进策略,尤其是Transformer中的层并行、多头K、V共享等方法。同时,也启用了迄今为止最大的基于2-pod 的TPU配置,规模之大从而可以不采用流水线训练策略。PaLM为Pathways交出了第一份答卷,这份答卷由Jeff Dean亲自带领的近百人团队费时一年多时间完成,尽管采用了风能等可再生能源和近80%的无碳能源,仍然产生了GPT3一半的CO2排放量。模型规模之大、训练配置之精、环境影响之大再一次说明了大模型训练目前还是土豪游戏,离普适性还有非常长的一段路要走,更不用说其所谓的人类平均水平的参考标准也只是从众筹中筛选出来的10样本,样本代表性也存在质疑。但不管怎么说,PaLM毕竟为噱头上的Pathways提交了一份答卷,考评就交给各位看官了。
由于水平有限,文中存在不足的地方,请各位读者批评指正,也欢迎大家参与我们的讨论。
[1] Austin Derrow-Pinion, Jennifer She, David Wong, et al. ETA Predictionwith Graph Neural Networks in Google Maps. 2021
[1] Introducing Pathways: A next-generation AI architecture:https://blog.google/technology/ai/introducing-pathways-next-generation-ai-architecture/
[2] Chowdhery, Aakanksha, et al. "PaLM: Scaling Language Modeling with Pathways." arXiv:2204.02311 (2022).
[3] https://jalammar.github.io/illustrated-transformer/
[4] Shoeybi, Mohammad, et al. "Megatron-lm: Training multi-billion parameter language models using model parallelism." arXiv:1909.08053 (2019).
[5] https://github.com/lucidrains/Palm-pytorch
[6] https://github.com/kingoflolz/mesh-transformer-jax
往期推荐
壁仞科技研究院作为壁仞科技的前沿研究部门,旨在研究新型智能计算系统的关键技术,重点关注新型架构,先进编译技术和设计方法学,并将逐渐拓展研究方向,探索未来智能系统的各种可能。壁仞科技研究院秉持开放的原则,将积极投入各类产学研合作并参与开源社区的建设,为相关领域的技术进步做出自己的贡献。