【中文医疗大模型】训练全流程源码剖析
笔者中文医疗大模型系列文章目录:
围绕中文医疗大模型,按照ChatGPT的训练流程做本地化适配,整体的流程如下:
从上述流程来看,整体上可以分为四个阶段,分别如下:
预训练(pre-training,pt),上述流程中,基于ChatGLM-6B的初始模型,经过海量中文医疗语料训练,得到领域适配的ChatGLM-6B
监督微调(supervised finetuning, sft),通过基于知识图谱,在线问诊等数据,构建训练数据完成指令微调
RM模型构建(reward modeling, rm),人工对预测答案排序,训练一个打分模型
强化学习阶段(reinforcement learning, rl),基于PPO算法,采用RL的方式,完成fine-tuned ChatGLM-6B模型的优化
这篇文章基于一个开源的中文医疗大模型的项目,完成全流程建模的代码阅读。项目地址为:
https://github.com/shibing624/MedicalGPT
关于为什么选择这个项目?理由如下:
流程完备,包括4个流程。虽然仍旧存在各种Bug和问题,但是按照最低配置能够跑通完整流程,对于理解原理细节,应该是够用的。但是笔者并不是基于该工作去完成实际的模型训练工作,至于是否有坑,需要自己去踩
代码实现借鉴了其他一些开源工作,保证了代码实现上的主流,在阅读代码的时候,不会觉得实现上很突兀,比如很接近transformers的风格。整体上,代码架构非常清晰干净
项目在快速进化,持续迭代
整体上按照pt,sft,rm,rl四个阶段依次简要过一遍源码。
第一阶段: PT
pt阶段的核心代码路径如下:
https://github.com/shibing624/MedicalGPT/blob/main/scripts/pretraining.py
该阶段的训练数据格式如下。其实就是非结构化的自然语言文本,通过设定max_seq_len和block_size等方式,实现文本数据的chunk,batch化,作为模型的训练数据,处理完的单条数据包含input_ids,attention_mask和labels。
核心代码结构如下,包括peft库的引入,model/data/peft核心类的定义,主要训练流程描述。其中,训练细节通过transformer[1]的trainer实现了封装。
针对大模型的高效参数微调方法,通过peft库[2]实现了统一的封装。lora作为一种经典的peft方法,不仅适用于sft阶段,同样适用于pt/rm/rl阶段。区别在于各个阶段对于lora的需求有多强,比如rm阶段,可以learn from scratch一个reward模型。
model = PeftModel.from_pretrained(model, training_args.peft_path, is_trainable=True)
peft库中与lora相关的类和方法包括LoraConfig,TaskType,PeftModel,get_peft_model和prepare_model_for_int8_training等,在使用体验上和transformers能够拉齐。比如通过上述代码,可以实现原始model的peft化。
trainer封装了大量的模型train和eval的基本实现,在代码中,可以通过两行代码实现train的过程,如下:
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = trainer.evaluate()
通过transformer的trainer打通了transformers和MedicalGPT之间的联系,该trainer写了约4000行。trainer中为了实现加速,需要accelerate[5]的支持。
https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py
transformer中包含了非常多模型的dsl,含forward函数,如bloom等。对于bloom而言,forward计算中的主要模块含transformer的标准模块,用于语言建模的lm_head,其实是一个linear层,以及计算逻辑。在这里通过shift logit,实现给定前(n-1)个token,预测第n个token的loss计算,也即auto-regressive的计算方式。代码如下所示:
实际场景下,pt阶段的初始模型也是一个带有对话能力的大模型,比如ChatGLM-6B这样的模型,而非从头开始完成pt的过程。因此,这样的条件下,如果按照上述流程,是否会存在对话能力的遗忘?如果会,如何在技术上能够减轻或者避免?换个角度,这个阶段如何能够高效实现模型能力的领域适配?
第二阶段: SFT
sft阶段的核心代码路径如下:
https://github.com/shibing624/MedicalGPT/blob/main/scripts/supervised_finetuning.py
训练数据格式如下:
整体的结构采用instruction/input/output,这里的训练数据中只包含了问答数据,所以input缺省。整体上看,sft和pt阶段并不存在本质上的差异性,均是通过auto-regressive的方式来训练模型。因此,具体代码实现和pt阶段类似。需要特别注意的是,真正用于训练的数据的构建方式,如下:
def preprocess_function(examples):
sources = []
targets = []
for instruction, input, output in zip(examples['instruction'], examples['input'], examples['output']):
if input:
instruction = instruction + '\n' + input
source = PROMPT_TEMPLATE.format_map({'instruction': instruction})
target = f"{output}{tokenizer.eos_token}"
sources.append(source)
targets.append(target)
tokenized_sources = tokenizer(sources, truncation=True, max_length=max_source_length)
tokenized_targets = tokenizer(targets, add_special_tokens=False, truncation=True, max_length=max_target_length)
all_input_ids = []
all_labels = []
for s, t in zip(tokenized_sources['input_ids'], tokenized_targets['input_ids']):
input_ids = torch.LongTensor(s + t)
labels = torch.LongTensor([IGNORE_INDEX] * (max_source_length + max_target_length - len(t)) + t)
all_input_ids.append(input_ids)
all_labels.append(labels)
results = {'input_ids': all_input_ids, 'labels': all_labels}
return results
上述代码中,将instruction和input共同形成source,output作为target,对于target的一个小细节是,需要加上eos_token,也就是target的输出结束标志符。最后通过拼接source和target,共同得到input_ids和labels(对的,这里没有attention_mask,其实即使pt阶段的attention_mask的值均为1)。在代码实现中,同时添加了一部分额外的prompt说明,如下:
PROMPT_TEMPLATE = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response: "
)
这里其实也是可选的。在多数条件下的团队是首先完成在指令微调阶段的工作。如果是基于一个比如ChatGLM-6B的模型来构建,经过该阶段的微调,大多数条件下会存在模型对话能力的丧失,这里并没有量化过模型是否存在其他能力的丧失。因此,这里会面临pt阶段一样的挑战。
针对这个问题,考虑到pt和sft阶段在训练方式上的一致性,可以将pt和sft的训练数据shuf之后去完成模型的微调,也许在一定程度上能够缓解灾难性遗忘的现象。
总之,看似简单的sft阶段,在实际微调过程中,却存在远比上述讨论到的多很多的问题,我们也积累了一些发现和心得。
第三阶段: RM
rlhf过程分为两个阶段,分别是rm阶段和rl阶段,前者的主要目的是训练得到一个rm模型用于rl阶段的模型打分。rm阶段的代码如下:
https://github.com/shibing624/MedicalGPT/blob/main/scripts/reward_modeling.py
整体上的过程会和pt/sft阶段不同。pt/sft阶段的本质是做自回归任务,训练得到一个lm。而rm阶段的本质是做一个回归/分类模型,用于打分。训练数据如下:
每个样本中包含一个question和对应于该question的response_chosen和response_rejected,其中前者表示接受的response,后者表示拒绝的response。数据的预处理代码如下:
def preprocess_reward_function(examples):
"""
Turn the dataset into pairs of Question + Answer, where input_ids_chosen is the preferred question + answer
and text_rejected is the other.
"""
new_examples = {
"input_ids_chosen": [],
"attention_mask_chosen": [],
"input_ids_rejected": [],
"attention_mask_rejected": [],
}
for question, chosen, rejected in zip(examples["question"], examples["response_chosen"],
examples["response_rejected"]):
tokenized_chosen = tokenizer("Question: " + question + "\n\nAnswer: " + chosen)
tokenized_rejected = tokenizer("Question: " + question + "\n\nAnswer: " + rejected)
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
return new_examples
实际的训练数据中,会将question和每个reponse分别拼接。基于构建的数据,采用的trainer是重写之后的,如下:
class RewardTrainer(Trainer):
"""
Trainer for reward models
Define how to compute the reward loss. Use the InstructGPT pairwise logloss: https://arxiv.org/abs/2203.02155
"""
def compute_loss(self, model, inputs, return_outputs=False):
rewards_chosen = model(input_ids=inputs["input_ids_chosen"],
attention_mask=inputs["attention_mask_chosen"])[0]
rewards_rejected = model(input_ids=inputs["input_ids_rejected"],
attention_mask=inputs["attention_mask_rejected"])[0]
loss = -torch.nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
if return_outputs:
return loss, {"rewards_chosen": rewards_chosen, "rewards_rejected": rewards_rejected}
return loss
在具体使用的模型上,pt/sft阶段均为CausalLM,rm阶段为SequenceClassification。采用负对数似然损失作为对比损失函数(pair-wise)。
实际上,这个项目中的工作是实现了InstructGPT中的做法,采用learning2rank的类似思路,还有很多方法可以得到这样的rm模型。截止目前,开源的rm并不是很多,ChatGPT也许是一个不错的rm候选。ChatGPT训练系统中的rm模型,大概在6B左右。
第四阶段: RL
rl阶段的训练代码如下所示:
https://github.com/shibing624/MedicalGPT/blob/main/scripts/rl_training.py
该阶段由于是优化一个语言模型,故训练数据同sft阶段。代码实现中除了正常引入peft,同时引入trl[3],trl是一个基于transformer的强化学习库,封装了PPO相关的实现,包括PPOConfig和PPOTrainer。基于上一阶段得到的rm模型,可以完成打分的动作,实现逻辑如下:
def get_reward_score(reward_model, reward_tokenizer, question, answer, device):
"""
Get the reward score for a given question and answer pair.
"""
inputs = reward_tokenizer(question, answer, return_tensors='pt').to(device)
score = reward_model(**inputs).logits[0].cpu().detach()
return score
rl阶段的核心训练流程如下:
for step, batch in tqdm(enumerate(trainer.dataloader)):
if step >= total_steps:
break
question_tensors = batch["input_ids"]
question_tensors = [torch.LongTensor(i).to(device).squeeze(0) for i in question_tensors]
responses = []
response_tensors = []
for q_tensor in question_tensors:
response_tensor = trainer.generate(
q_tensor,
return_prompt=False,
**generation_kwargs,
)
r = tokenizer.batch_decode(response_tensor, skip_special_tokens=True)[0]
responses.append(r)
response_tensors.append(response_tensor.squeeze(0))
batch["response"] = responses
# Compute reward score
score_outputs = [
get_reward_score(reward_model, reward_tokenizer, q, r, device) for q, r in
zip(batch["query"], batch["response"])
]
rewards = [torch.tensor(float(score) - args.reward_baseline) for score in score_outputs]
# Run PPO step
try:
stats = trainer.step(question_tensors, response_tensors, rewards)
trainer.log_stats(stats, batch, rewards)
logger.debug(f"Step {step}/{total_steps}: reward score:{score_outputs}")
except ValueError as e:
logger.warning(f"Failed to log stats for step {step}, because of {e}")
给定一个batch,首先利用rm模型计算score,该阶段的输入是真实 question以及模型预测的response。其次,将question和预测response的tensor,以及score,作为PPOTrainer的输入完成PPO的训练,主要是语言模型的参数更新。在输入之前,会有一个技术细节(代码中的reward容易引起异议,可以改为score):
rewards = [torch.tensor(float(score) - args.reward_baseline) for score in score_outputs]
进入PPOTrainer,主要的流程如下:
(1)基于self.mode计算ref_logprob,self.model也就是该阶段要去优化的模型
all_logprobs, _, values, masks = self.batched_forward_pass(self.model, queries, responses, model_inputs)
(2)基于self.ref_model计算ref_logprob,在某些条件下,self.ref_model=self.model
ref_logprobs, _, _, _ = self.batched_forward_pass(self.ref_model, queries, responses, model_inputs)
(3)基于两类logprob和传入的score,计算reward
rewards, non_score_reward = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks)
(4)train_minibatch
该阶段的输入为:
train_stats = self.train_minibatch(
batch["logprobs"],
batch["values"],
batch["rewards"],
logprobs,
logits,
vpreds,
batch["masks"],
)
核心计算逻辑如下所示:
def train_minibatch(
self,
old_logprobs: torch.FloatTensor,
values: torch.FloatTensor,
rewards: torch.FloatTensor,
logprobs: torch.FloatTensor,
logits: torch.FloatTensor,
vpreds: torch.FloatTensor,
mask: torch.LongTensor,
):
loss_p, loss_v, train_stats = self.loss(old_logprobs, values, rewards, logits, vpreds, logprobs, mask)
loss = loss_p + loss_v
self.optimizer.zero_grad()
self.accelerator.backward(loss)
if self.config.max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(
filter(lambda p: p.requires_grad, self.model.parameters()), self.config.max_grad_norm
)
t = time.time()
self.optimizer.step()
train_stats["time/ppo/optimizer_step"] = torch.Tensor([time.time() - t]).to(self.current_device)
return train_stats
总结: 思考
在读完代码的最后,借用一张Andrej Karpathy的PPT[4],再次回顾训练的全流程。除了上述项目,ColossalAI[7]和DeepSpeed[8]同样实现了完整的训练流程。
经过近一段时间的实践,也形成一些心得和发现,分享如下:
如果你准备以严肃且认真的方式做一个行业垂直模型,请先认真思考,什么是行业垂直模型?比如,什么是中文医疗大模型?什么是金融大模型?这个问题不仅仅与为什么需要行业大模型有关,又能回答边界在哪里,我们将走向何方
模型迭代一次的成本可能会比预想的高许多,即使是在<=13B这样的量级,包括显卡资源,能够完整跑通一次有效流程的人力资源,时间资源等,这启发我们要大胆假设的同时,小心求证,保持谨慎乐观
能接触到的各种大模型工作,能打的基本没有。这个问题可能与技术无关,与“大模型能打”的标准定义有关。这种缺乏标准的问题,与模型训练过程中各种现象的出现有类似之处,要能够在不确定性中寻找确定,在没有标准中找到模糊的尺子
幸运地是,在中文医疗大模型的what和why上,我们形成了自己的一些观点,同时做了一把能够驱动我们自身快速迭代和演进的尺子,也正在围绕how通过不断地实践在积累更多的认知。
相关参考
[1]https://github.com/huggingface/transformers/tree/main
[2]https://github.com/huggingface/peft
[3]https://github.com/lvwerra/trl
[4]https://karpathy.ai/stateofgpt.pdf
[5]https://github.com/huggingface/accelerate
[6]https://github.com/TimDettmers/bitsandbytes
[7]https://github.com/hpcaitech/ColossalAI/tree/main/applications/ChatlossalAI/tree/main/applications/Chat
[8]https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat
扫码加笔者好友,茶已备好,等你来聊,