查看原文
其他

学术前沿 | TorchOpt: PyTorch原生的高性能高阶可微优化库

刘博 北京大学人工智能研究院
2024-09-16

 导读 

 

本文介绍由北京大学人工智能研究院杨耀东课题组主导的开源项目项目TorchOpt,该项目已经被PyTorch官宣纳入生态(https://x.com/pytorch/status/1676616200468455424?s=46&t=5QijikOFbVVUQv2TenMWOw),并被机器学习顶级期刊Journal of Machine Learning Research (JMLR) 接收发表。


PyTorch官宣将TorchOpt纳入生态

TorchOpt是一个基于 PyTorch 的库,其统一的编程抽象、高性能的分布式执行运行时以及对多种微分模式的支持,为可微优化带来了革命性的变革。

您可以在 GitHub 上找到 TorchOpt,地址为 :

https://github.com/metaopt/torchopt(点击下方阅读原文跳转)



01

介绍 TorchOpt

可微编程已经改变了机器学习(ML)领域,它在高级语言中实现了自动计算导数。从神经网络的反向传播到贝叶斯推断和概率编程,可微编程的广泛应用极大地推动了 ML 及其应用的进步。它实现了高效且可组合的自动微分(AD)工具,为可微优化[1, 2]、模拟器[3, 4]、工程[5]和科学[6]的发展铺平了道路。不断涌现的可微优化算法凸显了可微编程的核心地位。


欢迎使用 TorchOpt —— 一款基于 PyTorch 的高效可微优化库。您可以在 GitHub 上找到 TorchOpt,地址为 https://github.com/metaopt/torchopt。


TorchOpt 提供了:

● 多样性:TorchOpt 包含三种微分模式 —— 显式微分、隐式微分和零阶微分,满足各种可微优化需求。


●灵活性:TorchOpt 提供功能和面向对象的 API,以满足不同用户的需求。您可以使用类似于 JAX 或 PyTorch 的风格实现可微优化。


●高效性:TorchOpt 提供 CPU/GPU 加速的可微优化器、基于 RPC 的分布式训练框架以及快速树操作,极大地提高了双层优化问题的训练效率。


02

为什么选择 TorchOpt?

TorchOpt 融合了两个关键方面 —— 统一且富有表现力的可微优化编程抽象和高速分布式执行运行时。


统一且富有表现力的可微优化编程抽象


TorchOpt 提供了一种抽象,可以高效地定义和分析可微优化程序,适用于显式、隐式和零阶梯度。


TorchOpt 的微分模式。通过将问题表述为可微问题,TorchOpt 为反向传递(虚线)提供 Autograd 支持。


TorchOpt 提供了一系列低级、高级、功能和面向对象(OO)API,使用户能够将可微优化纳入 PyTorch 生成的计算图。具体来说,TorchOpt 支持处理可微优化问题的三种微分模式:

(i)显式梯度用于展开优化;

(ii)隐式梯度用于基于解决方案的迭代优化;

(iii)零阶梯度估计用于非光滑/非可微函数。


高性能和分布式执行运行时


TorchOpt 提供了高性能和分布式执行运行时,包含了几种加速解决方案,支持 GPU 和 CPU 上的快速微分,并具有多节点多 GPU 的分布式训练功能。下图显示了 TorchOpt 与其他基线在 CPU/GPU 加速 op 和分布式训练方面的比较。


TorchOpt 的性能,(a)和(b)是不同参数大小下 TorchOpt 和 PyTorch 的前向/后向时间(Adam 优化器),(c)是使用 RPC 的多 GPU 加速比与顺序实现相比。


对于 PyTorch 研究人员和开发者,TorchOpt 的特性使其能够高效地声明和分析各种可微优化程序,实现计算密集型微分操作的完全并行化,并将计算自动分配给分布式设备。


03

使用示例

让我们深入了解 TorchOpt 的两个具体使用示例。我们将引导您完成每个步骤,提供视觉或代码示例以便更好地理解。


可微优化器的热身示例



让我们从一个热身示例开始:

Assume a tensor   is a meta-parameter and   is a normal parameters (such as network parameters). We have inner loss   and we update   use the gradient   and . Then we compute the outer loss . So the gradient of outer loss to   would be:


给定上述分析解,让我们使用 TorchOpt 中的 MetaOptimizer 对其进行验证。MetaOptimizer 是我们可微优化器的主类。它与功能优化器 torchopt.sgd 和 torchopt.adam 结合,定义了我们的高级 API torchopt.MetaSGD 和 torchopt.MetaAdam。

让我们开始。首先,定义网络:


(点击放大查看)


然后我们声明网络(由 a 参数化)和元参数 x。不要忘了为 x 设置标志 requires_grad=True。


(点击放大查看)


接下来我们声明元优化器。这里我们展示了定义元优化器的两种等效方法。

(点击放大查看)


元优化器将网络作为输入,并使用方法 step 更新网络(由 a 参数化)。最后,我们展示了一个双层过程如何工作。




使用 TorchOpt 实现模型无关元学习

(MAML)


让我们从 Model-Agnostic Meta-Learning(MAML)算法的核心思想开始。MAML 是一种元学习算法,它在模型上是无关的,因为它与用梯度下降训练的任何模型兼容,并且适用于各种不同的学习问题,包括分类、回归和强化学习。元学习的目标是在各种学习任务上训练模型,使其能够仅使用少量训练样本解决新的学习任务。


在 MAML 方法中,模型在各种任务上进行训练,然后在新任务的少量数据上进行一次或几次梯度步骤的微调。MAML 的关键见解是训练初始模型,使得这些微调步骤在新任务上具有良好的泛化性能。


MAML 中的更新规则定义如下:


给定微调步骤的学习率 alpha,theta 应最小化


(左右滑动查看完整公式)


优化此目标是元训练过程的目标。给定任务分布 p(T) 中的任务 i,使用任务 i 的损失 L_i 对模型参数 theta 进行一个或多个梯度下降步骤,得到任务特定参数 theta_i'。更新规则写为 theta_i' = theta - alpha * grad(L_i(theta))其中 alpha是学习率,grad表示梯度。


在批次中的每个任务进行此更新后,使用批次中所有任务 i 的损失 L_i的和更新模型参数 theta其中使用任务特定参数 theta_i'计算损失L_i。此更新规则写为 theta = theta - beta * grad(sum_i(L_i(theta_i'))其中 beta是学习率,grad 表示梯度。


