带你入门扩散模型:DDPM
设为星标,干货直达!
“What I cannot create, I do not understand.” -- Richard Feynman
近段时间最火的方向无疑是基于文本用AI生成图像,继OpenAI在2021提出的文本转图像模型DALLE之后,越来越多的大公司卷入这个方向,如谷歌在今年相继推出了Imagen和Parti。一些主流的文本转图像模型如DALL·E 2,stable-diffusion和Imagen采用了扩散模型(Diffusion Model)作为图像生成模型,这也引发了对扩散模型的研究热潮。相比GAN来说,扩散模型训练更稳定,而且能够生成更多样的样本,OpenAI的论文Diffusion Models Beat GANs on Image Synthesis也证明了扩散模型能够超越GAN。简单来说,扩散模型包含两个过程:前向扩散过程和反向生成过程,前向扩散过程是对一张图像逐渐添加高斯噪音直至变成随机噪音,而反向生成过程是去噪音过程,我们将从一个随机噪音开始逐渐去噪音直至生成一张图像,这也是我们要求解或者训练的部分。扩散模型与其它主流生成模型的对比如下所示:
目前所采用的扩散模型大都是来自于2020年的工作DDPM: Denoising Diffusion Probabilistic Models,DDPM对之前的扩散模型(具体见Deep Unsupervised Learning using Nonequilibrium Thermodynamics)进行了简化,并通过变分推断(variational inference)来进行建模,这主要是因为扩散模型也是一个隐变量模型(latent variable model),相比VAE这样的隐变量模型,扩散模型的隐变量是和原始数据是同维度的,而且推理过程(即扩散过程)往往是固定的。这篇文章将基于DDPM详细介绍扩散模型的原理,并给出具体的代码实现和分析。
扩散模型原理
扩散模型包括两个过程:前向过程(forward process)和反向过程(reverse process),其中前向过程又称为为扩散过程(diffusion process),如下图所示。无论是前向过程还是反向过程都是一个参数化的马尔可夫链(Markov chain),其中反向过程可以用来生成数据,这里我们将通过变分推断来进行建模和求解。
扩散过程
扩散过程是指的对数据逐渐增加高斯噪音直至数据变成随机噪音的过程。对于原始数据,总共包含步的扩散过程的每一步都是对上一步得到的数据按如下方式增加高斯噪音:
这里为每一步所采用的方差,它介于0~1之间。对于扩散模型,我们往往称不同step的方差设定为variance schedule或者noise schedule,通常情况下,越后面的step会采用更大的方差,即满足。在一个设计好的variance schedule下,的如果扩散步数足够大,那么最终得到的就完全丢失了原始数据而变成了一个随机噪音。 扩散过程的每一步都生成一个带噪音的数据,整个扩散过程也就是一个马尔卡夫链:
上述推到过程利用了两个方差不同的高斯分布和相加等于一个新的高斯分布。反重参数化后,我们得到:
扩散过程的这个特性很重要。首先,我们可以看到其实可以看成是原始数据和随机噪音的线性组合,其中和为组合系数,它们的平方和等于1,我们也可以称两者分别为signal_rate
和noise_rate
(见https://keras.io/examples/generative/ddim/#diffusion-schedule和Variational Diffusion Models)。更近一步地,我们可以基于而不是来定义noise schedule(见Improved Denoising Diffusion Probabilistic Models所设计的cosine schedule),因为这样处理更直接,比如我们直接将设定为一个接近0的值,那么就可以保证最终得到的近似为一个随机噪音。其次,后面的建模和分析过程将使用这个特性。
反向过程
扩散过程是将数据噪音化,那么反向过程就是一个去噪的过程,如果我们知道反向过程的每一步的真实分布,那么从一个随机噪音开始,逐渐去噪就能生成一个真实的样本,所以反向过程也就是生成数据的过程。
这里,而为参数化的高斯分布,它们的均值和方差由训练的网络和给出。实际上,扩散模型就是要得到这些训练好的网络,因为它们构成了最终的生成模型。 虽然分布是不可直接处理的,但是加上条件的后验分布却是可处理的,这里有:
下面我们来具体推导这个分布,首先根据贝叶斯公式,我们有:
由于扩散过程的马尔卡夫链特性,我们知道分布(这里条件是多余的),而由前面得到的扩散过程特性可知:
所以,我们有:
这里的是一个和无关的部分,所以省略。根据高斯分布的概率密度函数定义和上述结果(配平方),我们可以得到分布的均值和方差:
可以看到方差是一个定量(扩散过程参数固定),而均值是一个依赖和的函数。这个分布将会被用于推导扩散模型的优化目标。
优化目标
上面介绍了扩散模型的扩散过程和反向过程,现在我们来从另外一个角度来看扩散模型:如果我们把中间产生的变量看成隐变量的话,那么扩散模型其实是包含个隐变量的隐变量模型(latent variable model),它可以看成是一个特殊的Hierarchical VAEs(见Understanding Diffusion Models: A Unified Perspective):
这里最后一步是利用了Jensen's inequality(不采用这个不等式的推导见博客What are Diffusion Models?),对于网络训练来说,其训练目标为VLB取负:
我们近一步对训练目标进行分解可得:
可以看到最终的优化目标共包含项,其中可以看成是原始数据重建,优化的是负对数似然,可以用估计的来构建一个离散化的decoder来计算(见DDPM论文3.3部分);而计算的是最后得到的噪音的分布和先验分布的KL散度,这个KL散度没有训练参数,近似为0,因为先验而扩散过程最后得到的随机噪音也近似为;而则是计算的是估计分布和真实后验分布的KL散度,这里希望我们估计的去噪过程和依赖真实数据的去噪过程近似一致:
对于两个高斯分布的KL散度,其计算公式为(具体推导见生成模型之VAE):
那么就有:
那么优化目标即为:
从上述公式来看,我们是希望网络学习到的均值和后验分布的均值一致。不过DDPM发现预测均值并不是最好的选择。根据前面得到的扩散过程的特性,我们有:
将这个公式带入上述优化目标,可以得到:
近一步地,我们对也进行重参数化,变成:
这里的是一个基于神经网络的拟合函数,这意味着我们由原来的预测均值而换成预测噪音。我们将上述等式带入优化目标,可以得到:
DDPM近一步对上述目标进行了简化,即去掉了权重系数,变成了:这里的在[1, T]范围内取值(如前所述,其中取1时对应)。由于去掉了不同的权重系数,所以这个简化的目标其实是VLB优化目标进行了reweight。从DDPM的对比实验结果来看,预测噪音比预测均值效果要好,采用简化版本的优化目标比VLB目标效果要好:
模型设计
前面我们介绍了扩散模型的原理以及优化目标,那么扩散模型的核心就在于训练噪音预测模型,由于噪音和原始数据是同维度的,所以我们可以选择采用AutoEncoder架构来作为噪音预测模型。DDPM所采用的模型是一个基于residual block和attention block的U-Net模型。如下所示:
代码实现
最后,我们基于PyTorch框架给出DDPM的具体实现,这里主要参考了三套代码实现:
GitHub - hojonathanho/diffusion: Denoising Diffusion Probabilistic Models(官方TensorFlow实现) GitHub - openai/improved-diffusion: Release for Improved Denoising Diffusion Probabilistic Models (OpenAI基于PyTorch实现的DDPM+) GitHub - lucidrains/denoising-diffusion-pytorch: Implementation of Denoising Diffusion Probabilistic Model in Pytorch
首先,是time embeding,这里是采用Attention Is All You Need中所设计的sinusoidal position embedding,只不过是用来编码timestep:
# use sinusoidal position embedding to encode time step (https://arxiv.org/abs/1706.03762)
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
由于只有residual block才引入time embedding,所以可以定义一些辅助模块来自动处理,如下所示:
# define TimestepEmbedSequential to support `time_emb` as extra input
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
else:
x = layer(x)
return x
这里所采用的U-Net采用GroupNorm进行归一化,所以这里也简单定义了一个norm layer以方便使用:
# use GN for norm layer
def norm_layer(channels):
return nn.GroupNorm(32, channels)
U-Net的核心模块是residual block,它包含两个卷积层以及shortcut,同时也要引入time embedding,这里额外定义了一个linear层来将time embedding变换为和特征维度一致,第一conv之后通过加上time embedding来编码time:
# Residual block
class ResidualBlock(TimestepBlock):
def __init__(self, in_channels, out_channels, time_channels, dropout):
super().__init__()
self.conv1 = nn.Sequential(
norm_layer(in_channels),
nn.SiLU(),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
)
# pojection for time step embedding
self.time_emb = nn.Sequential(
nn.SiLU(),
nn.Linear(time_channels, out_channels)
)
self.conv2 = nn.Sequential(
norm_layer(out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
)
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
else:
self.shortcut = nn.Identity()
def forward(self, x, t):
"""
`x` has shape `[batch_size, in_dim, height, width]`
`t` has shape `[batch_size, time_dim]`
"""
h = self.conv1(x)
# Add time step embeddings
h += self.time_emb(t)[:, :, None, None]
h = self.conv2(h)
return h + self.shortcut(x)
这里还在部分residual block引入了attention,这里的attention和transformer的self-attention是一致的:
# Attention block with shortcut
class AttentionBlock(nn.Module):
def __init__(self, channels, num_heads=1):
super().__init__()
self.num_heads = num_heads
assert channels % num_heads == 0
self.norm = norm_layer(channels)
self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1, bias=False)
self.proj = nn.Conv2d(channels, channels, kernel_size=1)
def forward(self, x):
B, C, H, W = x.shape
qkv = self.qkv(self.norm(x))
q, k, v = qkv.reshape(B*self.num_heads, -1, H*W).chunk(3, dim=1)
scale = 1. / math.sqrt(math.sqrt(C // self.num_heads))
attn = torch.einsum("bct,bcs->bts", q * scale, k * scale)
attn = attn.softmax(dim=-1)
h = torch.einsum("bts,bcs->bct", attn, v)
h = h.reshape(B, -1, H, W)
h = self.proj(h)
return h + x
对于上采样模块和下采样模块,其分别可以采用插值和stride=2的conv或者pooling来实现:
# upsample
class Upsample(nn.Module):
def __init__(self, channels, use_conv):
super().__init__()
self.use_conv = use_conv
if use_conv:
self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
def forward(self, x):
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
# downsample
class Downsample(nn.Module):
def __init__(self, channels, use_conv):
super().__init__()
self.use_conv = use_conv
if use_conv:
self.op = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1)
else:
self.op = nn.AvgPool2d(stride=2)
def forward(self, x):
return self.op(x)
上面我们实现了U-Net的所有组件,就可以进行组合来实现U-Net了:
# The full UNet model with attention and timestep embedding
class UNetModel(nn.Module):
def __init__(
self,
in_channels=3,
model_channels=128,
out_channels=3,
num_res_blocks=2,
attention_resolutions=(8, 16),
dropout=0,
channel_mult=(1, 2, 2, 2),
conv_resample=True,
num_heads=4
):
super().__init__()
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_heads = num_heads
# time embedding
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
nn.Linear(model_channels, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
)
# down blocks
self.down_blocks = nn.ModuleList([
TimestepEmbedSequential(nn.Conv2d(in_channels, model_channels, kernel_size=3, padding=1))
])
down_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResidualBlock(ch, mult * model_channels, time_embed_dim, dropout)
]
ch = mult * model_channels
if ds in attention_resolutions:
layers.append(AttentionBlock(ch, num_heads=num_heads))
self.down_blocks.append(TimestepEmbedSequential(*layers))
down_block_chans.append(ch)
if level != len(channel_mult) - 1: # don't use downsample for the last stage
self.down_blocks.append(TimestepEmbedSequential(Downsample(ch, conv_resample)))
down_block_chans.append(ch)
ds *= 2
# middle block
self.middle_block = TimestepEmbedSequential(
ResidualBlock(ch, ch, time_embed_dim, dropout),
AttentionBlock(ch, num_heads=num_heads),
ResidualBlock(ch, ch, time_embed_dim, dropout)
)
# up blocks
self.up_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
layers = [
ResidualBlock(
ch + down_block_chans.pop(),
model_channels * mult,
time_embed_dim,
dropout
)
]
ch = model_channels * mult
if ds in attention_resolutions:
layers.append(AttentionBlock(ch, num_heads=num_heads))
if level and i == num_res_blocks:
layers.append(Upsample(ch, conv_resample))
ds //= 2
self.up_blocks.append(TimestepEmbedSequential(*layers))
self.out = nn.Sequential(
norm_layer(ch),
nn.SiLU(),
nn.Conv2d(model_channels, out_channels, kernel_size=3, padding=1),
)
def forward(self, x, timesteps):
"""
Apply the model to an input batch.
:param x: an [N x C x H x W] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:return: an [N x C x ...] Tensor of outputs.
"""
hs = []
# time step embedding
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
# down stage
h = x
for module in self.down_blocks:
h = module(h, emb)
hs.append(h)
# middle stage
h = self.middle_block(h, emb)
# up stage
for module in self.up_blocks:
cat_in = torch.cat([h, hs.pop()], dim=1)
h = module(cat_in, emb)
return self.out(h)
对于扩散过程,其主要的参数就是timesteps和noise schedule,DDPM采用范围为[0.0001, 0.02]的线性noise schedule,其默认采用的总扩散步数为1000。
# beta schedule
def linear_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
我们定义个扩散模型,它主要要提前根据设计的noise schedule来计算一些系数,并实现一些扩散过程和生成过程:
class GaussianDiffusion:
def __init__(
self,
timesteps=1000,
beta_schedule='linear'
):
self.timesteps = timesteps
if beta_schedule == 'linear':
betas = linear_beta_schedule(timesteps)
elif beta_schedule == 'cosine':
betas = cosine_beta_schedule(timesteps)
else:
raise ValueError(f'unknown beta schedule {beta_schedule}')
self.betas = betas
self.alphas = 1. - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = (
self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
# below: log calculation clipped because the posterior variance is 0 at the beginning
# of the diffusion chain
self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min =1e-20))
self.posterior_mean_coef1 = (
self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.posterior_mean_coef2 = (
(1.0 - self.alphas_cumprod_prev)
* torch.sqrt(self.alphas)
/ (1.0 - self.alphas_cumprod)
)
# get the param of given timestep t
def _extract(self, a, t, x_shape):
batch_size = t.shape[0]
out = a.to(t.device).gather(0, t).float()
out = out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
return out
# forward diffusion (using the nice property): q(x_t | x_0)
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
# Get the mean and variance of q(x_t | x_0).
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = self._extract(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
# Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0)
def q_posterior_mean_variance(self, x_start, x_t, t):
posterior_mean = (
self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
# compute x_0 from x_t and pred noise: the reverse of `q_sample`
def predict_start_from_noise(self, x_t, t, noise):
return (
self._extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
# compute predicted mean and variance of p(x_{t-1} | x_t)
def p_mean_variance(self, model, x_t, t, clip_denoised=True):
# predict noise using model
pred_noise = model(x_t, t)
# get the predicted x_0: different from the algorithm2 in the paper
x_recon = self.predict_start_from_noise(x_t, t, pred_noise)
if clip_denoised:
x_recon = torch.clamp(x_recon, min=-1., max=1.)
model_mean, posterior_variance, posterior_log_variance = \
self.q_posterior_mean_variance(x_recon, x_t, t)
return model_mean, posterior_variance, posterior_log_variance
# denoise_step: sample x_{t-1} from x_t and pred_noise
@torch.no_grad()
def p_sample(self, model, x_t, t, clip_denoised=True):
# predict mean and variance
model_mean, _, model_log_variance = self.p_mean_variance(model, x_t, t,
clip_denoised=clip_denoised)
noise = torch.randn_like(x_t)
# no noise when t == 0
nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1))))
# compute x_{t-1}
pred_img = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return pred_img
# denoise: reverse diffusion
@torch.no_grad()
def p_sample_loop(self, model, shape):
batch_size = shape[0]
device = next(model.parameters()).device
# start from pure noise (for each example in the batch)
img = torch.randn(shape, device=device)
imgs = []
for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
img = self.p_sample(model, img, torch.full((batch_size,), i, device=device, dtype=torch.long))
imgs.append(img.cpu().numpy())
return imgs
# sample new images
@torch.no_grad()
def sample(self, model, image_size, batch_size=8, channels=3):
return self.p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))
# compute train losses
def train_losses(self, model, x_start, t):
# generate random noise
noise = torch.randn_like(x_start)
# get x_t
x_noisy = self.q_sample(x_start, t, noise=noise)
predicted_noise = model(x_noisy, t)
loss = F.mse_loss(noise, predicted_noise)
return loss
其中几个主要的函数总结如下:
q_sample
:实现的从到扩散过程;q_posterior_mean_variance
:实现的是后验分布的均值和方差的计算公式;predict_start_from_noise
:q_sample
的逆过程,根据预测的噪音来生成;p_mean_variance
:根据预测的噪音来计算的均值和方差;p_sample
:单个去噪step;p_sample_loop
:整个去噪音过程,即生成过程。
扩散模型的训练过程非常简单,如下所示:
# train
epochs = 10
for epoch in range(epochs):
for step, (images, labels) in enumerate(train_loader):
optimizer.zero_grad()
batch_size = images.shape[0]
images = images.to(device)
# sample t uniformally for every example in the batch
t = torch.randint(0, timesteps, (batch_size,), device=device).long()
loss = gaussian_diffusion.train_losses(model, images, t)
if step % 200 == 0:
print("Loss:", loss.item())
loss.backward()
optimizer.step()
这里我们以mnist数据简单实现了一个mnist-demo,下面是一些生成的样本:
小结
相比VAE和GAN,扩散模型的理论更复杂一些,不过其优化目标和具体实现却并不复杂,这其实也让人感叹:一堆复杂的数据推导,最终却得到了一个简单的结论。要深入理解扩散模型,DDPM只是起点,后面还有比较多的改进工作,比如加速采样的DDIM以及DDPM的改进版本DDPM+和DDPM++。注:本人水平有限,如有谬误,欢迎讨论交流。
参考
Denoising Diffusion Probabilistic Models Understanding Diffusion Models: A Unified Perspective https://spaces.ac.cn/archives/9119 https://keras.io/examples/generative/ddim/ What are Diffusion Models? https://cvpr2022-tutorial-diffusion-models.github.io/ https://github.com/openai/improved-diffusion https://huggingface.co/blog/annotated-diffusion https://github.com/lucidrains/denoising-diffusion-pytorch https://github.com/hojonathanho/diffusion
推荐阅读
辅助模块加速收敛,精度大幅提升!移动端实时的NanoDet-Plus来了!
机器学习算法工程师
一个用心的公众号