查看原文
其他

ICLR 2023 | 扩散生成模型新方法:极度简化,一步生成

刘星超 PaperWeekly 2023-03-18

©作者 | 刘星超
单位 | 德州大学奥斯汀分校
研究方向 | 生成式模型


Diffusion Generative Models(扩散式生成模型)已经在各种生成式建模任务中大放异彩,但是,其复杂的数学推导却常常让大家望而却步,缓慢的生成速度也极大地阻碍了研究的快速迭代和高效部署。研究过 DDPM 的同学可能见到过这种画风的变分法(Variational Inference)推导(截取自 What are Diffusion Models):



总体上推导的难度和对数学的要求还是比较高的。在连续时间的形式下,还需要随机微分方程(Stochastic Differential Equation(SDE))的知识,有不低的入门门槛。除此以外,扩散式生成模型的一个众所周知的老大难问题就是生成速度慢:生成一张图需要模拟一整个基于复杂的深度模型的扩散过程。缓慢的生成速度是阻碍这些模型更广泛的普及的一个主要瓶颈。


Rectified Flow,一个“简简单单走直线”生成模型,是我们对这些挑战的一个回答:极度简单,一步生成。我们的方法有以下要点:


(1)我们无需一般扩散模型复杂的推导,代之以一个简单的“沿直线生成”的思想。算法理解上不需要变分法或随机微分方程等基础知识。我们的方法是基于一个简单的常微分方程(ODE),通过构造一个“尽量走直线”的连续运动系统来产生想要的数据分布。


(2)“尽量走直线”的目的是让我们模型实现快速生成。通过一个叫“reflow”的方法,我们可以实现梦想中的“一步生成”:只需一步计算就直接产生高质量的结果,而不需要调用计算量大的数值求解器来迭代式地模拟整个扩散过程。


(3)通常的扩散模型是把高斯白噪声转换成想要的数据(比如图片)。我们的方法可以把任何一种数据或噪声(比如猫脸照片)转换成另外一种数据(比如人脸照片)。所以我们的方法不仅可以做生成模型,还可以应用于很多更广泛的迁移学习(比如 domain transfer)任务上。


有兴趣的同学可以参见我们的论文(Arxiv 或 OpenReview,以及和最优传输(optimal transport)相关的深入理论 Arxiv)。代码,示例 Colab Notebook 和预训练模型已经开源在 github。一个英文版简介在这里。欢迎大家使用和交流!


▲ Rectified Flow 可以实现生成式模型或者无监督图像转换(图中是人 ↔ 猫)。同时,通过新颖的 Reflow 算法,我们可以将 ODE 的轨迹拉直,在 N=1 时也取得较好的生成效果(图中 N 指我们所使用的 Euler 求解器的步数)。




问题-传输映射(将一个分布搬运到另一个分布)



我们先定义好要解决的问题。无论是从噪声生成图片(generative modeling),还是将人脸转化为猫脸(domain transfer),都可以这样概括成将一个分布转化成另一个分布的问题:


给定从两个分布 中的采样,我们希望找到一个传输映射 使得,当 时,

比如,在生成模型里, 是高斯噪声分布, 是数据的分布(比如图片),我们想找到一个方法,把噪声 映射成一个服从 的数据 。在数据迁移(domain transfer)里, 分别是人脸和猫脸的图片。所以这个问题是生成模型和数据迁移的统一表述。

在我们的框架下,映射 是通过以下连续运动系统,也就是一个常微分方程(ordinary differential equation(ODE)),或者叫流模型(flow),来隐式定义的:


我们可以想象从 里采样出来的 是一个粒子。它从 时刻开始连续运动,在 时刻以 为速度。直到 时刻得到 。我们希望 服从分布 。这里我们假设 是一个神经网络。我们的任务是从数据里学习出 来达到 的目的。



走直线,走得快


除了希望 ,我们还希望这个连续运动系统能够在计算机里快速地模拟出来。注意到,在实际计算过程中,上面的连续系统通常是用 Euler 法(或其变种)在离散化的时间上近似:


这里 是一个步长参数。我们需要适当的选择 来平衡速度和精度: 需要足够小来保证近似的精度,但同时小的 意味着我们从 要跑很多步,速度就慢。