在这里,alpha beta 是决定梯度下降更新步长的超参数。学习率 alpha通常选择较小,以便模型可以快速适应每个任务,而学习率 beta 通常选择较大,以便模型可以有效地从任务分布中学习。


现在,让我们解释使用 TorchOpt 提供的 MAML 算法在强化学习中实现的代码示例。


我们首先定义与任务、轨迹、状态、操作和迭代相关的一些参数。

(点击放大查看)


接下来,我们定义一个名为 Traj 的类来表示轨迹,其中包括观察到的状态、采取的操作、采取操作后观察到的状态、获得的奖励以及折现未来奖励的 gamma 值。

(点击放大查看)


然后,我们定义一个名为 sample_traj 的函数,用于在给定环境、任务、策略和参数的情况下生成轨迹。该函数模拟策略和环境之间在 TRAJ_LEN 步内的交互。

(点击放大查看)


a2c_loss 函数用于计算 Actor-Critic(A2C)算法的损失。A2C 算法是一种策略梯度方法,使用价值函数(评论家)来减小策略梯度(演员)的方差。


(点击放大查看)


evaluate函数用于评估策略在不同任务上的性能。它使用内部优化器在每个任务上对策略进行微调,然后计算微调前后的奖励。

(点击放大查看)


main 函数中,我们初始化环境、策略和优化器。策略是一个简单的 MLP,它在操作上输出一个分类分布。内部优化器用于在微调阶段更新策略参数,外部优化器用于在元训练阶段更新策略参数。通过微调前后的奖励来评估性能。每个外部迭代的训练过程都会被记录并打印。

(点击放大查看)


总之,本代码示例展示了如何使用 TorchOpt 为强化学习任务实现 MAML 算法。MAML 算法以与梯度下降训练的任何模型兼容的灵活方式实现,使其成为元学习任务的强大工具。


04

前瞻性声明


TorchOpt 是一个新颖且高效的基于 PyTorch 的可微优化库。我们的实验结果突显了 TorchOpt 作为支持 PyTorch 中具有挑战性的梯度计算的用户友好、高性能和可扩展库的潜力。我们计划未来支持更复杂的差分模式,并涵盖更多非平凡的梯度计算问题。TorchOpt 已经证明对元梯度研究非常有用,我们有信心它可以作为更广泛范围的可微优化问题的关键自动微分工具。

