查看原文
其他

开发者经验分享!手把手教你为飞桨贡献Gumbel API

飞桨 百度AI 2023-03-16


在飞桨第三期黑客松活动中,湖北大学计算机与信息工程学院的研究生韩凡宇(队长)和王勇森组建了“源力觉醒”小队,为飞桨新增了 Gumbel API。


本文将由王勇森分享为飞桨新增 Gumbel API 的经验。



 01 

 任务介绍 


▎任务背景


耿贝尔(Gumbel)分布是一种极值型分布。Gumbel 分布理论认为,最大值分布的潜在适用性与极值理论有关,如果基础样本数据的分布是正态或者指数类型,Gumbel 分布就是有用的。Gumbel 分布适用于海洋、水文、气象领域,用来计算不同重现期的极端高(低)潮位。而在概率论和统计学中,Gumbel 分布常被用来模拟不同分布的样本的最大(或最小)分布。


Gumbel 概率密度图像


飞桨框架目前还未集成 Gumbel 分布,本任务的目标是对飞桨框架中现有概率分布方案进行扩展,新增 Gumbel API。增加此 API 能够扩大飞桨的应用处理范围,对飞桨来说是非常必要的。


▎设计思路


Gumbel API 的设计需要做多个方面的知识储备:


  • 详细了解 Gumbel 分布背后的数学原理以及应用场景;

  • 深刻了解飞桨和业界概率分布设计实现的方法和技巧。


注:“充分学习 Gumbel 背后的数学原理”这一点很重要,避免在开发时违背定理。印象最深的是在开发之初,我们没有了解到分布的 mean 如何计算,直接使用了一个不清楚的计算方式,导致在最开始的版本中,mean 的计算方式错误。最后,通过查询资料得知标准 Gumbel 分布的 mean 为负欧拉常数,于是纠正。


接下来,我将详细描述 Gumbel API 的设计思路。


  • 命名与参数设计


API 的名称直接使用 Gumbel 分布的名称,参数保持 Gumbel 分布最原生的参数,包括“位置参数 loc”以及“尺度参数 scale”。预期 Gumbel API 的形式为:

paddle.distribution.gumbel.Gumbel ( loc, scale )


  • Gumbel 分布类的初始化方法


类初始化过程中,一方面要严格控制参数 loc 和 scale 的形状和数据类型。另一方面还要借助基础分布 Uniform 以及 transforms 初始化父类  TransformedDistribution。


  • Gumbel API 的功能


该 API 部分功能继承于 TransformedDistribution,包括 mean 均值、variance 方差、sample 随机采样、rsample 重参数化采样、prob 概率密度、log_prob 对数概率密度、entropy 熵计算等。除了官方任务要求外,我们还添加了一些其他的方法,比如 stddev 标准差和 cdf 累积分布函数等。



▎类初始化方法


■ 数据类型


首先我们需要判断 loc 和 scale 的数据类型是否是飞桨支持的标量数据类型。如果是飞桨支持的标量,需要将其转为飞桨支持的 tensor 类型。

  • 上述判断实现如下:

if not isinstance(loc, (numbers.Real, framework.Variable)):
    (抛出数据类型错误)
if not isinstance(scale, (numbers.Real, framework.Variable)):
    (抛出数据类型错误)
if isinstance(loc, numbers.Real):
    (转为 paddle 类型的 tensor)
if isinstance(scale, numbers.Real):
    (转为 paddle 类型的 tensor)

此外,还要统一 loc 和 scale 的形状类型,我们选择使用 paddle.broadcast_tensors() 广播机制来进行统一。


if loc.shape != scale.shape:
        self.loc, self.scale = paddle.broadcast_tensors([loc, scale])
else:
        self.loc, self.scale = loc, scale


■ 父类调用

因为初始化方法中对基础分布进行一系列 transform 的操作,我们选择继承父类 TransformedDistribution。但在实际开发过程中,由于遇到了一些问题,我们并未在 TransformedDistribution 类中对 Uniform 进行变换,而是选择在 rsample 中进行变换。

▎API 伪代码实现

在经过以上准备,确定设计思路后,我们给出 paddle.distribution.gumbel.Gumbel 中实现的属性、方法的伪代码。

  • mean:均值

loc + scale * γ

  • variance:方差

pow( scale, 2 ) * pi * pi / 6

  • stddev:标准差

sqrt( variance )

  • cdf(value):累积分布函数

exp(-exp(-(value - loc) / scale))

  • rsample(shape):重参数化采样

ExpTransform()
AffineTransform(0, -ones_like(scale))
AffineTransform(loc, -scale)
chain = ChainTransform(ExpTransform(), AffineTransform(0, -ones_like(scale)), AffineTransform(loc, -scale))
chain.forward(base_distribute)


 02 

 代码开发 


本节介绍代码开发的过程,着重介绍在开发中遇到困难的两个部分,包括类初始化方法(__init__)和重参数化采样方法(rsample)。最后,再介绍开发过程中遇到的问题以及如何解决该问题。其他属性方法仅是将1.4节中的伪代码使用飞桨框架实现,我们将在本节末尾给出各个方法属性的实现。

▎类初始化方法:__init__

在进行此方法的开发中,我们着重关注两个方面:
  • 参数(loc, scale)的形状,数据类型的判断;
  • 使用基础分布 Uniform 以及 transforms 调用父类 TransformedDistribution。