那么问题来了,什么样的系统能最快地用 Euler 法来模拟呢?也就是说,什么样的体系能允许我们在用较大的步长 的同时还能得到很好的精度呢?


答案是“走直线”。如下图所示,如果粒子的运动轨迹是弯曲的,我们需要很细的离散化来得到很好的结果。如果粒子的轨迹是直线,那么即使我们取最大的步长(),只用一步走到 时刻,还是能得到正确的结果!


所以,我们希望我们学习出来的速度模型 既能保证 ,又能给出尽量直的轨迹。怎么同时实现这两个目的在数学上是一个非常不简单(non-trivial)的问题,涉及最优传输(optimal transport)的一些深刻理论。但是我们发现其实可以用一个非常简单的方法来解决这个问题。


▲ 蓝色:真实 ODE 轨迹;绿色:Euler 法得到的离散轨迹。左:弯曲的运动轨迹需要较小的步长来离散化才能得到较小误差,所以需要更多的步数;右:笔直的运动轨迹甚至可以在计算机里用一步进行完美的模拟。




Rectified Flow-基于直线ODE学习生成模型


假设我们有从两个分布中的采样 (比如 是从 里出来的随机噪声, 是一个随机的数据(服从 ))。我们把 用一个线性插值连接起来,得到



这里 是随机,或者说,以任意方式配对的。你也许觉得 应该用一种有意义的方式配对好,这样能够得到更好的效果。我们先忽略这个问题,待会回来解决它。


现在,如果我们拿 对时间 求导,我们其实已经可以得到一个能够将数据从 传输到 的“ODE”了,



但是,这个“ODE”并不实用而且很奇怪,所以要打个引号:它不是一个“因果”(causal),或者“可前向模拟”(forward simulatable)的系统,因为要计算 时刻的速度 需要提前(在 时)知道 ODE 轨迹的终点 。如果我们都已经知道 了,那其实也就没有必要模拟 ODE 了。


那么我们能不能学习 ,使得我们想要的“可前向模拟”的 ODE 能尽可能逼近刚才这个“不可前向模拟”的过程呢?最简单的方法就是优化 来最小化这两个系统的速度函数(分别是 )之间的平方误差:



这是一个标准的优化任务。我们可以将 设置成一个神经网络,并用随机梯度下降或者 Adam 来优化,进而得到我们的可模拟 ODE 模型。


这就是我们的基本方法。数学上,我们可以证明这样学出来的 确实可以保证生成想要的分布 。对数学感兴趣的同学可以看一看论文里的理论推导。下面我们只用这个图来给一些直观的解释。



图(a):在我们用直线连接 时,有些线会在中间的地方相交,这是导致 非因果的原因(在交叉点, 既可以沿蓝线走,也可以沿绿线走,因此粒子不知该向岔路的哪边走)。


图(b):我们学习出的 ODE 因为必须是因果的,所以不能出现道路相交的情况,它会在原来相交的地方把道路交换成不交叉的形式。这样,我们学习出来的 ODE 仍然保留了原来的基本路径,但是做了一个重组来避免相交的情况。这样的结果是,图(a)和图(b)里的系统在每个时刻 的边际分布是一样的,即使总体的路径不一样。


我们的方法起名为 Rectified Flow。这里 rectified 是“拉直”,“规整”的意思。我们这个框架其实也可以用来推导和解释其他的扩散模型(如 DDPM)。我们论文里有详细说明,这里就不赘述了。我们现在的算法版本应该是在已知的算法空间里最简单的选项了。我们提供了 Colab Notebook 来帮助大家通过实践来理解这个过程。





Reflow-拉直轨迹,一步生成


因为 Rectified Flow 要在直线轨迹的交叉点做路径重组,所以上面的 ODE 模型(或者说 flow)的轨迹仍然可能是弯曲的(如上面的图(b)),不能达到一步生成。我们提出一个“Reflow”方法,将 ODE 的轨迹进一步变直。

具体的做法非常简单: 假设我们从 里采样出一批 。然后,从 出发,我们模拟上面学出的 flow(叫它 1-Rectified Flow),得到 。我们用这样得到的 对来学一个新的“2-Rectified Flow”:


这里,2-Rectified Flow 和 1-Rectified Flow 在训练过程中唯一的区别就是数据配对不同:在 1-Rectified Flow 中, 是随机或者任意配对的;在 2-Rectified Flow 中, 是通过 1-Rectified Flow 配对的。

上面的动图中,图(c)展示了 Reflow 的效果。因为从 1-Rectified Flow 里出来的 已经有很好的配对, 他们的直线插值交叉数减少,所以 2-Rectified Flow 的轨迹也就(比起 1-Rectified Flow)变得很直了(虽然仔细看还不完美)。

理论上,我们可以重复 Reflow 多次,从而得到 3-Rectified Flow, 4-Rectified Flow... 我们可以证明这个过程其实是在单调地减小最优传输理论中的传输代价(transport cost),而且最终收敛到完全直的状态。

当然,实际中,因为每次 优化得不完美,多次 Reflow 会积累误差,所以我们不建议做太多次的Reflow。幸运的是,在我们的实验中,我们发现对生成图片和很多我们感兴趣的问题而言,像上面的图(c)一样,1次 Reflow 已经可以得到非常直的轨迹了,配合蒸馏足够达到一步生成的效果了。



Reflow与Distillation


给定一个配对 ,要想实现一步生成,也就是 , 我们好像也可以通过优化下面的平方误差来直接"蒸馏(distillation)"出一个一步模型:



这个目标函数和上面的 Reflow 的目标函数很像,只是把所有的时间 都设成 了。


尽管如此,Distillation 和 Reflow 是有本质的区别的。Distillation 试图一五一十地复现 配对的关系。但是,如果 的配对是随机的,Distillation最多只能得到 在给定 时的条件平均,也就是 ,并不能成功地完全匹配 。即使 有确定的一一对应关系,他们的配对关系也可能很复杂,导致直接蒸馏很困难。


Reflow 解决了 Distillation 的这些困难。它的意义在于 :

1)给定任何 配对,就算是随机的配对,他都能学出一个给出正确边际分布(marginal distribution)的 flow。Reflow 不会去试图完全复现 的配对关系,而只注重于得到正确的边际分布。

2)从 Reflow 出的 ODE 里采样,我们还可以得到一个更好的配对 ,从而给出更好的 flow。重复这个过程可以最终得到保证一步生成的直线 ODE。

形象地来讲,如果 太复杂,Reflow会“拒绝”完全复现 ,转而给出一个新的,更简单的,但仍然满足 的配对 。所以,Distillation 更像“模仿者”,只会机械地模仿,就算问题无解也要“硬做”。Reflow 更像“创造者”,懂得变通,发现新方法来解决问题。

当然,Reflow 和 Distillation 也可以组合使用:先用 Reflow 得到比较好的配对,最后再用已经很好的配对进行 Distillation 。我们在论文里发现,这个结合的策略确实有用。

下面,我们进一步基于具体例子解释一下 Reflow 对配对的提高效果。如果一个配对 是好的,那么从这个配对里随机产生的两条直线 就不会相交。在我们的论文里,这种直线不相交的配对我们叫做“Straight Coupling”。我们的 Reflow 过程就是在不停地降低这个相交概率的过程。下图我们展示随着 Reflow 的不断进行,配对的直线交叉数确实逐渐降低。

在图中,对每种配对方法,我们随机选择两个配对,分别用直线段连接它们,然后若它们相交,就用红色点标出这两条直线段的交点。对于这种交叉的配对,Reflow 就有可能改善它们。

我们重复 10000 次并统计交叉的概率。我们发现:1)每次 Reflow 都降低了交叉的概率和 L2 传输代价;2)即使 2-Rectified Flow 在肉眼观察时已经很直,但它的交叉概率仍不为 0,更多的 Reflow 次数就可能进一步使它变直并降低传输代价。相比之下,单纯的蒸馏是不能改善配对的,这是 Reflow 与蒸馏的本质区别。

▲ 图中,每个红点代表一次两随机的直线交叉的事件。随着 reflow,交叉的概率逐渐降低,对应的 ODE 的轨迹也越来越直。




理论保证


Rectified Flow 不仅简洁,而且在理论上也有很好的性质。我们在此给出一些理论保证的非正式表述,如果大家对理论部分感兴趣,欢迎大家阅读我们文章的细节。