我们对 TorchOpt 的潜力充满热情,并致力于其持续的开发和改进。我们欢迎社区反馈和贡献,以帮助我们使 TorchOpt 更好。请继续关注未来几个月的更多更新和功能!


05

致谢


项目论文的四位一作作者,分别是来自爱丁堡大学的任杰伦敦大学学院的冯熙栋新加坡国立大学的刘博,以及北京大学的潘学海。该项目的通讯作者分别是爱丁堡大学的麦络助理教授和北京大学的杨耀东助理教授。


1.JAXopt *[7] 库其为隐式梯度微分设计的精美 API,给我们带来了很大启发。它在硬件加速、批处理和可微优化解决方案方面的方法,为我们在有效管理优化问题方面提供了重要见解。


2.Optax *[8],其关注功能编程和梯度处理,已经成为我们工作的基础。它将底层元素组合成自定义优化器的方式,激发了我们设计自己的功能 API,大大提高了我们项目的效率。


3.Betty [9],一种用于泛化元学习和多层次优化的自动微分库,对我们也有很大价值。尽管没有直接集成到我们的项目中,但它的功能为我们提供了有用的见解,并有助于我们自己的 TorchOpt 库功能的概念化和设计。


References


[1] Liu, B., Feng, X., Ren, J., Mai, L., Zhu, R., Zhang, H., … & Yang, Y. (2022). A theoretical understanding of gradient bias in meta-reinforcement learning. Advances in Neural Information Processing Systems, 35, 31059–31072.
[2] Finn, C., Abbeel, P., & Levine, S. (2017, July). Model-agnostic meta-learning for fast adaptation of deep networks. In International conference on machine learning (pp. 1126–1135). PMLR.
[3] Hu, Y., Anderson, L., Li, T. M., Sun, Q., Carr, N., Ragan-Kelley, J., & Durand, F. (2019). Difftaichi: Differentiable programming for physical simulation. arXiv preprint arXiv:1910.00935.
[4] Freeman, C. D., Frey, E., Raichuk, A., Girgin, S., Mordatch, I., & Bachem, O. (2021). Brax — A Differentiable Physics Engine for Large Scale Rigid Body Simulation. arXiv preprint arXiv:2106.13281.
[5] Schoenholz, S., & Cubuk, E. D. (2020). Jax md: a framework for differentiable physics. Advances in Neural Information Processing Systems, 33, 11428–11441.

[6] Raissi, M., Perdikaris, P., & Karniadakis, G. E. (2019).  Physics-informed neural networks: A deep learning framework for solving forward and inverse problems involving nonlinear partial differential equations. Journal of Computational physics, 378, 686–707.
[7] Blondel, M., Berthet, Q., Cuturi, M., Frostig, R., Hoyer, S., Llinares-López, F., … & Vert, J. P. (2022). Efficient and modular implicit differentiation. Advances in neural information processing systems, 35, 5230–5242.
[8] Babuschkin, I., Baumli, K., Bell, A., Bhupatiraju, S., Bruce, J., Buchlovsky, P., Budden, D., Cai, T., Clark, A., Danihelka, I., Dedieu, A., Fantacci, C., Godwin, J., Jones, C., Hemsley, R., Hennigan, T., Hessel, M., Hou, S., Kapturowski, S., … Viola, F. (2020). The DeepMind JAX Ecosystem. http://github.com/deepmind.*

[9] Choe, S. K., Neiswanger, W., Xie, P., & Xing, E. (2022). Betty: An automatic differentiation library for multilevel optimization. arXiv preprint arXiv:2207.02849.


—   往期发布  —







学术前沿 | 全身交互三维重建:从椅子谈起

点击图片查看原文








学术前沿 | 物理“魔术”的背后

点击图片查看原文








学术前沿 | 单词学习:多模态理解和推理的基石

点击图片查看原文



—   版权声明  —

本微信公众号所有内容,由北京大学人工智能研究院微信自身创作、收集的文字、图片和音视频资料,版权属北京大学人工智能研究院微信所有;从公开渠道收集、整理及授权转载的文字、图片和音视频资料,版权属原作者。本公众号内容原作者如不愿在本号刊登内容,请及时通知本号,予以删除。

继续滑动看下一个
北京大学人工智能研究院
向上滑动看下一个

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存