查看原文
其他

VQGAN 论文与源码解读:前Diffusion时代的高清图像生成模型

周弈帆 天才程序员周弈帆
2024-11-23

2022年中旬,以扩散模型为核心的图像生成模型将AI绘画带入了大众的视野。实际上,在更早的一年之前,就有了一个能根据文字生成高清图片的模型——VQGAN。VQGAN不仅本身具有强大的图像生成能力,更是传承了前作VQVAE把图像压缩成离散编码的思想,推广了「先压缩,再生成」的两阶段图像生成思路,启发了无数后续工作。

VQGAN生成出的高清图片

在这篇文章中,我将对VQGAN的论文和源码中的关键部分做出解读,提炼出VQGAN中的关键知识点。由于VQGAN的核心思想和VQVAE如出一辙,我不会过多地介绍VQGAN的核心思想,强烈建议读者先去学懂VQVAE,再来看VQGAN。

VQGAN 核心思想

VQGAN的论文名为Taming Transformers for High-Resolution Image Synthesis,直译过来是「驯服Transformer模型以实现高清图像合成」。可以看出,该方法是在用Transformer生成图像。可是,为什么这个模型叫做VQGAN,是一个GAN呢?这是因为,VQGAN使用了两阶段的图像生成方法:

  • 训练时,先训练一个图像压缩模型(包括编码器和解码器两个子模型),再训练一个生成压缩图像的模型。
  • 生成时,先用第二个模型生成出一个压缩图像,再用第一个模型复原成真实图像。

其中,第一个图像压缩模型叫做VQGAN,第二个压缩图像生成模型是一个基于Transformer的模型。

为什么会有这种乍看起来非常麻烦的图像生成方法呢?要理解VQGAN的这种设计动机,有两条路线可以走。两条路线看待问题的角度不同,但实际上是在讲同一件事。

第一条路线是从Transformer入手。Transformer已经在文本生成领域大展身手。同时,Transformer也在视觉任务中开始崭露头角。相比擅长捕捉局部特征的CNN,Transformer的优势在于它能更好地融合图像的全局信息。可是,Transformer的自注意力操作开销太大,只能生成一些分辨率较低的图像。因此,作者认为,可以综合CNN和Transformer的优势,先用基于CNN的VQGAN把图像压缩成一个尺寸更小、信息更丰富的小图像,再用Transformer来生成小图像。

第二条路线是从VQVAE入手。VQVAE是VQGAN的前作,它有着和VQGAN一模一样两阶段图像生成方法。不同的是,VQVAE没有使用GAN结构,且其配套的压缩图像生成模型是基于CNN的。为提升VQVAE的生成效果,作者提出了两项改进策略:1) 图像压缩模型VQVAE仅使用了均方误差,压缩图像的复原结果较为模糊,可以把图像压缩模型换成GAN;2) 在生成压缩图片这个任务上,基于CNN的图像生成模型比不过Transformer,可以用Transformer代替原来的CNN。

第一条思路是作者在论文的引言中描述的,听起来比较高大上;而第二条思路是读者读过文章后能够自然总结出来的,相对来说比较清晰易懂。如果你已经理解了VQVAE,你能通过第二条思路瞬间弄懂VQGAN的原理。说难听点,VQGAN就是一个改进版的VQVAE。然而,VQGAN的改进非常有效,且使用了若干技巧来实现带约束(比如根据文字描述)的高清图像生成,有非常多地方值得学习。

在下文中,我将先补充VQVAE的背景以方便讨论,再介绍VQGAN论文的四大知识点:VQGAN的设计细节、生成压缩图像的Transformer的设计细节、带约束图像生成的实现方法、高清图像生成的实现方法。

VQVAE 背景知识补充

VQVAE的学习目标是用一个编码器把图像压缩成离散编码,再用一个解码器把图像尽可能地还原回原图像。

通俗来说,VQVAE就是把一幅真实图像压缩成一个小图像。这个小图像和真实图像有着一些相同的性质:小图像的取值和像素值(0-255的整数)一样,都是离散的;小图像依然是二维的,保留了某些空间信息。因此,VQVAE的示意图画成这样会更形象一些:

但小图像和真实图像有一个关键性的区别:与像素值不同,小图像的离散取值之间没有关联。真实图像的像素值其实是一个连续颜色的离散采样,相邻的颜色值也更加相似。比如颜色254和颜色253和颜色255比较相似。而小图像的取值之间是没有关联的,你不能说编码为1与编码为0和编码为2比较相似。由于神经网络不能很好地处理这种离散量,在实际实现中,编码并不是以整数表示的,而是以类似于NLP中的嵌入向量的形式表示的。VAE使用了嵌入空间(又称codebook)来完成整数序号到向量的转换。

为了让任意一个编码器输出向量都变成一个固定的嵌入向量,VQVAE采取了一种离散化策略:把每个输出向量替换成嵌入空间中最近的那个向量的离散编码就是在嵌入空间的下标。这个过程和把254.9的输出颜色值离散化成255的整数颜色值的原理类似。

