开发者经验分享!手把手教你为飞桨贡献Gumbel API
在飞桨第三期黑客松活动中,湖北大学计算机与信息工程学院的研究生韩凡宇(队长)和王勇森组建了“源力觉醒”小队,为飞桨新增了 Gumbel API。
本文将由王勇森分享为飞桨新增 Gumbel API 的经验。
01
任务介绍
▎任务背景
耿贝尔(Gumbel)分布是一种极值型分布。Gumbel 分布理论认为,最大值分布的潜在适用性与极值理论有关,如果基础样本数据的分布是正态或者指数类型,Gumbel 分布就是有用的。Gumbel 分布适用于海洋、水文、气象领域,用来计算不同重现期的极端高(低)潮位。而在概率论和统计学中,Gumbel 分布常被用来模拟不同分布的样本的最大(或最小)分布。
Gumbel 概率密度图像
飞桨框架目前还未集成 Gumbel 分布,本任务的目标是对飞桨框架中现有概率分布方案进行扩展,新增 Gumbel API。增加此 API 能够扩大飞桨的应用处理范围,对飞桨来说是非常必要的。
▎设计思路
Gumbel API 的设计需要做多个方面的知识储备:
详细了解 Gumbel 分布背后的数学原理以及应用场景;
深刻了解飞桨和业界概率分布设计实现的方法和技巧。
注:“充分学习 Gumbel 背后的数学原理”这一点很重要,避免在开发时违背定理。印象最深的是在开发之初,我们没有了解到分布的 mean 如何计算,直接使用了一个不清楚的计算方式,导致在最开始的版本中,mean 的计算方式错误。最后,通过查询资料得知标准 Gumbel 分布的 mean 为负欧拉常数,于是纠正。
接下来,我将详细描述 Gumbel API 的设计思路。
命名与参数设计
API 的名称直接使用 Gumbel 分布的名称,参数保持 Gumbel 分布最原生的参数,包括“位置参数 loc”以及“尺度参数 scale”。预期 Gumbel API 的形式为:
paddle.distribution.gumbel.Gumbel ( loc, scale )
Gumbel 分布类的初始化方法
类初始化过程中,一方面要严格控制参数 loc 和 scale 的形状和数据类型。另一方面还要借助基础分布 Uniform 以及 transforms 初始化父类 TransformedDistribution。
Gumbel API 的功能
该 API 部分功能继承于 TransformedDistribution,包括 mean 均值、variance 方差、sample 随机采样、rsample 重参数化采样、prob 概率密度、log_prob 对数概率密度、entropy 熵计算等。除了官方任务要求外,我们还添加了一些其他的方法,比如 stddev 标准差和 cdf 累积分布函数等。
▎类初始化方法
■ 数据类型
上述判断实现如下:
if not isinstance(loc, (numbers.Real, framework.Variable)):
(抛出数据类型错误)
if not isinstance(scale, (numbers.Real, framework.Variable)):
(抛出数据类型错误)
if isinstance(loc, numbers.Real):
(转为 paddle 类型的 tensor)
if isinstance(scale, numbers.Real):
(转为 paddle 类型的 tensor)
此外,还要统一 loc 和 scale 的形状类型,我们选择使用 paddle.broadcast_tensors() 广播机制来进行统一。
if loc.shape != scale.shape:
self.loc, self.scale = paddle.broadcast_tensors([loc, scale])
else:
self.loc, self.scale = loc, scale
mean:均值
loc + scale * γ
variance:方差
pow( scale, 2 ) * pi * pi / 6
stddev:标准差
sqrt( variance )
cdf(value):累积分布函数
exp(-exp(-(value - loc) / scale))
rsample(shape):重参数化采样
ExpTransform()
AffineTransform(0, -ones_like(scale))
AffineTransform(loc, -scale)
chain = ChainTransform(ExpTransform(), AffineTransform(0, -ones_like(scale)), AffineTransform(loc, -scale))
chain.forward(base_distribute)
02
代码开发
参数(loc, scale)的形状,数据类型的判断; 使用基础分布 Uniform 以及 transforms 调用父类 TransformedDistribution。
if not isinstance(loc, (numbers.Real, framework.Variable)):
raise TypeError(
f"Expected type of loc is Real|Variable, but got {type(loc)}")
if not isinstance(scale, (numbers.Real, framework.Variable)):
raise TypeError(
f"Expected type of scale is Real|Variable, but got {type(scale)}"
)
if isinstance(loc, numbers.Real):
loc = paddle.full(shape=(), fill_value=loc)
if isinstance(scale, numbers.Real):
scale = paddle.full(shape=(), fill_value=scale)
if loc.shape != scale.shape:
self.loc, self.scale = paddle.broadcast_tensors([loc, scale])
else:
self.loc, self.scale = loc, scale
finfo = np.finfo(dtype='float32')
self.base_dist = paddle.distribution.Uniform(
paddle.full_like(self.loc, float(finfo.tiny)),
paddle.full_like(self.loc, float(1 - finfo.eps)))
self.transforms = ()
super(Gumbel, self).__init__(self.base_dist, self.transforms)
def rsample(self, shape):
exp_trans = paddle.distribution.ExpTransform()
affine_trans_1 = paddle.distribution.AffineTransform(
paddle.full(shape=self.scale.shape,
fill_value=0,
dtype=self.loc.dtype), -paddle.ones_like(self.scale))
affine_trans_2 = paddle.distribution.AffineTransform(
self.loc, -self.scale)
return affine_trans_2.forward(
exp_trans.inverse(
affine_trans_1.forward(
exp_trans.inverse(self._base.sample(shape)))))
▎其他属性、方法实现
多次使用仅支持动态图的 tensor,最终使用 paddle.full 替代; 无法在初始化方法中将多个 transforms 作用到基础分布 Uniform 上。最终选择在 rsample 方法中直接将多个 transforms 作用到 Uniform 上; 测试类中,使用 Numpy 实现方法的数据类型和飞桨数据类型没有统一。
导入及初始化 Gumbel 类
from paddle.distribution.gumbel import Gumbel
loc = paddle.full([1], 0.0)
scale = paddle.full([1], 1.0)
dist = Gumbel(loc, scale)
使用 rsample(shape) / sample(shape)进行随机采样/重参数化采样,需要指定采样的 shape。
dist.sample(shape)
# Tensor(shape=[2, 1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [[-0.27544352], [-0.64499271]])
dist.rsample(shape)
# Tensor(shape=[2, 1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [[0.80463481], [0.91893655]])
如何使用 GitHub 进行代码协作开发; 更加熟悉国内顶尖的飞桨深度学习框架; 深刻掌握了包括 Gumbel 分布在内的多个数学分布; 锻炼了使用 Python 开发的能力。
rsample 方法中使用多个 transforms,操作繁琐; 测试方法时,随机采样和重参数化采样的精度未能降低到1e-3以下。未来将尝试使用其他方式来实现属性方法,从而降低采样精度; API 的__init__方法存在优化的余地,未来可以将 rsample 中的 transforms 整合到初始化方法中; 未能实现 InverseTransform; 未能将多个参数的验证抽取出来,使用特定的参数验证方法进行验证,比如 _validate_value(value); 未来可以尝试增加 kl_divergence 计算两个 Gumbel 分布的差异。