关于医疗LLM的随笔
本周在思考基于医疗LLM的应用侧的工作,限于与司内业务关系较强,暂不通过文章讨论。应用侧的工作基本理清之后,接下来需要思考模型侧的问题,简而言之,如何得到一个中文医疗领域的LLM?
为了回答这个问题,首先简要梳理大模型训练的一些关键问题,之后梳理目前能够看到的中文通用大模型和英文医疗大模型以及特定领域的大模型的相关工作,最后给出得到一个中文医疗领域LLM的三步走策略。
1.大模型训练的关键问题
在2019年7月份,写过一篇博客《Pytorch用于大模型训练》(https://zhpmatrix.github.io/2019/07/18/speed-up-pytorch/),时间一晃来到2023年的4月份,“大模型训练”已经比大概3年多之前更加地被人思考和讨论。
大模型训练要解决的两个关键问题分别是把一个大的模型能够装进你的GPU显存和提升训练吞吐。
针对提升训练吞吐,主要以3D并行为主,分别是数据平行,Pipeline并行和Tensor并行。
数据并行。假设有M台机器,将相同的模型复制在其中的N台机器上,训练数据分为N份,每台机器同时处理(前向和反向计算)1/N的数据。在其他的机器上做N台机器的参数的聚合和分发,本质上是MapReduce的思路。Tensorflow和Pytorch中均有实现。
Pipeline并行。数据并行是split数据,Pipeline是split模型。由于模型的layer之间具备天然的顺序性,也就是下一步计算的开始必须等待上一步计算的结束,这样会导致GPU的利用率较低(bubble overhead cost)。为了提升Pipeline的效率,一些工作,比如GPipe和PipeDream,主要围绕padding多个batch的数据和异步梯度更新来做。
Tensor并行。Tensor是split参数矩阵。假设模型某层的表达是Y=AX,可以转化为Y=[A_1X,A_2X],通过把A_1和A_2放在不同的GPU上,可以实现矩阵计算的并行化,之后基于GPU通信实现不同计算结果的聚合,所以NVLink很需要。Tensor并行在现在的主流的大模型训练框架Megatron-LM和Colossal-AI中均有实现。
针对显存优化,主要以ZeRO和混合精度训练为主。
ZeRO。数据并行时,每个GPU上主要存储三类数据,分别是模型参数,模型梯度和优化器参数。但是实际上每个GPU上可以只存储部分必需的数据,需要其他数据的时候可以去其他GPU上做检索,本质上是以通信时间换显存空间。Pytorch中的对应实现是FSDP(DP->DDP->FSDP)。
混合精度训练。相比标准的32bit浮点数计算,在显存占用和通信开销上都比16bit的高,因此采用混合精度训练的方式,能够实现一定程度上的显存优化,通过特殊的显卡比如A100的支持,能够一定程度上降低低精度训练带来的模型效果损耗。
如果模型的部分能够放在CPU上完成计算,也是一种显存优化的技术。
从另外一个角度看,假设我们在fine-tuning大模型的时候,只去fine-tuning一部分,如果能够在效果上comparable全部fine-tuning模型参数,也是一个不错的解题思路。这类技术称为PEFT(Parameter-Efficient Fine-Tuning),典型包括LoRA,Prefix-Tuning,P-Tuning和Prompt-Tuning。其中LoRA是关于压缩感知的艺术,而后三个方向的工作则都是关于Prompt的技术(好吧,这个方向上的工作有点挤......)。
LoRA。直接采用原始论文的一句话作为解释,"LoRA allows us to train some dense layers in a neural network indirectly by optimizing rank decomposition matrices of the dense layers’ change during adaptation instead, while keeping the pre-trained weights frozen.",具体如下图所示:
Prefix-Tuning。直接采用原始论文的一句话作为解释。"keeps language model parameters frozen and instead optimizes a sequence of continuous task-specific vectors, which we call the prefix. Prefix-tuning draws inspiration from prompting for language models, allowing subsequent tokens to attend to this prefix as if it were “virtual tokens”.",具体如下图所示:
Prefix-Tuning经过两个版本的迭代,区别如下所示:
P-tuning。借用唐杰老师的话,“we show that GPTs can be better than or comparable to similar-sized BERTs on NLU tasks with a novel method P-tuning -- which employs trainable continuous prompt embeddings。.”,如下图所示:
Prompt-Tuning。采用原始文章中的话,"In this work, we explore "prompt tuning", a simple yet effective mechanism for learning "soft prompts" to condition frozen language models to perform specific downstream tasks. Unlike the discrete text prompts used by GPT-3, soft prompts are learned through backpropagation and can be tuned to incorporate signal from any number of labeled examples.",如下图所示:
假设显存问题和吞吐问题得到一定程度的解决,我们是不是可以训练更大的模型了呢?MoE隆重出场。简而言之,MoE通过Gating机制实现forward计算的路由。假设有N个专家,模型总量虽然扩大了N倍,能力扩大了N倍,但是实际推理路径还是1个专家而非N个专家。
2.训练一个医疗LLM
为了训练一个医疗LLM,先看下别人是怎么做的?
ChatDoctor。英文版,一个基于LLaMA的在医疗领域的大模型。按照Meta的文章,LLaMA-13B在大多数benchmark上都优于GPT3-175B(可以认为等价于ChatGPT),LLaMA-65B能够媲美最好的模型,Chinchilla-70B和PaLM-540B。另外一个基于LLaMA的模型是Alpaca,Alpaca主要构建了一个instruction数据集(基于self-instruct),基于LLaMA做了微调。具体形式如下:
{
"instruction": "Render a 3D model of a house",
"input": "",
"output": "<nooutput> This type of instruction cannot be fulfilled by a GPT model."
},
{
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
"input": "He finnished his meal and left the resturant",
"output": "He finished his meal and left the restaurant."
},
{
"instruction": "How did Julius Caesar die?",
"input": "",
"output": "Julius Caesar was assassinated by a group of up to 60 conspirators, led by Gaius Cassius Longinus and Marcus Junius Brutus, in the Senate House on the Ides of March (15 March) of 44 BC."
}
ChatDoctor首先基于Alpaca的数据集微调了LLaMA,因此也可以认为是将Alpaca作为基础模型,另外构建了200k+15k+5k的患者和医生的对话数据作为微调数据。在国内,一种是人工构造数据,另外一种非人工的对话数据获取可以使用好大夫等网站的问诊数据,如果对标ChatDoctor,大概率能够获取相同数量级的数据。
LMFlow。英文版,基于PubMedQA和MedMCQA构建了instruction数据集,使用8*A100,微调约16个小时得到LLaMA-33B(LoRA)模型,在医疗问题上的效果优于ChatGPT。基于对话构建数据集的时候,需要思考当前数据集是否适合构建instruction,比如第一届智能对话诊疗评测比赛的数据是否能用?如何使用?作为对比,CBLUE中的医学段落检索的任务数据在使用的时候似乎更加地直接,其他的数据集需要进一步思考。
Baize。英文版。模型调优的数据来自Quora和Medical对话,基于LLaMA-7B微调得到。
Med-PaLM。英文版,Google的工作。在PaLM的基础上加了instruction prompt tuning,具体工作见《Large Language Models Encode Clinical Knowledge》,不是很常见。
上述是英文版的医疗大模型,如果要做中文版的医疗大模型,可选的一些中文基础模型包括:
ChatGLM-6B。智源的工作,中英双语模型,训练数据由自己构建。
BELLE。链家的工作,模型调优的数据仅仅来自ChatGPT,提供两个模型,分别是BLOOM和LLaMA。
此外,非医疗垂直领域的模型,包括金融领域的LLM,如下:
BloombergGPT。英文版,50B,363Btoken的金融数据,345Btoken的通用数据。
所以,训练一个中文医疗LLM的Final Answer是什么?分为三步走,如下:
(1)基于一个中文的通用LLM(7B<)+SFT(基于高质量对话数据)+不做RLHF
(2)自己训练一个中文医疗通用LLM(7B<)+SFT(基于高质量对话数据)+不做RLHF
(3)自己训练一个中文医疗通用LLM(7B<)+SFT(基于高质量对话数据)+做RLHF
整体的资源上,8*A100可能是一个比较理想的硬件配置。
扫码加笔者好友,茶已备好,等你来聊,