VQVAE的损失函数由两部分组成:重建误差和嵌入空间误差。

其中,重建误差就是输入和输出之间的均方误差。

嵌入空间误差为解码器输出向量和它在嵌入空间对应向量的均方误差。

作者在误差中还使用了一种「停止梯度」的技巧。这个技巧在VQGAN中被完全保留,此处就不过多介绍了。

图像压缩模型 VQGAN

回顾了VQVAE的背景知识后,我们来正式认识VQGAN的几个创新点。第一点,图像压缩模型VQVAE被改进成了VQGAN。

一般VAE重建出来出来的图像都会比较模糊。这是因为VAE只使用了均方误差,而均方误差只能保证像素值尽可能接近,却不能保证图像的感知效果更加接近。为此,作者把GAN的一些方法引入VQVAE,改造出了VQGAN。

具体来说,VQGAN有两项改进。第一,作者用感知误差(perceptual loss)代替原来的均方误差作为VQGAN的重建误差。第二,作者引入了GAN的对抗训练机制,加入了一个基于图块的判别器,把GAN误差加入了总误差。

计算感知误差的方法如下:把两幅图像分别输入VGG,取出中间某几层卷积层的特征,计算特征图像之间的均方误差。如果你之前没学过相关知识,请搜索"perceptual loss"。

基于图块的判别器,即判别器不为整幅图输出一个真或假的判断结果,而是把图像拆成若干图块,分别输出每个图块的判断结果,再对所有图块的判断结果取一个均值。这只是GAN的一种改进策略而已,没有对GAN本身做太大的改动。如果你之前没学过相关知识,请搜索"PatchGAN"。

这样,总的误差可以写成:

其中,是控制两种误差比例的权重。作者在论文中使用了一个公式来自适应地设置。和普通的GAN一样,VQGAN的编码器、解码器(即生成器)、codebook会最小化误差,判别器会最大化误差。

用VQGAN代替VQVAE后,重建图片中的模糊纹理清晰了很多。

有了一个保真度高的图像压缩模型,我们可以进入下一步,训练一个生成压缩图像的模型。

基于 Transformer 的压缩图像生成模型

如前所述,经VQGAN得到的压缩图像与真实图像有一个本质性的不同:真实图像的像素值具有连续性,相邻的颜色更加相似,而压缩图像的像素值则没有这种连续性。压缩图像的这一特性让寻找一个压缩图像生成模型变得异常困难。多数强大的真实图像生成模型(比如GAN)都是输出一个连续的浮点颜色值,再做一个浮点转整数的操作,得到最终的像素值。而对于压缩图像来说,这种输出连续颜色的模型都不适用了。因此,之前的VQVAE使用了一个能建模离散颜色的PixelCNN模型作为压缩图像生成模型。但PixelCNN的表现不够优秀。

恰好,功能强大的Transformer天生就支持建模离散的输出。在NLP中,每个单词都可以用一个离散的数字表示。Transformer会不断生成表示单词的数字,以达到生成句子的效果。

Transformer 随机生成句子的过程

为了让Transformer生成图像,我们可以把生成句子的一个个单词,变成生成压缩图像的一个个像素。但是,要让Transformer生成二维图像,还需要克服一个问题:在生成句子时,Transformer会先生成第一个单词,再根据第一个单词生成第二个单词,再根据第一、第二个单词生成第三个单词……。也就是说,Transformer每次会根据之前所有的单词来生成下一单词。而图像是二维数据,没有先后的概念,怎样让像素和文字一样有先后顺序呢?

VQGAN的作者使用了自回归图像生成模型的常用做法,给图像的每个像素从左到右,从上到下规定一个顺序。有了先后顺序后,图像就可以被视为一个一维句子,可以用Transfomer生成句子的方式来生成图像了。在第步,Transformer会根据前个像素生成第个像素

带约束的图像生成

在生成新图像时,我们更希望模型能够根据我们的需求生成图像。比如,我们希望模型生成「一幅优美的风景画」,又或者希望模型在一幅草图的基础上作画。这些需求就是模型的约束。为了实现带约束的图像生成,一般的做法是先有一个无约束(输入是随机数)的图像生成模型,再在这个模型的基础上把一个表示约束的向量插入进图像生成的某一步。

把约束向量插入进模型的方法是需要设计的,插入约束向量的方法往往和模型架构有着密切关系。比如假设一个生成模型是U-Net架构,我们可以把约束向量和当前特征图拼接在一起,输入进U-Net的每一大层。

为了实现带约束的图像生成,VQGAN的作者再次借鉴了Transformer实现带约束文字生成的方法。许多自然语言处理任务都可以看成是带约束的文字生成。比如机器翻译,其实可以看成在给定一种语言的句子的前提下,让模型「随机」生成一个另一种语言的句子。比如要把「简要访问非洲」翻译成英语,我们可以对之前无约束文字生成的Transformer做一些修改。

也就是说,给定约束的句子,在第步,Transformer会根据前个输出单词以及生成第个单词。表示约束的单词被添加到了所有输出之前,作为这次「随机生成」的额外输入。

