查看原文
其他

变分自编码器:球面上的VAE(vMF-VAE)

苏剑林 PaperWeekly 2022-07-04


©PaperWeekly 原创 · 作者|苏剑林

单位|追一科技

研究方向|NLP、神经网络


变分自编码器:VAE + BN = 更好的 VAE 中,我们讲到了 NLP 中训练 VAE 时常见的 KL 散度消失现象,并且提到了通过 BN 来使得 KL 散度项有一个正的下界,从而保证 KL 散度项不会消失。事实上,早在 2018 年的时候,就有类似思想的工作就被提出了,它们是通过在 VAE 中改用新的先验分布和后验分布,来使得 KL 散度项有一个正的下界。

该思路出现在 2018 年的两篇相近的论文中,分别是《Hyperspherical Variational Auto-Encoders》[1] 和《Spherical Latent Spaces for Stable Variational Autoencoders》[2],它们都是用定义在超球面的 von Mises–Fisher(vMF)分布来构建先后验分布。某种程度上来说,该分布比我们常用的高斯分布还更简单和有趣。


KL散度消失

我们知道,VAE 的训练目标是:

其中第一项是重构项,第二项是 KL 散度项,在变分自编码器:原来是这么一回事中我们就说过,这两项某种意义上是“对抗”的,KL 散度项的存在,会加大解码器利用编码信息的难度,如果 KL 散度项为 0,那么说明解码器完全没有利用到编码器的信息。

在 NLP 中,输入和重构的对象是句子,为了保证效果,解码器一般用自回归模型。然而,自回归模型是非常强大的模型,强大到哪怕没有输入,也能完成训练(退化为无条件语言模型),而刚才我们说了,KL 散度项会加大解码器利用编码信息的难度,所以解码器干脆弃之不用,这就出现了 KL 散度消失现象。

早期比较常见的应对方案是逐渐增加 KL 项的权重,以引导解码器去利用编码信息。现在比较流行的方案就是通过某些改动,直接让 KL 散度项有一个正的下界。将先后验分布换为 vMF 分布,就是这种方案的经典例子之一。


vMF分布

vMF 分布是定义在 d-1 维超球面的分布,其样本空间为 ,概率密度函数则为:

其中 是预先给定的参数向量。不难想象,这是 上一个以 为中心的分布,归一化因子写成 的形式,意味着它只依赖于 的模长,这是由于各向同性导致的。由于这个特性,vMF 分布更常见的记法是设 ,从而:

这时候 就是 的夹角余弦,所以说,vMF 分布实际上就是以预先为度量的一种分布。由于我们经常用余弦值来度量两个向量的相似度,因此基于 vMF 分布做出来的模型,通常更能满足我们的这个需求。当 的时候,vMF 分布是球面上的均匀分布。
从归一化因子 的积分形式来看,它实际上也是 vMF 的母函数,从而 vMF 的各阶矩也可以通过 来表达,比如一阶矩为:

可以看到 在方向上跟 一致。 的精确形式可以算出来,但比较复杂,而且很多时候我们也不需要精确知道这个归一化因子,所以这里我们就不算了。
至于参数 \kappa 的含义,或许设 我们更好理解,此时 ,熟悉能量模型的同学都知道,这里的 就是温度参数,如果 越小( 越大),那么分布就越集中在 附近,反之则越分散(越接近球面上的均匀分布)。因此, 也被形象地称为“凝聚度(concentration)”参数。


从vMF采样

对于 vMF 分布来说,需要解决的第一个难题是如何实现从它里边采样出具体的样本来。尤其是如果我们要将它应用到 VAE 中,那么这一步是至关重要的。

3.1 均匀分布

最简单是 的情形,也就是 d-1 维球面上的均匀分布,因为标准正态分布本来就是各向同性的,其概率密度正比于 只依赖于模长,所以我们只需要从 d 为标准正态分布中采样一个 z,然后让 就得到了球面上的均匀采样结果。

3.2 特殊方向

接着,对于 的情形,我们记 ,首先考虑一种特殊的情况:。事实上,由于各向同性的原因,很多时候我们都只需要考虑这个特殊情况,然后就可以平行地推广到一般情形。
此时概率密度正比于 ,然后我们转换到球坐标系:

那么:

这个分解表明,从该 vMF 分布中采样,等价于先从概率密度正比于 的分布采样一个 ,然后从 d-2 维超球面上均匀采样一个 d-1 维向量 ,通过如下方式组合成最终采样结果:

,那么:

所以我们主要研究从概率密度正比于 的分布中采样。
然而,笔者所不理解的是,大多数涉及到 vMF 分布的论文,都采用了 1994 年的论文《Simulation of the von mises fisher distribution》[3] 提出的基于 beta 分布的拒绝采样方案,整个采样流程还是颇为复杂的。但现在都 2021 年了,对于一维分布的采样,居然还需要拒绝采样这么低效的方案?
事实上,对于任意一维分布 ,设它的累积概率函数为 ,那么 就是一个最方便通用的采样方案。可能有读者抗议说“累积概率函数不好算呀”、“它的逆函数更不好算呀”,但是在用代码实现采样的时候,我们压根就不需要知道 长啥样,只要直接数值计算就行了,参考实现如下:
import numpy as np

def sample_from_pw(size, kappa, dims, epsilon=1e-7):
    x = np.arange(-1 + epsilon, 1, epsilon)
    y = kappa * x + np.log(1 - x**2) * (dims - 3) / 2
    y = np.cumsum(np.exp(y - y.max()))
    y = y / y[-1]
    return np.interp(np.random.random(size), y, x)
