深入理解生成模型VAE
设为星标,干货直达!
“What I cannot create, I do not understand.” -- Richard Feynman
说起生成模型,大家最容易想到的就是GAN,GAN是通过对抗训练实现的一种隐式生成模型。虽然GAN很强大,但其实还有很多与GAN不同的生成模型,最常见的就是基于最大化似然的模型,**变分自动编码器(Variational Autoencoder,VAE)**就属于这种类型。这篇文章将介绍VAE的原理和实现。
自动编码器(Autoencoder,AE)
再讲VAE之前,有必要先简单介绍一下自动编码器AE,自动编码器是一种无监督学习方法,它的原理很简单:先将高维的原始数据映射到一个低维特征空间,然后从低维特征学习重建原始的数据。一个AE模型包含两部分网络:
**Encoder:**将原始的高维数据映射到低维特征空间,这个特征维度一般比原始数据维度要小,这样就起到压缩或者降维的目的,这个低维特征也往往成为中间隐含特征(latent representation); Decoder:基于压缩后的低维特征来重建原始数据;
由于训练AE并不需要对数据进行标注,所以AE是一种无监督学习方法。由于压缩后的特征能对原始数据进行重建,所以我们可以用AE的encoder对高维数据进行压缩,这和PCA非常类似,当然得到的隐含特征也可以用来做一些其它工作,比如相似性搜索等。
AE有很多变种,比如经典的去噪自编码器(Denoising Autoencoder,DAE),与原始AE不同的是,在训练过程先对输入进行一定的扰动,比如增加噪音或者随机mask掉一部分特征。相比AE,DAE的重建难度增加,这也使得encoder学习到的隐含特征更具有代表性。
作为一种无监督学习方法,AE除了可以对数据降维,还可以用来对深度网络进行预训练。在深度学习早期,由于存在数据和算力限制,训练深度模型是比较困难的,所以常常采用无监督学习方法先对网络进行预训练,然后在具体的任务上进行有监督finetune,经典的工作如基于DAE的堆叠去噪自编码器(Stacked Denoising Autoencoder,SDA)和基于RBM的深度信念网络(Deep Belief Network,DBN)。
变分自动编码器(Variational Autoencoder,VAE)
VAE虽然名字里也带有自动编码器,但这主要是因为VAE和AE有着类似的结构,即encoder和decoder这样的架构设计。实际上,VAE和AE在建模方面存在很大的区别,从本质上讲,VAE是一种基于变分推断(Variational Inference, Variational Bayesian methods)的概率模型(Probabilistic Model),它属于生成模型(当然也是无监督模型)。在变分推断中,除了已知的数据(观测数据,训练数据)外还存在一个隐含变量,这里已知的数据集记为由个连续变量或者离散变量组成,而未观测的随机变量记为,那么数据的产生包含两个过程:
从一个先验分布中采样一个; 根据条件分布,用生成。
这里的指的是分布的参数,比如对于高斯分布就是均值和标准差。我们希望找到一个参数来最大化生成真实数据的概率:
这里可以通过对积分得到:
而实际上要根据上述积分是不现实的,一方面先验分布是未知的,而且如果分布比较复杂,对穷举计算也是极其耗时的。为了解决这个难题,变分推断引入后验分布来联合建模,根据贝叶斯公式,后验等于:
建模已经完成,下面我们来推导一下VAE的优化目标。对于估计的后验,我们希望它接近真实的后验分布,评估两个分布差异最常用的方式就是计算KL散度(Kullback-Leibler divergence)。对和计算KL散度,如下所示:
最终可以得到:
这里我们适当调整一下上述等式中各个项的位置,可以得到:
这里是生成真实数据的对数似然,对于生成模型,我们希望最大化这个对数似然,而是估计的后验分布和真实分布的KL散度,我们希望最小化该KL散度(KL散度为0时两个分布没有差异),所以上述等式的左边就是联合建模的最大化优化目标,这等价于最大化等式的右边。这个等式的右边又称为Evidence lower bound,简称为ELBO,这主要是因为一般称为evidence,而由于KL散度的非负性,所以有下述不等式:
所以ELBO是evidence的下限,ELBO是变分推断中经常用到的优化目标。对于VAE,ELBO取负就是其要最小化的训练目标:
对于优化目标的第二项,即计算和的KL散度,首先我们必须要对两个分布做一定的假设:
即为各分量独立的多元高斯分布(协方差矩阵为对角矩阵),那么encoder网络预测的就是高斯分布的均值和方差(实际处理时预测,因为该值是无约束的)。而先验为标准正态分布,这样就变成了计算两个多元高斯分布的KL散度。对于多元高斯分布,其概率密度函数为:
对于两个多元高斯分布,其KL散度计算推导如下:
上述公式的推导涉及到一些线性代数的知识,如矩阵的迹运算(tr),如果不明白可以参考这篇文章。根据上述公式,就可以计算出和的KL散度:
这里指的多元高斯分布分量的总数,或者说是隐变量的元素数量。实际上由于为各分量独立的多元高斯分布,这个计算可以简化为先计算单独计算各分量的的KL散度(即一元正态分布),然后对各分量的KL散度求和,因为一元正态分布的KL散度相对容易推导:
综上,对于训练数据的一个样本,其KL散度项的优化目标为:
现在我们来分析优化目标的第一项,它一般被称为重建误差(reconstruction error),因为正是给定下生成真实数据的似然(Likelihood)。对于一个给定的训练样本,我们可以采蒙特卡洛方法(Monte Carlo method)来估计这个数学期望,即从多次采样来估计:
这里的为采样的总次数,实际上在具体实现上往往,即只随机采样一次(VAE论文中说当训练的mini-batch size足够大时,采样一次是有效的)。另外一个困难的地方,从采样这个操作是无法计算梯度的,VAE采用一种重参数化(reparameterization)技巧来解决这个问题,具体地,通过引入一个额外的独立随机变量来将随机变量转变成确定变量:。由于已经假定为多元高斯分布,使用重采样技巧后则为:
直观上讲,就是首先从标准正态分布随机采样一个样本,然后乘以encoder预测的标准差,再加上encoder预测的均值,这样就能计算该损失对encoder网络参数的梯度了。
对于这个高斯分布的标准差,我们往往假定它是一个常量,而均值是由decoder预测得出:。那么则有:
这里和均是常量,而是变量的维度大小。如果忽略常量的话,那么重建误差其实就是L2损失。上面我们是假定分布是一个高斯分布,如果是一个伯努利分布即0-1分布的话,此时decoder直接预测概率值(sigmoid激活函数),重建误差就是交叉熵,:
根据上述分析,对给定的一个训练样本,其训练损失(假定是高斯分布)为:
如果把KL散度项看到一个正则化的话,那么VAE的损失函数就是重建误差+正则化,这样VAE就可以看成是一个加了约束的AE。VAE的整个训练流程如下所示:输入,encoder首先计算出后验分布的均值和标准差,然后通过重采样方法采样得到隐变量,然后送入decoder得到重建的数据。
训练完成后,我们就得到生成模型,其中就是decoder网络,而先验为标准正态分布,我们从随机采样一个,送入decoder网络,就能生成与训练数据类似的样本。
CVAE
条件变分自编码器(Conditional Variational Autoencoder,CVAE)是VAE的一个变种,相比VAE,CVAE要估计的是一个条件分布,同样地,我们引入隐变量来进行变分推断。此时,给定一个输入,从先验分布中采样一个,然后根据分布生成一个样本,因而这里要求解的生成模型是。这个生成模型可以用两个网络来学习,其中一个网络来学习先验分布,另外一个网络来学习条件分布。在VAE,我们假定先验为标准正态分布,因为不需要单独的网路来学习;而在CVAE中,先验分布是一种条件先验,如果假定是独立与的话,那么此时先验分布,更进一步地也可以简化认为先验为标准正态分布。同样地,我们另外采用一个网络来估计后验分布。同样地,我们可以推导出ELBO:
那么对于CVAE,其优化目标为:
对于上述优化目标的处理,同样地可以采用和VAE一样的分析过程,这里不再详细展开,具体见CVAE论文。下图为CVAE的一种实现方式(这里先验简化为标准正态分布):
对于VAE和CVAE,它们最重要的区别是数据是如何生成的,对于VAE,数据的产生认为是,而对于CVAE,其数据的产生是,不同的数据产生方式导致了不同的建模方式和ELBO,但两者用的变分推断理论是一致的。
VAE的代码实现
这里以MNIST数据集为例用PyTorch实现一个简单的VAE生成模型,由于MNIST数据集为灰度图,而且大部分像素点为0(黑色背景)或者白色(255,前景),所以这里可以将像素值除以255归一化到[0, 1],并认为像素值属于伯努利分布,重建误差采用交叉熵。首先是构建encoder,这里用简单的两层卷积和一个全连接层来实现,encoder给出隐变量的mu和log_var:
class Encoder(nn.Module):
"""The encoder for VAE"""
def __init__(self, image_size, input_dim, conv_dims, fc_dim, latent_dim):
super().__init__()
convs = []
prev_dim = input_dim
for conv_dim in conv_dims:
convs.append(nn.Sequential(
nn.Conv2d(prev_dim, conv_dim, kernel_size=3, stride=2, padding=1),
nn.ReLU()
))
prev_dim = conv_dim
self.convs = nn.Sequential(*convs)
prev_dim = (image_size // (2 ** len(conv_dims))) ** 2 * conv_dims[-1]
self.fc = nn.Sequential(
nn.Linear(prev_dim, fc_dim),
nn.ReLU(),
)
self.fc_mu = nn.Linear(fc_dim, latent_dim)
self.fc_log_var = nn.Linear(fc_dim, latent_dim)
def forward(self, x):
x = self.convs(x)
x = torch.flatten(x, start_dim=1)
x = self.fc(x)
mu = self.fc_mu(x)
log_var = self.fc_log_var(x)
return mu, log_var
对于decoder,基本采用对称的结构,这里用反卷积来实现上采样,decoder根据隐变量重构样本或者生成样本:
class Decoder(nn.Module):
"""The decoder for VAE"""
def __init__(self, latent_dim, image_size, conv_dims, output_dim):
super().__init__()
fc_dim = (image_size // (2 ** len(conv_dims))) ** 2 * conv_dims[-1]
self.fc = nn.Sequential(
nn.Linear(latent_dim, fc_dim),
nn.ReLU()
)
self.conv_size = image_size // (2 ** len(conv_dims))
de_convs = []
prev_dim = conv_dims[-1]
for conv_dim in conv_dims[::-1]:
de_convs.append(nn.Sequential(
nn.ConvTranspose2d(prev_dim, conv_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ReLU()
))
prev_dim = conv_dim
self.de_convs = nn.Sequential(*de_convs)
self.pred_layer = nn.Sequential(
nn.Conv2d(prev_dim, output_dim, kernel_size=3, stride=1, padding=1),
nn.Sigmoid()
)
def forward(self, x):
x = self.fc(x)
x = x.reshape(x.size(0), -1, self.conv_size, self.conv_size)
x = self.de_convs(x)
x = self.pred_layer(x)
return x
有了encoder和decoder,然后就可以构建VAE模型了,这里的实现只对隐变量通过重采样方式采样一次,训练损失为KL散度和重建误差(交叉熵)之和:
class VAE(nn.Module):
"""VAE"""
def __init__(self, image_size, input_dim, conv_dims, fc_dim, latent_dim):
super().__init__()
self.encoder = Encoder(image_size, input_dim, conv_dims, fc_dim, latent_dim)
self.decoder = Decoder(latent_dim, image_size, conv_dims, input_dim)
def sample_z(self, mu, log_var):
"""sample z by reparameterization trick"""
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
mu, log_var = self.encoder(x)
z = self.sample_z(mu, log_var)
recon = self.decoder(z)
return recon, mu, log_var
def compute_loss(self, x, recon, mu, log_var):
"""compute loss of VAE"""
# KL loss
kl_loss = (0.5*(log_var.exp() + mu ** 2 - 1 - log_var)).sum(1).mean()
# recon loss
recon_loss = F.binary_cross_entropy(recon, x, reduction="none").sum([1, 2, 3]).mean()
return kl_loss + recon_loss
模型训练完成,可以从标准正态分布随机采样,然后生成新的样本,下图为一些模型生成的样本:
代码实现见 https://github.com/xiaohu2015/nngen
总结
这篇文章简单讲述了自动编码器的原理,并重点介绍了VAE模型的原理以及它和AE之间的联系,最后给出了一个具体的VAE代码实例。VAE模型涉及比较复杂的数学建模,理解它需要花费一定的精力,这里特别感谢一些优秀的文章(见参考)。
参考
Lilian Weng blog: From Autoencoder to Beta-VAE Auto-encoding variational bayes Tutorial on variational autoencoders 变分自编码器(一):原来是这么一回事 变分自编码器(二):从贝叶斯观点出发 变分自编码器(五):VAE + BN = 更好的VAE A Beginner's Guide to Variational Methods: Mean-Field Approximation Understanding Variational Autoencoders (VAEs) CVAE: Learning Structured Output Representation using Deep Conditional Generative Models PyTorch-VAE https://keras.io/examples/generative/vae/ https://github.com/jojonki/AutoEncoders/blob/master/vae.ipynb
推荐阅读
PyTorch1.10发布:ZeroRedundancyOptimizer和Join
谷歌AI用30亿数据训练了一个20亿参数Vision Transformer模型,在ImageNet上达到新的SOTA!
"未来"的经典之作ViT:transformer is all you need!
PVT:可用于密集任务backbone的金字塔视觉transformer!
涨点神器FixRes:两次超越ImageNet数据集上的SOTA
不妨试试MoCo,来替换ImageNet上pretrain模型!
机器学习算法工程师
一个用心的公众号