上述方法并不是唯一的文字生成方法。这种文字生成方法被称为"decoder-only"。实际上,也有使用一个编码器来额外维护约束信息的文字生成方法。最早的Transformer就用到了带编码器的方法。

我们同样可以把这种思想搬到压缩图像生成里。比如对于MNIST数据集,我们希望模型只生成0~9这些数字中某一个数字的手写图像。也就是说,约束是类别信息,约束的取值是0~9。我们就可以把这个0~9的约束信息添加到Transformer的输入之前,以实现由类别约束的图像生成。

但这种设计又会产生一个新的问题。假设约束条件不能简单地表示成整数,而是一些其他类型的数据,比如语义分割图像,那该怎么办呢?对于这种以图像形式表示的约束,作者的做法是,再训练另一个VQGAN,把约束图像压缩成另一套压缩图片。这一套压缩图片和生成图像的压缩图片有着不同的codebook,就像两种语言有着不同的单词一样。这样,约束图像也变成了一系列的整数,可以用之前的方法进行带约束图像生成了。

生成高清图像

由于Transformer注意力计算的开销很大,作者在所有配置中都只使用了的压缩图像,再增大压缩图像尺寸的话计算资源就不够了。而另一方面,每张图像在VQGAN中的压缩比例是有限的。如果图像压缩得过多,则VQGAN的重建质量就不够好了。因此,设边长压缩了倍,则该方法一次能生成的图片的最大尺寸是。在多项实验中,的表现都较好。这样算下来,该方法一次只能生成的图片。这种尺寸的图片还称不上高清图片。

为了生成更大尺寸的图片,作者先训练好了一套能生成的图片的VQGAN+Transformer,再用了一种基于滑动窗口的采样机制来生成大图片。具体来说,作者把待生成图片划分成若干个像素的图块,每个图块对应压缩图像的一个像素。之后,在每一轮生成时,只有待生成图块周围的个图块(个像素)会被输入进VQGAN和Transformer,由Transformer生成一个新的压缩图像像素,再把该压缩图像像素解码成图块。(在下面的示意图中,每个方块是一个图块,transformer的输入是个图块)

这个滑动窗口算法不是那么好理解,需要多想一下才能理解它的具体做法。在理解这个算法时,你可能会有这样的问题:上面的示意图中,待生成像素有的时候在最左边,有的时候在中间,有的时候在右边,每次约束它的像素都不一样。这么复杂的约束逻辑怎么编写?其实,Transformer自动保证了每个像素只会由之前的像素约束,而看不到后面的像素。因此,在实现时,只需要把待生成像素框起来,直接用Transformer预测待生成像素即可,不需要编写额外的约束逻辑。

如果你没有学过Transformer的话,理解这部分会有点困难。Transformer可以根据第1~k-1个像素并行地生成第2~k个像素,且保证生成每个像素时不会偷看到后面像素的信息。因此,假设我们要生成第i个像素,其实是预测了所有第2~k个像素的结果,再取出第i个结果,填回待生成图像。

由于论文篇幅有限,作者没有对滑动窗口机制做过多的介绍,也没有讲带约束的滑动窗口是怎么实现的。如果你在理解这一部分时碰到了问题,不用担心,这很正常。稍后我们会在代码阅读章节彻底理解滑动窗口的实现方法。我也是看了代码才看懂此处的做法。

作者在论文中解释了为什么用滑动窗口生成高清图像是合理的。作者先是讨论了两种情况,只要满足这两种情况中的任意一种,拿滑动窗口生成图像就是合理的。第一种情况是数据集的统计规律是几乎空间不变,也就是说训练集图片每个像素的统计规律是类似的。这和我们拿卷积卷图像是因为图像每个像素的统计规律类似的原理是一样的。第二种情况是有空间上的约束信息。比如之前提到的用语义分割图来指导图像生成。由于语义分割也是一张图片,它给每个待生成像素都提供了额外信息。这样,哪怕是用滑动窗口,在局部语义的指导下,模型也足以生成图像了。

若是两种情况都不满足呢?比如在对齐的人脸数据集上做无约束生成。在对齐的人脸数据集里,每张图片中人的五官所在的坐标是差不多的,图片的空间不变性不满足;做无约束生成,自然也没有额外的空间信息。在这种情况下,我们可以人为地添加一个坐标约束,即从左到右、从上到下地给每个像素标一个序号,把每个滑动窗口里的坐标序号做为约束。有了坐标约束后,就还原成了上面的第二种情况,每个像素有了额外的空间信息,基于滑动窗口的方法依然可行。

学完了论文的四大知识点,我们知道VQGAN是怎么根据约束生成高清图像的了。接下来,我们来看看论文的实验部分,看看作者是怎么证明方法的有效性的。

实验

在实验部分,作者先是分别验证了基于Transformer的压缩图像生成模型较以往模型的优越性(4.1节)、VQGAN较以往模型的优越性(4.4节末尾)、使用VQGAN做图像压缩的必要性及相关消融实验(4.3节),再把整个生成方法综合起来,在多项图像生成任务上与以往的图像生成模型做定量对比(4.4节),最后展示了该方法惊艳的带约束生成效果(4.2节)。

