Unifying Bayesian Flow Networks and Diffusion Models through Stochastic Differential Equations论文作者:
薛凯文*、周聿浩*、聂燊、闵旭、张晓露、周军、李崇轩论文链接:
https://arxiv.org/abs/2404.15766代码链接:
https://github.com/ML-GSAI/BFN-Solver背景
如今深度生成式模型被广泛运用在计算机视觉和自然语言处理等领域。这些模型面临的主要挑战在于如何有效地表示复杂的概率分布,这些分布通常高度依赖于具体的数据或模态,同时能需要够进行有效的训练和推理。自回归模型(ARM)擅长对序列和离散数据(如文本)进行建模,但在推理速度上存在明显瓶颈,另一方面,扩散模型(DM)采用从粗糙到精细的渐进式方法,实现了生成质量与效率的更好平衡。目前扩散模型在图像生成领域取得了很大的进展,但它在处理离散变量时面临挑战,分数匹配算法在这里并不直接适用。贝叶斯流网络(BFN)是一种新的生成式模型,与扩散模型(DM)不同,BFN 不直接处理样本,而是通过贝叶斯推理迭代地细化不同噪声级别的分布参数。这种模型因其可微性质显示出在连续与离散数据建模上的巨大潜力,并且能够实现快速采样。概述
为了更深入地理解和改进 BFN,本文利用随机微分方程(SDE)将其与扩散模型建立了联系。我们发现了与 BFN 中的加噪过程相对应的线性 SDE,证明了 BFN 的回归损失等价于降噪分数匹配(DSM)。
基于此发现,我们验证 BFN 中的采样器近似等同于反向 SDE 的一阶 Solver。此外,通过借鉴扩散模型中现有的快速采样方案,我们提出了一种新的采样方法——BFN-Solvers,通过在图像和文本数据集上的测试,BFN-Solvers 在进行有限次数(例如 10 次)的函数评估下,能显著提升样本质量,并超越原始 BFN 采样器。特别地,我们的最优采样器实现了 5~20 倍的速度提升如图 3、4 所示。
贝叶斯流网络
BFN通过迭代不同噪声水平下的分布参数,这种策略使 BFN 在连续和离散数据上都是可微分的,同时使得少步数生成成为可能。给定一个从真实数据分布采样得到的样本点 ,BFN 定义了一个贝叶斯更新过程,通过不同噪声水平的带噪数据 和贝叶斯公式更新样本的先验分布参数,总共 步,产生了参数序列 。每一步的噪声水平由准确率 决定:其中 可以被设置为一个简单的先验。随着 的增加, 趋近于样本点 的狄拉克分布,当 趋近于无穷时,分布收敛到样本点 的狄拉克分布。与扩散模型类似,贝叶斯更新过程同样具有单步采样的性质,即我们可以得到分布 的解析形式(具体推导可以参考 BFN 论文 sec3.4):其中准确率时间表 (accuracy scheduel) 。上述的贝叶斯更新过程定义了给定样本 作为条件时, 从固定先验到样本的狄拉克分布的过程,如果我们可以不依赖 , 从 中采样,我们能够通过从固定先验出发得到样本点 的狄拉克分布,并从中采样得到生成数据。困难在于估计 的计算开销是巨大的,因为需要用到整个数据集,我们可以用神经网络学习这些条件概率,这样我们得到了由神经网络定义的参数更新过程 。特别地,为了方便之后损失函数的化简,我们定义 成以下期望的形式:
到这里,我们完成了 BFN 模型的定义,接下来我们考虑如何训练 BFN 和 BFN 的具体参数化形式。
BFN 的优化目标是负对数似然的变分下界:
作者从信息论角度理解 。数据所有者根据噪声时间表向接受者传输有损信息,接受者根据当前时刻的先验分布接受此信息,并通过贝叶斯更新得到后验分布。 表示传输 所需的 nat (natural unit of information) 的期望数量。接着我们需要考虑 和 的具体参数化形式。这里直接列出了关于离散数据 的参数化形式。令:其中 代表神经网络,输出维度为 。我们可以计算得到:假设加噪分布为正态分布给训练带来了立即的好处,我们可以进一步简化损失函数中的 KL 项。在这里我们直接给出了连续时间的化简后的损失函数,具体推导见 BFN 论文 Sec 3.9:
准确率时间表 需要满足单调递增的性质,这里经验上被设置为 作为一个超参数。通过对化简后的损失函数做蒙特卡洛估计,我们得到了可以直接用于训练的损失函数。具体训练和采样流程如算法 8、9 所示。
我们基本完成了 BFN 的介绍。接下来介绍我们近期对 BFN 的一些研究。简而言之,我们通过对应于 BFN 加噪策略的线性随机微分方程建立 BFN 和扩散模型的联系,将 BFN 训练损失函数对应于去噪得分匹配损失(DSM)。基于这个认识,我们发现了 BFN 的原始采样算法对应于反向 SDE 的离散化的某种近似,并为 BFN 开发了加速采样算法。
通过SDE统一BFN和DM
我们发现连续时间 BFN 对离散数据的噪声添加过程唯一求解了一个线性 SDE,总结为定理 5.1。
song 等人(2021)指出线性 SDE 对应一个由未知得分函数定义的反向 SDE:
值得注意的是定理 5.1 描述的是隐变量 而不是 的动态,如图 2 所示,这暗示了 BFN 原始采样算法不是通过直接离散化 SDE 进行采样。我们证明连续时间 BFN 在离散数据上的训练目标是 DSM 的重新参数化形式,总结为定理 5.2。
定理 5.1 和定理 5.2 将 BFN 与现有的离散状态扩散模型区分开来。具体来说,应用于离散数据的 BFN 求解线性 SDE,并使用 DSM 进行训练,与连续状态扩散模型无缝对齐。因此,在不改变离散数据的情况下,BFN 可以直接利用连续状态扩散模型的现有的经验进行改进。在这个工作中我们尝试改进 BFN 的采样效率。
加速采样
首先我们建立了 BFN 原始采样算法和反向 SDE 离散化的联系,总结为命题 5.3。
离散分布采样的的作用在理论上仍然不清楚。然而如图 6 所示,text8 数据集上的实验表明,去除分类采样步骤会在少于 50 网络前传次数(NFE)时一致提高效果,在其他 NFE 时几乎相同的结果。
一件重要的事情在于我们可以利用来自扩散模型的快速采样配方直接应用于 BFN。具体来说,首先我们可以得到线性 SDE 对应的概率流 ODE(Song 等人,2021):
这个 ODE 产生与具有无穷小步长的相应 SDE 相同的数据分布,并且由于其确定性性质,在大步长时具有较小的离散化误差。为了求解这个 ODE,我们借鉴 DPM-Solvers (Lu 等人,2022)的做法,首先利用 ODE 的半线性性将其进一步简化为神经网络的积分:
然后通过数值方法近似神经网络来降低离散化误差,一阶近似和二阶近似分别诱导出了 BFN-Solver1 和 BFN-Solver2 如算法 6,7 所示。
实验结果表明不同阶数的 BFN-Solver 在 Text8 数据集上显著优于基于相同预训练模型的同数量 NFE 的原始 BFN-Solver 如图 4 所示。同样在 CIFAR10 图像数据集上我们也观察到了类似的结果如图 3 所示。
总结
我们通过线性随机微分方程建立 BFN 和连续扩散模型的联系,受到连续扩散模型现有工作的启发,我们对采样算法进行了改进,取得了不错的结果。未来可以通过改进训练策略进一步改进和提高 BFN 的能力。
[1] Graves, A., Srivastava, R. K., Atkinson, T. & Gomez, F. "Bayesian Flow Networks."
[2] Song, Y. et al. "Score-Based Generative Modeling Through Stochastic Differential Equations."
[3] Lu, C. et al. "DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps."
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:hr@paperweekly.site
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