训练精度媲美AlphaFold2、速度翻倍,飞桨螺旋桨HelixFold训练和推理代码全面开源
2021年7月15日,DeepMind 公司在 Nature 杂志上发表了题为“Highly accurate protein structure prediction with AlphaFold”的文章,系统介绍了一种端到端的从蛋白质序列预测蛋白质三维结构的神经网络算法——AlphaFold2。该算法预测的蛋白质结构能达到原子水平的准确度,被 Science 评选为2021年十大科学突破之首。虽然 DeepMind 公司开源了 AlphaFold2 推理代码,但是其训练代码一直未开源。从 DeepMind 公司发表的 AlphaFold2 论文看,完整从头训练 AlphaFold2 需要使用128张 TPUv3 训练11天,对计算资源的消耗是巨大的。科研机构和普通公司想要基于 AlphaFold2 探索解决蛋白领域的更多问题,例如蛋白质设计,新靶点发现等,也更加困难。因此,如何搭建一套性能更优、更加节省算力资源、支持适配国产硬件的蛋白结构预测模型,就成为亟待解决的问题。
在飞桨强大的高性能并行计算能力支持下,飞桨螺旋桨 PaddleHelix 生物计算团队发布了蛋白结构预测模型 HelixFold,围绕着显存峰值、训练速度、分布式策略进行了全面性能优化。通过与原版 AlphaFold2 模型和哥伦比亚大学 Mohammed AlQuraishi 教授团队基于 PyTorch 复现的 OpenFold 模型的性能对比测试显示,HelixFold 模型的训练性能相比 AlphaFold2 提升106.97%,相比 OpenFold 提升104.86%。
HelixFold 之所以能够得到如此大的性能提升,源于如下几项技术创新:
分支并行与混合并行策略
AlphaFold2 在使用 TPUv3 训练模型时,每张卡上的 batch size 只设置为1,限制了数据样本维度扩卡加速训练的可能性。HelixFold 创新性的提出分支并行(Branch Parallelism, BP)策略,将不同的网络模型分支放在不同的卡上并行计算,从而在 initial training 阶段大幅提高了模型并行效率和训练速度。并且,分支并行与已有的动态轴并行(Dynamic Axial Parallelism, DAP)和数据并行(Data Parallelism,DP)结合使用,通过 BP-DAP-DP 三维混合并行,进一步加快了模型的整体训练速度。
算子融合优化技术和张量融合低频次访存技术
针对 AlphaFold2 中 Gated Self-Attention 小算子组合 CPU 调度开销大、模型参数小、参数个数多的问题,HelixFold 将 Gated Self-Attention 整个模块融合用一个算子实现,将CPU 调度开销优化到极致。同时,将数千个小张量融合成一个连续的大张量,模型参数的梯度、优化器状态都相应更新,大幅减少了访存次数、CPU 调度开销和显存碎片,从而提升了训练速度。
多维度显存优化方案
采用 Recompute、BFloat16、显存复用、Subbatch(Chunking)等技术,将显存峰值降低到 40G 以内,同时支持 MSA 长度为512、ExtraMSA 长度为5120、残基序列长度为384的最大模型配置的微调训练,从而解决了模型结构深,中间结果计算量大,ExtraMSAStack 输入过长等导致无法训练的问题。
在性能大幅度提升的同时,HelixFold 从头端到端完整训练可以达到 AlphaFold2 论文媲美的精度。在包含87个蛋白的 CASP14 数据集和包含371个蛋白的 CAMEO 数据集上,HelixFold 模型 TM-score 指标分别达到0.8771和0.8885,与原版 AlphaFold2 准确率相当甚至更优
HelixFold 是运用飞桨的高性能计算技术,显著提升模型性能的典型案例。不仅如此,飞桨与曙光 AC 智算平台深度合作,将 HelixFold 在曙光 AC 智算平台全面部署上线,通过曙光智算中心对外提供服务。同时,飞桨螺旋桨也正在全力支持“先导杯” AI for science 赛道的比赛,希望能对参赛选手们有所启发。激发大家在 AI for science 领域的更多探索。也欢迎大家在曙光 AC 智算平台调用 HelixFold 模型。
HelixFold 端到端训练和推理代码现已全面向社区开源。
GitHub 地址:
https://github.com/PaddlePaddle/PaddleHelix/tree/dev/apps/protein_folding/helixfold
更多性能优化细节和数据分析参考技术报告:
HelixFold: An Efficient Implementation of AlphaFold2 using PaddlePaddle
https://arxiv.org/abs/2207.05477