在论文4.1节中,作者验证了基于Transformer的压缩图像生成模型的有效性。之前,压缩图像都是使用能输出离散分布的PixelCNN系列模型来生成的。PixelCNN系列的最强模型是PixelSNAIL。为确保公平,作者对比了相同训练时间、相同训练步数下两个网络在不同训练集下的负对数似然(NLL)指标。结果表明,基于Transformer的模型确实训练得更快。

对于直接能建模离散分布的模型来说,NLL就是交叉熵损失函数。

在论文4.4节末尾,作者将VQGAN和之前的图像压缩模型对比,验证了引入感知误差和GAN结构的有效性。作者汇报了各模型重建图像集与原数据集(ImageNet的训练集和验证集)的FID(指标FID是越低越好)。同时,结果也说明,增大codebook的尺寸或者编码种类都能提升重建效果。

在论文4.3节中,作者验证了使用VQGAN的必要性。作者训了两个模型,一个直接让Transformer做真实图像生成,一个用VQGAN把图像边长压缩2倍,再用Transformer生成压缩图像。经比较,使用了VQGAN后,图像生成速度快了10多倍,且图像生成效果也有所提升。

另外,作者还做了有关图像边长压缩比例的消融实验。作者固定让Transformer生成的压缩图片,即每次训练时用到的图像尺寸都是。之后,作者训练训练了不同下的模型,用各个模型来生成图片。结果显示时效果最好。这是因为,在固定Transformer的生成分辨率的前提下,越小,Transformer的感受野越小。如果Transformer的感受野过小,就学习不到足够的信息。

在论文4.4节中,作者探究了VQGAN+Transformer在多项基准测试(benchmark)上的结果。

首先是语义图像合成(根据语义分割图像来生成)任务。本文的这套方法还不错。

接着是人脸生成任务。这套方法表现还行,但还是比不过专精于某一任务的GAN。

作者还比较了各模型在ImageNet上的生成结果。这一比较的数据量较多,欢迎大家自行阅读原论文。

在论文4.2节中,作者展示了多才多艺的VQGAN+Transformer在各种约束下的图像生成结果。这些图像都是按照默认配置生成的,大小为

作者还展示了使用了滑动窗口算法后,模型生成的不同分辨率的图像。

本文开头的那张高清图片也来自论文。

总结

VQGAN是一个改进版的VQVAE,它将感知误差和GAN引入了图像压缩模型,把压缩图像生成模型替换成了更强大的Transformer。相比纯种的GAN(如StyleGAN),VQGAN的强大之处在于它支持带约束的高清图像生成。VQGAN借助NLP中"decoder-only"策略实现了带约束图像生成,并使用滑动窗口机制实现了高清图像生成。虽然在某些特定任务上VQGAN还是落后于其他GAN,但VQGAN的泛化性和灵活性都要比纯种GAN要强。它的这些潜力直接促成了Stable Diffusion的诞生。

如果你是读完了VQVAE再来读的VQGAN,为了完全理解VQGAN,你只需要掌握本文提到的4个知识点:VQVAE到VQGAN的改进方法、使用Transformer做图像生成的方法、使用"decoder-only"策略做带约束图像生成的方法、用滑动滑动窗口生成任意尺寸的图片的思想。

代码阅读

在代码阅读章节中,我将先简略介绍官方源码的项目结构以方便大家学习,再介绍代码中的几处核心代码。具体来说,我会介绍模型是如何组织配置文件的、模型的定义代码在哪、训练代码在哪、采样代码在哪,同时我会主要分析VQGAN的结构、Transformer的结构、损失函数、滑动窗口采样算法这几部分的代码。

官方源码地址:https://github.com/CompVis/taming-transformers。

官方的Git仓库里有很多很大的图片,且git记录里还藏了一些很大的数据,整个Git仓库非常大。如果你的网络不好,建议以zip形式下载仓库,或者只把代码部分下载下来。

项目结构

├─assets
├─configs
├─scripts
└─taming
├─data
│ └─conditional_builder
├─models
└─modules
├─diffusionmodules
├─discriminator
├─losses
├─misc
├─transformer
└─vqvae

configs目录下存放的是模型配置文件。VQGAN和Transformer的模型配置是分开来放的。每个模型配置文件都会指向一个Python模型类,比如taming.models.vqgan.VQModel,配置里的参数就是模型类的初始化参数。我们可用通过阅读配置文件找到模型的定义位置。

运行脚本包括根目录下的main.pyscripts文件夹下的脚本。main.py是用于训练的。scripts文件夹下有各种采样脚本和数据集可视化脚本。

taming是源代码的主目录。其data子文件夹下放置了各数据集的预处理代码,models放置了VQGAN和Transformer PyTorch模型的定义代码,modules则放置了模型中用到的模块,主要包括VQGAN编码解码模块(diffusionmodules)、判别器模块(discriminator)、误差模块(losses)、Transformer模块(transformer)、codebook模块(vqvae)。

