在科学研究中,从方法论上来讲,都应“先见森林,再见树木”。当前,人工智能学术研究方兴未艾,技术迅猛发展,可谓万木争荣,日新月异。对于AI从业者来说,在广袤的知识森林中,系统梳理脉络,才能更好地把握趋势。为此,我们精选国内外优秀的综述文章,开辟“综述专栏”,敬请关注。
地址:https://zhuanlan.zhihu.com/p/378241241
01
概率生成模型,简称生成模型,指一系列用于随机生成可观测数据的模型,包括变分自编码器VAE和生成对抗网络GAN。具体来说,假设我们现在有一些来自未知分布 的可观测样本 ,我们想做以下两件事情:
学习一个参数化的模型 来近似未知分布 ;
基于学习到的模型 生成一些样本,使得生成样本和真实样本尽可能接近。
上述两点即对应了生成模型要完成的两件事:概率密度估计和采样。
在高维空间中,直接建模 比较困难,通常通过引入隐变量 来简化模型,如此密度估计问题就转化为估计变量 的两个局部条件概率 和 。得到上述两个局部条件概率后,生成数据 的过程如下:
根据 进行采样,得到样本 ;
根据 进行采样,得到样本 。
问题是在高维空间中, 的估计的采样也是个难题。生成对抗网络GAN借助神经网络,简单粗暴地解决了这个问题,既避免了密度估计,又降低了采样的难度。具体为:
从一个简单分布 (例如标准正态分布)中采样得到 ;
利用神经网络 使得 服从 。
也就是说,GAN并没有显示建模 ,而是建模其生成过程 ( 服从数据分布 ),因此GAN属于隐式密度模型。 成为生成网络,一个完整的GAN还需要一个判别网络 :判别网络的目的是尽量准确地判断一个样本是来自真实数据还是来自生成网络,生成网络的目的则是尽可能地欺骗判别网络,使其无法区分样本的来源。这两个网络正是“对抗”一词的来源。
02
2.1 判别网络
判别网络 的目的是区分一个样本 是来自真实分布 还是生成模型 ,因此它其实就是一个二分类器,训练二分类器最常用的损失函数为交叉熵,令判别网络 的输出表示 来自真实数据分布的概率:
注意区分这里的几个符号: 指真实数据分布; 指生成网络生成的数据服从的数据分布; 指生成网络,参数是 ,作用是将 映射到生成数据 ; 是生成网络输入 所服从的分布; 指判别网络,参数是 ,作用是判别样本 来自真实分布还是生成网络。2.2 生成网络
生成网络的目的是让判别网络将自己生成的样本判别为真实样本,即对于自己生成的样本,判别网络的输出(样本来自真实分布的概率)越大越好:
在式(5)中,判别网络的参数 是一个定值,因此我们可以再添加上一项常数项:
容易发现,待优化的部分就是判别网络的损失函数,而且其参数 在此处已经是定值,那么怎么确定这个定值呢?回忆一下,既然判别网络和生成网络是“对抗”的关系,而“对抗”的前提是双方的实力匹配,这样双方才能互相竞争对抗,共同进步。也就是说,判别网络的判别能力不能太弱,否则生成网络就没有对手,无法进步了。因此,此处参数的定值,应当是使得判别网络判别能力最好的参数值。因此,我们把两个网络结合起来看,即把式(2)、(6)统一起来,那么生成对抗网络的目标函数其实是一个最小最大化游戏:
2.3 训练
先看内层的最大化部分,固定生成网络,找出当前最优的判别网络;再看外层的最小化部分,固定判别网络,找出当前最优的生成网络。在实际训练中,通常需要平衡好两个网络的能力,每次迭代判别网络的能力应该比生成网络强一些,但又不能强太多,因此通常在一次迭代中,判别网络更新 次,生成网络才更新一次,伪代码如下所示:
03
3.1 稳定性差
这里的稳定性差指的是GAN在训练的过程中很难把握好梯度消失和梯度错误之间的平衡。我们先看看为什么会出现梯度消失的问题。先关注判别网络,若 和 已知,令式(2)的导数为零,可解得最优的判别为:
也就是说,当判别网络最优的时候,生成网络的目标是最小化分布 和 之间的 散度。当两个分布相同时 散度为零,即生成网络的最优值 对应的损失为 。然而实际情况是,当我们用诸如梯度下降等方式去最小化目标函数 的时候,生成网络的目标函数关于参数的梯度为零,无法更新。为什么会出现这种情况呢?原因是 散度本身的特性:当两个分布没有重叠的时候,它们之间的 散度恒为 。容易发现此时目标函数为0,意味着最优判别器的判别全部正确,对所有生成数据的输出均为0,因此对目标参数求导仍为0,带来了梯度消失的问题。因此在实际中,我们往往不降判别网络训练到最优,只进行 次梯度下降,以保证生成网络的梯度仍然存在。但是如果因为训练次数太少导致判别网络判别能力太差,则生成网络的梯度为错误的梯度。如何确定 这个超参数,平衡好梯度消失和梯度错误之间的平衡是个难题,这也是为什么说GAN在训练时稳定性差的原因。3.2 模型坍塌
除了稳定性差,GAN在训练的时候还容易出现模型坍塌的问题。模型坍塌指生成网络倾向于生成更“安全”的样本,即生成数据的分布聚集在原始数据分布的局部。下面我们看看为什么会出现这个问题。将最优判别网络 代入式(4),得到生成网络的目标函数为:
此时, 。其中 属于有界函数,因此生成网络的最优值更多受逆向KL散度 的影响。
什么是前向和逆向KL散度?以它们为目标进行优化会带来什么结果?我们先看看第一个问题:KL散度是一种非对称的散度,在计算真实分布 和生成分布 之间的KL散度的时候,按照顺序不同,分为前向KL散度和逆向KL散度:
当 而 时, 。意味着 的时候, 无论怎么取值都可以,都不会对前向KL散度的计算产生影响,因此拟合的时候不用回避 的点;当 而 时, 。意味着要减小前向KL散度, 必须尽可能覆盖 的点。因此,当以前向KL散度为目标函数进行优化的时候,模型分布 会尽可能覆盖所有真实分布 的点,而不用回避 的点。当 而 时, 。意味着要减小逆向KL散度, 必须回避所有 的点;当 时,无论 取什么值, 。意味着 不需要考虑考虑是否需要尽可能覆盖所有真实分布 的点。因此,当以逆向KL散度为目标函数进行优化的时候,模型分布 会尽可能避开所有真实分布 的点,而不需要考虑是否覆盖所有真实分布 的点。下图给出了当真实分布为高斯混合分布,模型分布为单高斯分布的时候,用前向KL散度和逆向KL散度进行模型优化的结果,可以发现使用逆向KL散度进行优化会带来模型坍缩的问题。因此,基于上述两个问题,GAN难训练的问题是出了名的。为了解决这些问题,后续又有人提出了各式各样的GAN,例如W-GAN,通过用Wasserstein距离代替JS散度,改善了GAN稳定性差的问题,同时一定程度上缓解了模型坍缩的问题,有兴趣的可以自己阅读paper。参考资料
1. 一文搞懂交叉熵在机器学习中的使用,透彻理解交叉熵背后的直觉
2. 《神经网络与深度学习》,邱锡鹏
本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。
“综述专栏”历史文章
更多综述专栏文章,
请点击文章底部“阅读原文”查看
分享、点赞、在看,给个三连击呗!