最终__init__方法如下
def __init__(self, loc, scale):
    if not isinstance(loc, (numbers.Real, framework.Variable)):
        raise TypeError(
            f"Expected type of loc is Real|Variable, but got {type(loc)}")
    if not isinstance(scale, (numbers.Real, framework.Variable)):
        raise TypeError(
            f"Expected type of scale is Real|Variable, but got {type(scale)}"
        )
    if isinstance(loc, numbers.Real):
        loc = paddle.full(shape=(), fill_value=loc)
    if isinstance(scale, numbers.Real):
        scale = paddle.full(shape=(), fill_value=scale)
    if loc.shape != scale.shape:
        self.loc, self.scale = paddle.broadcast_tensors([loc, scale])
    else:
        self.loc, self.scale = loc, scale
    finfo = np.finfo(dtype='float32')
    self.base_dist = paddle.distribution.Uniform(
        paddle.full_like(self.loc, float(finfo.tiny)),
        paddle.full_like(self.loc, float(1 - finfo.eps)))
    self.transforms = ()
    super(Gumbel, self).__init__(self.base_dist, self.transforms)

▎重参数化采样:rsample

由于我们对 rsample 的概念模糊,在开发此部分时非常困难。在查阅资料后得知 rsample 的真正用途是重参数化技巧,即从一个分布中进行采样。而该分布是带有参数的,如果直接进行采样(采样动作是离散的,其不可微),没有梯度信息,那么在反向传播的时候就不会对参数梯度进行更新。重参数化技巧可以保证我们从分布中进行采样,同时又能保留梯度信息。

在梳理清楚 rsample 的含义后,我们参考了业界的写法,使用了两个 transform。

即 paddle.distribution.AffineTransform 和 paddle.distribution.AffineTransform,最终代码如下:


def rsample(self, shape):
    exp_trans = paddle.distribution.ExpTransform()
    affine_trans_1 = paddle.distribution.AffineTransform(
        paddle.full(shape=self.scale.shape,
                    fill_value=0,
                    dtype=self.loc.dtype), -paddle.ones_like(self.scale))
    affine_trans_2 = paddle.distribution.AffineTransform(
        self.loc, -self.scale)
    return affine_trans_2.forward(
        exp_trans.inverse(
            affine_trans_1.forward(
                exp_trans.inverse(self._base.sample(shape)))))


▎其他属性、方法实现


其他属性及方法的实现即将1.4中的伪代码使用飞桨框架实现,详细实现可以参考 Paddle/python/paddle/distribution/gumbel.py。

▎问题以及解决方法

在开发 Gumbel API 时,遇到了一些问题。

  • 多次使用仅支持动态图的 tensor,最终使用 paddle.full 替代;
  • 无法在初始化方法中将多个 transforms 作用到基础分布 Uniform 上。最终选择在 rsample 方法中直接将多个 transforms 作用到 Uniform 上;
  • 测试类中,使用 Numpy 实现方法的数据类型和飞桨数据类型没有统一。

最开始的解决办法是在各个测试方法里面将 Numpy 的数据类型转换成飞桨支持的数据类型,但是官方评审人员给出了一种更方便的解决办法:不增加类型转换逻辑,用 xrand 指定 dtype。

 03 
 成果展示 

  • 导入及初始化 Gumbel 类
 import paddle
   from paddle.distribution.gumbel import Gumbel   
   loc = paddle.full([1], 0.0)
    scale = paddle.full([1], 1.0)
    dist = Gumbel(loc,  scale)
  • 使用 rsample(shape) / sample(shape)进行随机采样/重参数化采样,需要指定采样的 shape。
   shape = [2]
    dist.sample(shape)
    # Tensor(shape=[21], dtype=float32, place=Place(gpu:0), stop_gradient=True, [[-0.27544352], [-0.64499271]])
   dist.rsample(shape)
    # Tensor(shape=[21], dtype=float32, place=Place(gpu:0), stop_gradient=True, [[0.80463481], [0.91893655]])

更多用法可在飞桨官网查看:
https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/distribution/Gumbel_cn.html#gumbel

 04 
 总结 

▎收获

本篇项目介绍了我们小组开发 Gumbel API 的历程。从学习飞桨以及业界相关内容、到设计提案,最后到 API 的开发、测试。这期间,我们学习到很多内容,比如:

  • 如何使用 GitHub 进行代码协作开发;
  • 更加熟悉国内顶尖的飞桨深度学习框架;
  • 深刻掌握了包括 Gumbel 分布在内的多个数学分布;
  • 锻炼了使用 Python 开发的能力。

▎不足&后续工作

当然,Gumbel API 还存在着一些缺点和不足,后续需要不断改进:

  • rsample 方法中使用多个 transforms,操作繁琐;
  • 测试方法时,随机采样和重参数化采样的精度未能降低到1e-3以下。未来将尝试使用其他方式来实现属性方法,从而降低采样精度;
  • API 的__init__方法存在优化的余地,未来可以将 rsample 中的 transforms 整合到初始化方法中;
  • 未能实现 InverseTransform;
  • 未能将多个参数的验证抽取出来,使用特定的参数验证方法进行验证,比如 _validate_value(value);
  • 未来可以尝试增加 kl_divergence 计算两个 Gumbel 分布的差异。

以上就是整个 API 开发的过程,虽然充满挑战,但是不断解决问题才是关键,很高兴能为飞桨框架做出一份贡献,也很高兴能有这个机会在这里给各位分享开发过程。

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存