VQGAN 模型结构

打开configs\faceshq_vqgan.yaml,我们能够找到高清人脸生成任务使用的VQGAN模型配置。我们来学习一下这个模型的定义方法。

model:
  base_learning_rate: 4.5e-6
  target: taming.models.vqgan.VQModel
  params:
    embed_dim: 256
    n_embed: 1024
    ddconfig:
      ...

    lossconfig:
      target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
      params:
        ...

从配置文件的target字段中,我们知道VQGAN定义在模块taming.models.vqgan.VQModel中。我们可以打开taming\models\vqgan.py这个文件,查看其中VQModel类的代码。

首先先看一下初始化函数。初始化函数主要是初始化了encoderdecoderlossquantize这几个模块,我们可以从文件开头的import语句中找到这几个模块的定义位置。不过,先不急,我们来继续看一下模型的前向传播函数。

from taming.modules.diffusionmodules.model import Encoder, Decoder
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from taming.modules.vqvae.quantize import GumbelQuantize
from taming.modules.vqvae.quantize import EMAVectorQuantizer

class VQModel(pl.LightningModule):
    def __init__(self,
                 ddconfig,
                 lossconfig,
                 n_embed,
                 embed_dim,
                 ckpt_path=None,
                 ignore_keys=[],
                 image_key="image",
                 colorize_nlabels=None,
                 monitor=None,
                 remap=None,
                 sane_index_shape=False,  # tell vector quantizer to return indices as bhw
                 ):

        super().__init__()
        self.image_key = image_key
        self.encoder = Encoder(**ddconfig)
        self.decoder = Decoder(**ddconfig)
        self.loss = instantiate_from_config(lossconfig)
        self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
                                        remap=remap, sane_index_shape=sane_index_shape)
        self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
        self.image_key = image_key
        if colorize_nlabels is not None:
            assert type(colorize_nlabels)==int
            self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 11))
        if monitor is not None:
            self.monitor = monitor

模型的前向传播逻辑非常清晰。self.encoder可以把一张图片变为特征,self.decoder可以把特征变回图片。self.quant_convpost_quant_conv则分别完成了编码器到codebook、codebook到解码器的通道数转换。self.quantize实现了VQVAE和VQGAN中那个找codebook里的最近邻、替换成最近邻的操作。

def encode(self, x):
    h = self.encoder(x)
    h = self.quant_conv(h)
    quant, emb_loss, info = self.quantize(h)
    return quant, emb_loss, info

def decode(self, quant):
    quant = self.post_quant_conv(quant)
    dec = self.decoder(quant)
    return dec

def forward(self, input):
    quant, diff, _ = self.encode(input)
    dec = self.decode(quant)
    return dec, diff

接下来,我们再看一看VQGAN的各个模块的定义。编码器和解码器的定义都可以在taming\modules\diffusionmodules\model.py里找到。VQGAN使用的编码器和解码器基于DDPM论文中的U-Net架构(而此架构又可以追溯到PixelCNN++的模型架构)。相比于最经典的U-Net,此U-Net每一层由若干个残差块和若干个自注意力块构成。为了把这个U-Net用到VQGAN里,U-Net的下采样部分和上采样部分被拆开,分别做成了VQGAN的编码器和解码器。

此处代码过长,我就只贴出部分关键代码了。以下是编码器的__init__函数和forward函数的关键代码。self.down存储了U-Net各层的模块。对于第i层,down[i].block是所有残差块,down[i].attn是所有自注意力块,down[i].downsample是下采样操作。它们在forward里会被依次调用。解码器的结构与之类似,只不过下采样变成了上采样。

class Encoder(nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, double_z=True, **ignore_kwargs):

        super().__init__()
        ...
        self.down = nn.ModuleList()
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = ch*in_ch_mult[i_level]
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(AttnBlock(block_in))
            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions-1:
                down.downsample = Downsample(block_in, resamp_with_conv)
                curr_res = curr_res // 2
            self.down.append(down)

       ...


    def forward(self, x):
        hs = [self.conv_in(x)]
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level].block[i_block](hs[-1], temb)
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
                hs.append(h)
            if i_level != self.num_resolutions-1:
                hs.append(self.down[i_level].downsample(hs[-1]))
        ...

        return h

之后,我们再看看离散化层的代码,即把编码器的输出变成codebook里的嵌入的实现代码。作者在taming\modules\vqvae\quantize.py中提供了VQVAE原版的离散化操作以及若干个改进过的离散化操作。我们就来看一下原版的离散化模块VectorQuantizer是怎么实现的。

离散化模块的初始化非常简洁,主要是初始化了一个嵌入层。

class VectorQuantizer(nn.Module):
    def __init__(self, n_e, e_dim, beta):
        super(VectorQuantizer, self).__init__()
        self.n_e = n_e
        self.e_dim = e_dim
        self.beta = beta

        self.embedding = nn.Embedding(self.n_e, self.e_dim)
        self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)