这里的实现中,计算量最大的是变量 y 的计算,而一旦计算好之后,可以缓存下来,之后只需要执行最后一步来完成采样,其速度是非常快的。这样再怎么看,也比从 beta 分布中拒绝采样要简单方便吧。顺便说,实现上这里还用到了一个技巧,即先计算对数值,然后减去最大值,最后才算指数,这样可以防止溢出,哪怕 成千上万,也可以成功计算。
3.3 一般情形
现在我们已经实现了从 的 vMF 分布中采样了,我们可以将采样结果分解为:
同样由于各向同性的原因,对于一般的 ,采样结果依然具有同样的形式:

对于 v 的采样,关键之处是与 正交,这也不难实现,先从标准正态分布中采样一个 d 维向量 z,然后保留与 正交的分量并归一化即可:


vMF-VAE

至此,我们可谓是已经完成了本篇文章最艰难的部分,剩下的构建 vMF-VAE 可谓是水到渠成了。vMF-VAE 选用球面上的均匀分布()作为先验分布 ,并将后验分布选取为 vMF 分布:

简单起见,我们将 设为超参数(也可以理解为通过人工而不是梯度下降来更新这个参数),这样一来, 的唯一参数来源就是 了。此时我们可以计算 KL 散度项:

前面我们已经讨论过,vMF 分布的均值方向跟 一致,模长则只依赖于 d 和 ,所以代入上式后我们可以知道 KL 散度项只依赖于 d 和 ,当这两个参数被选定之后,那么它就是一个常数(根据 KL 散度的性质,当 时,它必然大于 0),绝对不会出现 KL 散度消失现象了。
那么现在就剩下重构项了,我们需要用“重参数(Reparameterization)”来完成采样并保留梯度,在前面我们已经研究了vMF的采样过程,所以也不难实现,综合的流程为:

这里的重构 loss 以 MSE 为例,如果是句子重构,那么换用交叉熵就好。其中 就是编码器,而 就是解码器,由于 KL 散度项为常数,对优化没影响,所 以vMF-VAE 相比于普通的自编码器,只是多了一项稍微有点复杂的重参数操作(以及人工调整 )而已,相比基于高斯分布的标准 VAE 可谓简化了不少了。
此外,从该流程我们也可以看出,除了“简单起见”之外,不将 设为可训练还有一个主要原因,那就是 关系到 w 的采样,而在w的采样过程中要保留 的梯度是比较困难的。


参考实现

vMF-VAE 的实现难度主要是重参数部分,也就还是从 vMF 分布中采样,而关键之处就是 w 的采样。前面我们已经给出了 w 的采样的 numpy 实现,但是在 tf 中未见类似 np.interp 的函数,因此不容易转换为纯 tf 的实现。当然,如果是torch或者 tf2 这种动态图框架,直接跟 numpy 的代码混合使用也无妨,但这里还是想构造一种比较通用的方案。

其实也不难,由于 w 只是一个一维变量,每步训练只需要用到 batch_size 个采样结果,所以我们完全可以事先用 numpy 函数采样好足够多(几十万)个 w 存好,然后训练的时候直接从这批采样好的结果随机抽就行了,参考实现如下:

def sampling(mu):
    """vMF分布重参数操作
    """

    dims = K.int_shape(mu)[-1]
    # 预先计算一批w
    epsilon = 1e-7
    x = np.arange(-1 + epsilon, 1, epsilon)
    y = kappa * x + np.log(1 - x**2) * (dims - 3) / 2
    y = np.cumsum(np.exp(y - y.max()))
    y = y / y[-1]
    W = K.constant(np.interp(np.random.random(10**6), y, x))
    # 实时采样w
    idxs = K.random_uniform(K.shape(mu[:, :1]), 010**6, dtype='int32')
    w = K.gather(W, idxs)
    # 实时采样z
    eps = K.random_normal(K.shape(mu))
    nu = eps - K.sum(eps * mu, axis=1, keepdims=True) * mu
    nu = K.l2_normalize(nu)
    return w * mu + (1 - w**2)**0.5 * nu

一个基于 MNIST 的完整例子可见:

https://github.com/bojone/vae/blob/master/vae_vmf_keras.py
至于 vMF-VAE 用于 NLP 的例子,我们日后有机会再分享。本文主要还是以理论介绍和简单演示为主。


文章小结

本文介绍了基于 vMF 分布的 VAE 实现,其主要难度在于 vMF 分布的采样。总的来说,vMF 分布建立在余弦相似度度量之上,在某些方面的性质更符合我们的直观认知,将其用于 VAE 中,能够使得 KL 散度项为一个常数,从而防止了 KL 散度消失现象,并且简化了 VAE 结构。

参考文献

[1] https://arxiv.org/abs/1804.00891
[2] https://arxiv.org/abs/1808.10805
[3] https://www.tandfonline.com/doi/abs/10.1080/03610919408813161


更多阅读




#投 稿 通 道#

 让你的论文被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学习心得技术干货。我们的目的只有一个,让知识真正流动起来。


📝 来稿标准:

• 稿件确系个人原创作品,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向) 

• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接 

• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志


📬 投稿邮箱:

• 投稿邮箱:hr@paperweekly.site 

• 所有文章配图,请单独在附件中发送 

• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通



🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧



关于PaperWeekly


PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。



您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存