查看原文
其他

不到 200 行代码就能微调 Llama-2!

HAOCHENYE OpenMMLab 2024-04-23

上周 Meta AI 又搞了个大新闻,发布了他们的第二代大语言模型:Llama-2,并在第一时间开源了训练、推理代码,甚至还提供了官方版和 huggingface 版,在良心开源的同时还考虑到了普通用户,真的是太酷啦。在收到消息的第一时间,我就冲进了官方仓库:llama2-recipes 打算体验一下 llama-2 的训练流程。

第一次体验时,可以看出来 code release 还是有一些仓促的,也遇到了一些小意外,没跑几个 iteration 模型就不再收敛了,在细品了一番代码之后,发现了一个小错误。他们把 epoch based 的 scheduler 按照 step based 方式去更新了,导致学习率下降的过快,没过几个迭代,学习率就趋向于 0 了。于是我就赶紧向官方反馈了这个问题:

https://github.com/facebookresearch/llama-recipes/issues/27


官方也回应神速,立马就在第二天修复了这个问题:

https://github.com/facebookresearch/llama-recipes/pull/28


之前遇到类似问题的小伙伴可以更新一下代码,保证药到病除。在解决这个小问题之后,Llama-2 就能正常训练了,再次给 Meta AI 的开源精神点赞(有图有真相)!



经历了这个小插曲后,我也差不多摸清楚了 Llama-2 的训练流程,正如论文里所说,它是用 FSDP 进行训练的,诶?FSDP,MMEngine v0.8.0 不也支持了 FSDP 训练么,于是乎,我就基于 MMEngine 的新特性实现了 Llama-2 的训练流程。


完整的训练示例见:


https://github.com/open-mmlab/mmengine/tree/main/examples/llama2

(文末点击阅读原文可直达)



实现数据类


ctrl c+ ctrl v 大法好,直接参考 llama-recipe 里 alpaca dataset 的实现。

参考代码请见:https://github.com/facebookresearch/llama-recipes/blob/1e0f8a1fb77b9ddccf649970f632dd606a22bd06/ft_datasets/alpaca_dataset.py#L28



构建 FSDPStrategy


FSDPStrategy 的构造函数会初始化分布式环境、随机种子等环境变量,因此需要放在第一步来做。Strategy 是 MMEngine v0.8.0 新引入的特性,旨在解决大模型训练的一些问题。关于 Strategy 的具体解读,大家可以期待一下后续的文章~

strategy = FSDPStrategy(    model_wrapper=dict(        auto_wrap_policy=partial(            transformer_auto_wrap_policy,            transformer_layer_cls={LlamaDecoderLayer})),    state_dict_cfg='full',    env_kwargs=dict(randomness=dict(seed=42)))



构建 dataloader、model


配置完全照搬官方 repo。需要注意的是,官方 repo 默认启用全量参数的 bf16 训练,无需混合精度训练。


# Prepare modeltokenizer = LlamaTokenizer.from_pretrained(args.checkpoint)tokenizer.add_special_tokens({'pad_token': '<PAD>'})model = LlamaForCausalLM.from_pretrained(args.checkpoint)model.to(torch.bfloat16)model.train()
# Prepare datasettrain_dataset = AlpacaDataset(    tokenizer=tokenizer, data_path=args.data_root)train_dataloader = DataLoader(    train_dataset,    batch_size=args.batch_size,    sampler=DefaultSampler(train_dataset, seed=0),    collate_fn=default_data_collator,    drop_last=True)



准备 optimizer 和 scheduler


配置与官方 repo 对齐,使用 AdamW 和 StepLR。然后把 model、scheduler、optimizer 传给 strategy,由他来处理 FSDP 相关的逻辑。


optim_cfg = dict(    optimizer=dict(type=AdamW, lr=1e-4, weight_decay=0.0),    accumulative_counts=ORI_BATCH_SIZE / args.batch_size)scheduler_cfgs = [dict(type=StepLR, step_size=1, gamma=0.85)]model, optimizer, schedulers = strategy.prepare(    model,    optim_wrapper=optim_cfg,    param_scheduler=scheduler_cfgs,    dispatch_kwargs=dict(max_iters=max_iters, max_epochs=args.max_epoch))



实现 train-loop


用了 strategy,我们就能脱离 Runner,自由自在地实现训练逻辑啦,和原生的 PyTorch 很像有木有?


for epoch in range(args.max_epoch):    for idx, inputs in enumerate(train_dataloader):        # Convert inputs to target device.        inputs = apply_to(inputs, lambda m: isinstance(m, torch.Tensor),                          lambda m: m.cuda())
       loss = model(**inputs).loss        optimizer.update_params(loss)
       max_memory = torch.cuda.max_memory_allocated()        strategy.logger.info(f'Epoch: {epoch+1}/{args.max_epoch}, '                             f'Iter: {idx+1}/{epoch_length}, '                             f'Loss: {loss.item():.3f}, '                             f'Lr: {optimizer.get_lr()["lr"][0]:.6f} '                             f'Memory: {max_memory/1e9:.3f}G')        visualizer.add_scalars({'loss': loss.item()})
       torch.cuda.reset_peak_memory_stats()
   for scheduler in schedulers:        scheduler.step()
   save_dir = f'{args.output_dir}/epoch_{epoch+1}'    state_dict = model.state_dict()
   if is_main_process():        model.save_pretrained(save_dir, state_dict=state_dict)        tokenizer.save_pretrained(save_dir)


不过脱离 Runner 也有一些麻烦,我们就得手动地更新学习率、打印日志、记录日志、保存权重了。



总结


感兴趣的同学可以来 MMEngine 体验一下训练的 example,欢迎大家多多反馈。要是大家 star 给力,我们也会尽快提供使用 DeepSpeed 和 ColossalAI 微调的示例,以及中文数据集训练的示例~


https://github.com/open-mmlab/mmengine

(欢迎大家使用,觉得好用欢迎点亮小星星)



大模型社区再掀波澜,Meta重磅开源LLAMA-2,性能升级可商用

2023-07-19

2023 年了,大模型训练还要不要用 PyTorch 的 FSDP ?

2023-07-17

超级视客营再度启航,展现编程实力,共同塑造开源未来!

2023-07-21


继续滑动看下一个
向上滑动看下一个

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存