在前向传播时,作者先是算出了编码器输出z和所有嵌入的距离d,再用argmin算出了最近邻嵌入的下标min_encodings,最后根据下标取出解码器输入z_q。同时,该函数还计算了其他几个可能用到的量,比如和codebook有关的误差 loss。注意,在计算lossz_q时,作者都使用到了停止梯度算子(.detach())。

    def forward(self, z):
        z = z.permute(0231).contiguous()
        z_flattened = z.view(-1, self.e_dim)
        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z

        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(self.embedding.weight**2, dim=1) - 2 * \
            torch.matmul(z_flattened, self.embedding.weight.t())

        ## could possible replace this here
        # #\start...
        # find closest encodings
        min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)

        min_encodings = torch.zeros(
            min_encoding_indices.shape[0], self.n_e).to(z)
        min_encodings.scatter_(1, min_encoding_indices, 1)

        z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
        #.........\end


        # compute loss for embedding
        loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
            torch.mean((z_q - z.detach()) ** 2)

        # preserve gradients
        z_q = z + (z_q - z).detach()

        # perplexity
        e_mean = torch.mean(min_encodings, dim=0)
        perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))

        # reshape back to match original input shape
        z_q = z_q.permute(0312).contiguous()

        return z_q, loss, (perplexity, min_encodings, min_encoding_indices)

VQGAN的三个主要模块已经看完了。最后,我们来看一下误差的定义。误差的定义在taming\modules\losses\vqperceptual.pyVQLPIPSWithDiscriminator类里。误差类名里的LPIPS(Learned Perceptual Image Patch Similarity,学习感知图像块相似度)就是感知误差的全称,"WithDiscriminator"表示误差是带了判定器误差的。我们来把这两类误差分别看一下。

说实话,这个误差模块乱得一塌糊涂,一边自己在算误差,一边又维护了codebook误差和重建误差的权重,最后会把自己维护的两个误差和其他误差合在一起输出。功能全部耦合在一起。我们就跳过这个类的实现细节,主要关注self.perceptual_lossself.discriminator是怎么调用其他模块的。

from taming.modules.losses.lpips import LPIPS
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init

class VQLPIPSWithDiscriminator(nn.Module):
    def __init__(self, ...):
        super().__init__()

        self.perceptual_loss = LPIPS().eval()

        self.discriminator = NLayerDiscriminator...

感知误差模块在taming\modules\losses\vqperceptual.py文件里。这个文件来自GitHub项目 PerceptualSimilarity。

感知误差可以简单地理解为两张图片在VGG中几个卷积层输出的误差的加权和。加权的权重是可以学习的。作者使用的是已经学习好的感知误差。感知误差的初始化函数如下。其中,self.lin0等模块就是算权重的模块,self.net是VGG。

class LPIPS(nn.Module):
    # Learned perceptual metric
    def __init__(self, use_dropout=True):
        super().__init__()
        self.scaling_layer = ScalingLayer()
        self.chns = [64128256512512]  # vg16 features
        self.net = vgg16(pretrained=True, requires_grad=False)
        self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
        self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
        self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
        self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
        self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
        self.load_from_pretrained()
        for param in self.parameters():
            param.requires_grad = False

在算误差时,先是把图像inputtarget都输入进VGG,获取各层输出outs0, outs1,再求出两个图像的输出的均方误差diffs,最后用lins给各层误差加权,求和。

def forward(self, input, target):
    in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
    outs0, outs1 = self.net(in0_input), self.net(in1_input)
    feats0, feats1, diffs = {}, {}, {}
    lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
    for kk in range(len(self.chns)):
        feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
        diffs[kk] = (feats0[kk] - feats1[kk]) ** 2

    res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=Truefor kk in range(len(self.chns))]
    val = res[0]
    for l in range(1, len(self.chns)):
        val += res[l]
    return val

GAN的判别器写在taming\modules\discriminator\model.py文件里。这个文件来自GitHub上的 pytorch-CycleGAN-and-pix2pix 项目。这个判别器非常简单,就是一个全卷积网络。

class NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator as in Pix2Pix
        --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
    """

    def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
        """Construct a PatchGAN discriminator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """

        super(NLayerDiscriminator, self).__init__()
        if not use_actnorm:
            norm_layer = nn.BatchNorm2d
        else:
            norm_layer = ActNorm
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func != nn.BatchNorm2d
        else:
            use_bias = norm_layer != nn.BatchNorm2d

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2True)
        ]

        sequence += [
            nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.main = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        return self.main(input)

Transformer 模型结构

此方法使用的Transformer是GPT2。我们先看一下该项目封装Transformer的模型类taming.models.cond_transformer.Net2NetTransformer,再稍微看一下GPT类taming.modules.transformer.mingpt.GPT的具体实现。

Net2NetTransformer主要是实现了论文中提到的带约束生成。它会把输入x和约束c分别用一个VQGAN转成压缩图像,把图像压扁成一维,再调用GPT。我们来看一下这个类的主要内容。

初始化函数主要是初始化了输入图像的VQGAN self.first_stage_model、约束图像的VQGAN self.cond_stage_model、Transformer self.transformer

class Net2NetTransformer(pl.LightningModule):
    def __init__(self,
                 transformer_config,
                 first_stage_config,
                 cond_stage_config,
                 permuter_config=None,
                 ckpt_path=None,
                 ignore_keys=[],
                 first_stage_key="image",
                 cond_stage_key="depth",
                 downsample_cond_size=-1,
                 pkeep=1.0,
                 sos_token=0,
                 unconditional=False,
                 ):

        super().__init__()
        self.be_unconditional = unconditional
        self.sos_token = sos_token
        self.first_stage_key = first_stage_key
        self.cond_stage_key = cond_stage_key
        self.init_first_stage_from_ckpt(first_stage_config)
        self.init_cond_stage_from_ckpt(cond_stage_config)
        if permuter_config is None:
            permuter_config = {"target""taming.modules.transformer.permuter.Identity"}
        self.permuter = instantiate_from_config(config=permuter_config)
        self.transformer = instantiate_from_config(config=transformer_config)

        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
        self.downsample_cond_size = downsample_cond_size
        self.pkeep = pkeep

    def init_first_stage_from_ckpt(self, config):
        model = instantiate_from_config(config)
        model = model.eval()
        model.train = disabled_train
        self.first_stage_model = model

    def init_cond_stage_from_ckpt(self, config):
        ...
        self.cond_stage_model = ...

模型的前向传播函数如下。一开始,函数调用encode_to_zencode_to_c,根据self.cond_stage_modelself.first_stage_model把约束图像和输入图像编码成压扁至一维的压缩图像。之后函数做了一个类似Dropout的操作,根据self.pkeep随机替换掉约束编码。最后,函数把约束编码和输入编码拼接起来,使用通常方法调用Transformer。

def forward(self, x, c):
    # one step to produce the logits
    _, z_indices = self.encode_to_z(x)
    _, c_indices = self.encode_to_c(c)

    if self.training and self.pkeep < 1.0:
        mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
                                                      device=z_indices.device))
        mask = mask.round().to(dtype=torch.int64)
        r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
        a_indices = mask*z_indices+(1-mask)*r_indices
    else:
        a_indices = z_indices

    cz_indices = torch.cat((c_indices, a_indices), dim=1)

    # target includes all sequence elements (no need to handle first one
    # differently because we are conditioning)
    target = z_indices
    # make the prediction
    logits, _ = self.transformer(cz_indices[:, :-1])
    # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
    logits = logits[:, c_indices.shape[1]-1:]

    return logits, target

GPT2的结构不是本文的重点,我们就快速把模型结构过一遍了。GPT2的模型定义在taming.modules.transformer.mingpt.GPT里。GPT2的结构并不复杂,就是一个只有解码器的Transformer。前向传播时,数据先通过嵌入层self.tok_emb,再经过若干个Transformer模块self.blocks,最后过一个LayerNorm层self.ln_f和线性层self.head

class GPT(nn.Module):

    def forward(self, idx, embeddings=None, targets=None):
        # forward the GPT model
        token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector

        if embeddings is not None# prepend explicit embeddings
            token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)

        t = token_embeddings.shape[1]
        assert t <= self.block_size, "Cannot forward, model block size is exhausted."
        position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
        x = self.drop(token_embeddings + position_embeddings)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x)

        # if we are given some desired targets also calculate the loss
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return logits, loss

每个Transformer块就是非常经典的自注意力加全连接层。

class Block(nn.Module):
    """ an unassuming Transformer block """
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),  # nice
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.resid_pdrop),
        )

    def forward(self, x, layer_past=None, return_present=False):
        # TODO: check that training still works
        if return_present: assert not self.training
        # layer past: tuple of length two with B, nh, T, hs
        attn, present = self.attn(self.ln1(x), layer_past=layer_past)

        x = x + attn
        x = x + self.mlp(self.ln2(x))
        if layer_past is not None or return_present:
            return x, present
        return x

基于滑动窗口的带约束图像生成

看完了所有模型的结构,我们最后来学习一下论文中没能详细介绍的滑动窗口算法。在scripts\taming-transformers.ipynb里有一个采样算法的最简实现,我们就来学习一下这份代码。

这份代码可以根据一幅语义分割图像来生成高清图像。一开始,代码会读入模型和语义分割图像。大致的代码为:

from taming.models.cond_transformer import Net2NetTransformer
model = Net2NetTransformer(**config.model.params)
from PIL import Image
import numpy as np
segmentation_path = "data/sflckr_segmentations/norway/25735082181_999927fe5a_b.png"
segmentation = Image.open(segmentation_path)
...



之后,代码把约束图像用对应的VQGAN编码进压缩空间,得到c_indices。由于待生成图像为空,我们可以随便生成一个待生成图像的压缩图像z_indices,代码中使用了randint初始化待生成的压缩图像。

c_code, c_indices = model.encode_to_c(segmentation)
z_indices = torch.randint(codebook_size, z_indices_shape, device=model.device)

idx = z_indices
idx = idx.reshape(z_code_shape[0],z_code_shape[2],z_code_shape[3])

cidx = c_indices
cidx = cidx.reshape(c_code.shape[0],c_code.shape[2],c_code.shape[3])

最后就是最关键的滑动窗口采样部分了。我们先稍微浏览一遍代码,再详细地一行一行看过去。

temperature = 1.0
top_k = 100

for i in range(0, z_code_shape[2]-0):
  if i <= 8:
    local_i = i
  elif z_code_shape[2]-i < 8:
    local_i = 16-(z_code_shape[2]-i)
  else:
    local_i = 8
  for j in range(0,z_code_shape[3]-0):
    if j <= 8:
      local_j = j
    elif z_code_shape[3]-j < 8:
      local_j = 16-(z_code_shape[3]-j)
    else:
      local_j = 8

    i_start = i-local_i
    i_end = i_start+16
    j_start = j-local_j
    j_end = j_start+16
    
    patch = idx[:,i_start:i_end,j_start:j_end]
    patch = patch.reshape(patch.shape[0],-1)
    cpatch = cidx[:, i_start:i_end, j_start:j_end]
    cpatch = cpatch.reshape(cpatch.shape[0], -1)
    patch = torch.cat((cpatch, patch), dim=1)
    logits,_ = model.transformer(patch[:,:-1])
    logits = logits[:, -256:, :]
    logits = logits.reshape(z_code_shape[0],16,16,-1)
    logits = logits[:,local_i,local_j,:]

    logits = logits/temperature

    if top_k is not None:
      logits = model.top_k_logits(logits, top_k)

    probs = torch.nn.functional.softmax(logits, dim=-1)
    idx[:,i,j] = torch.multinomial(probs, num_samples=1)

x_sample = model.decode_to_img(idx, z_code_shape)
show_image(x_sample)

一开始的temperaturetop_k是得到logit后的采样参数,和滑动窗口算法无关。

temperature = 1.0
top_k = 100

进入生成图像循环后,i, j分别表示压缩图像的竖索引和横索引,i_start, i_end, j_start, j_end是滑动窗口上下左右边界。

for i in range(0, z_code_shape[2]-0):
  ...
  for j in range(0,z_code_shape[3]-0):
    ...
    i_start = i-local_i
    i_end = i_start+16
    j_start = j-local_j
    j_end = j_start+16

为了获取这四个滑动窗口的范围,代码用了若干条件语句计算待生成像素在滑动窗口里的相对位置local_i, local_j

for i in range(0, z_code_shape[2]-0):
  if i <= 8:
    local_i = i
  elif z_code_shape[2]-i < 8:
    local_i = 16-(z_code_shape[2]-i)
  else:
    local_i = 8
  for j in range(0,z_code_shape[3]-0):
    if j <= 8:
      local_j = j
    elif z_code_shape[3]-j < 8:
      local_j = 16-(z_code_shape[3]-j)
    else:
      local_j = 8

得到了滑动窗口的边界后,代码用滑动窗口从约束图像的压缩图像和待生成图像的压缩图像上各取出一个图块,并拼接起来。

patch = idx[:,i_start:i_end,j_start:j_end]
patch = patch.reshape(patch.shape[0],-1)
cpatch = cidx[:, i_start:i_end, j_start:j_end]
cpatch = cpatch.reshape(cpatch.shape[0], -1)
patch = torch.cat((cpatch, patch), dim=1)

之后,只需要把拼接的图块直接输入进Transformer,得到输出logits,再用local_i,local_j去输出图块的对应位置取出下一个压缩图像像素的概率分布,就可以随机生成下一个压缩图像像素了。如前文所述,Transformer类会把二维的图块压扁到一维,输入进GPT。同时,GPT会自动保证前面的像素看不到后面的像素,我们不需要人为地指定约束像素。这个地方的调用逻辑其实非常简单。

logits,_ = model.transformer(patch[:,:-1])
logits = logits[:, -256:, :]
logits = logits.reshape(z_code_shape[0],16,16,-1)
logits = logits[:,local_i,local_j,:]

最后只要从logits里采样,把采样出的压缩图像像素填入idx,就完成了一步生成。

logits = logits/temperature

if top_k is not None:
    logits = model.top_k_logits(logits, top_k)

probs = torch.nn.functional.softmax(logits, dim=-1)
idx[:,i,j] = torch.multinomial(probs, num_samples=1)

反复执行循环,就能将压缩图像生成完毕。最后将压缩图像过一遍VQGAN的解码器即可得到最终的生成图像。

x_sample = model.decode_to_img(idx, z_code_shape)
show_image(x_sample)

参考资料

VQGAN论文:https://arxiv.org/abs/2012.09841

VQGAN GitHub:https://github.com/CompVis/taming-transformers

如果你需要补充学习早期工作,欢迎阅读我之前的文章。

Transformer解读

PixelCNN解读

VQVAE解读


继续滑动看下一个
天才程序员周弈帆
向上滑动看下一个

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存