1.边际分布不变: 取得最优值时,对任意时间 ,我们有 的分布相等。因为 ,因此 确实可以将 转移到


2.降低传输损失:每次 Reflow 都可以降低两个分布之间的传输代价。特别的,Reflow 并不优化一个特定的损失函数,而是同时优化所有的凸损失函数。
3.拉直 ODE 轨迹:通过不停重复 Reflow,ODE 轨迹的直线性(Straightness)以 的速率下降,这里, 是 reflow 的次数。



实验结果-Rectified Flow能做到什么?


▲ CIFAR-10实验结果


使用 Runge Kutta-45 求解器,1-Rectified Flow 在 CIFAR10 上得到 IS=9.6, FID=2.58,recall=0.57,基本与之前的 VP SDE/sub-VP SDE [2] 相同,但是平均只需要 127 步进行模拟。


Reflow 可以使 ODE 轨迹变直,因此2-Rectified Flow 和 3-Rectified Flow 在仅用一步(N=1)时也可以有效的生成图片(FID=12.21/8.15)


Reflow 可以降低传输损失,因此在进行蒸馏时会得到更好的表现。用 2-Rectified Flow + 蒸馏,我们在仅用一步生成时得到了 FID=4.85,远超之前最好的仅基于蒸馏/基于 GAN loss 的快速扩散式生成模型(当用一步采样时 FID=8.91)。同时,比起 GAN,Rectified Flow + 蒸馏有更好的多样性(recall>0.5)


我们的方法也可以用于高清图片生成无监督图像转换


▲ 1-rectified flow: 256分辨率图像生成


▲ 1-rectified flow: 256分辨率无监督图像转换


同期相关工作


有意思的是,今年 ICLR 在 openreview 上出现了好几篇投稿论文提出了类似的想法。


(1) Flow Matching for Generative Modeling:https://openreview.net/forum?id=PqvMRDCJT9t
(2) Building Normalizing Flows with Stochastic Interpolants:https://openreview.net/forum?id=li7qeBbCR1t
(3) Iterative -alpha (de)Blending: Learning a Deterministic Mapping Between Arbitrary Densities:https://openreview.net/forum?id=s7gnrEtWSm
(4) Action Matching: A Variational Method for Learning Stochastic Dynamics from Samples: https://openreview.net/forum?id=T6HPzkhaKeS


这些工作都或多或少地提出了用拟合插值过程来构建生成式 ODE 模型的方法。除此之外,我们的工作还阐明了这个路径相交重组的直观解释和最优传输的内在联系,提出了 Reflow 算法,实现了一步生成,建立了比较完善的理论基础。大家不约而同地在一个地方发力,说明这个方法的出现是有很大的必然性的。因为它的简单形式和很好的效果,相信以后有很大的潜力。


如有任何问题,欢迎留言或者发邮件!


主要论文:

X. Liu, C. Gong, Q. Liu. Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow. ICLR 2023, arXiv:2209.03003Q. Liu. Rectified flow: A marginal preserving approach to optimal transport. arXiv preprint arXiv:2209.14577, 2022.


参考文献

[1] Song Y, Sohl-Dickstein J, Kingma D P, et al. Score-Based Generative Modeling through Stochastic Differential Equations. International Conference on Learning Representations.[2] Ho J, Jain A, Abbeel P. Denoising diffusion probabilistic models. Advances in Neural Information Processing Systems, 2020, 33: 6840-6851.[3] Song J, Meng C, Ermon S. Denoising Diffusion Implicit Models. International Conference on Learning Representations.[4] Lu C, Zhou Y, Bao F, et al. DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps. Advances in Neural Information Processing Systems.[5] Bansal A, Borgnia E, Chu H M, et al. Cold diffusion: Inverting arbitrary image transforms without noise. arXiv preprint arXiv:2208.09392, 2022.[6] Liu X, Wu L, Ye M. Learning Diffusion Bridges on Constrained Domains//International Conference on Learning Representations.[7] Liu Q. Rectified flow: A marginal preserving approach to optimal transport. arXiv preprint arXiv:2209.14577, 2022.


更多阅读



#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编



🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧


·
·

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存