扩散模型中的去噪分数匹配(Denoising Score Matching)原理与训练目标
题目描述
在扩散模型与基于分数的生成模型中,去噪分数匹配是一种重要的训练方法。它旨在通过向数据添加噪声并训练一个模型来估计扰动后数据分布的梯度(即分数),从而避开直接计算真实数据分布分数的困难。本题目将深入讲解去噪分数匹配的核心思想、数学推导、与普通分数匹配的区别,以及其在扩散模型训练中的具体应用。
解题过程
步骤1:理解“分数”的定义
- 在概率建模中,分数(Score) 定义为数据分布的对数概率密度的梯度。
- 对于一个概率分布 \(p(\mathbf{x})\)(其中 \(\mathbf{x} \in \mathbb{R}^D\)),其分数函数为:
\[ \nabla_{\mathbf{x}} \log p(\mathbf{x}) \]
- 这个梯度向量指向 \(p(\mathbf{x})\) 概率密度增加最快的方向。
步骤2:为何需要分数匹配
- 在生成模型中,我们通常想学习一个模型 \(s_{\theta}(\mathbf{x})\) 来逼近真实数据分布 \(p_{\text{data}}(\mathbf{x})\) 的分数 \(\nabla_{\mathbf{x}} \log p_{\text{data}}(\mathbf{x})\)。
- 一旦获得分数估计,可以通过朗之万动力学(Langevin Dynamics) 等采样方法从分布中生成样本。
- 直接计算真实分数需要知道 \(p_{\text{data}}(\mathbf{x})\) 的归一化常数,这通常是未知的。
步骤3:普通分数匹配的问题
- 普通分数匹配(Score Matching)的目标是最小化模型分数与真实分数之间的 Fisher 散度:
\[ J(\theta) = \mathbb{E}_{p_{\text{data}}(\mathbf{x})} \left[ \frac{1}{2} \left\| s_{\theta}(\mathbf{x}) - \nabla_{\mathbf{x}} \log p_{\text{data}}(\mathbf{x}) \right\|^2 \right] \]
- 但该目标依赖真实分数 \(\nabla_{\mathbf{x}} \log p_{\text{data}}(\mathbf{x})\),无法直接计算。
步骤4:去噪分数匹配的核心思想
- 去噪分数匹配(Denoising Score Matching)的关键思路是:对数据加噪声,然后让模型学习去噪后的分数。
- 具体步骤:
- 选择一个噪声分布 \(q(\tilde{\mathbf{x}} | \mathbf{x})\),通常为高斯噪声:
\[ q(\tilde{\mathbf{x}} | \mathbf{x}) = \mathcal{N}(\tilde{\mathbf{x}}; \mathbf{x}, \sigma^2 I) \]
- 构造加噪后的数据分布:
\[ q_{\sigma}(\tilde{\mathbf{x}}) = \int p_{\text{data}}(\mathbf{x}) q(\tilde{\mathbf{x}} | \mathbf{x}) d\mathbf{x} \]
- 训练一个模型 \(s_{\theta}(\tilde{\mathbf{x}})\) 来估计加噪分布的分数 \(\nabla_{\tilde{\mathbf{x}}} \log q_{\sigma}(\tilde{\mathbf{x}})\)。
步骤5:去噪分数匹配的训练目标
- 目标函数推导:
- 对于加噪分布 \(q_{\sigma}(\tilde{\mathbf{x}})\),其分数为:
\[ \nabla_{\tilde{\mathbf{x}}} \log q_{\sigma}(\tilde{\mathbf{x}}) = \mathbb{E}_{p_{\text{data}}(\mathbf{x})} \left[ \nabla_{\tilde{\mathbf{x}}} \log q(\tilde{\mathbf{x}} | \mathbf{x}) \right] \]
- 对于高斯噪声 \(q(\tilde{\mathbf{x}} | \mathbf{x}) = \mathcal{N}(\tilde{\mathbf{x}}; \mathbf{x}, \sigma^2 I)\),有:
\[ \nabla_{\tilde{\mathbf{x}}} \log q(\tilde{\mathbf{x}} | \mathbf{x}) = - \frac{\tilde{\mathbf{x}} - \mathbf{x}}{\sigma^2} \]
- 因此,去噪分数匹配的目标是让模型 \(s_{\theta}(\tilde{\mathbf{x}})\) 逼近这个条件分数:
\[ \min_{\theta} \mathbb{E}_{p_{\text{data}}(\mathbf{x})} \mathbb{E}_{q(\tilde{\mathbf{x}} | \mathbf{x})} \left[ \frac{1}{2} \left\| s_{\theta}(\tilde{\mathbf{x}}) + \frac{\tilde{\mathbf{x}} - \mathbf{x}}{\sigma^2} \right\|^2 \right] \]
- 直观理解:
- 模型 \(s_{\theta}(\tilde{\mathbf{x}})\) 学习预测加噪样本 \(\tilde{\mathbf{x}}\) 指向原始干净数据 \(\mathbf{x}\) 的方向。
- 在扩散模型中,这对应去噪方向。
步骤6:与扩散模型的关联
- 在扩散模型中,前向过程逐步加噪:
\[ q(\mathbf{x}_t | \mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{1 - \beta_t} \mathbf{x}_{t-1}, \beta_t I) \]
- 每个时间步的分数匹配目标可写为:
\[ \mathbb{E}_{t, \mathbf{x}_0, \mathbf{x}_t} \left[ \lambda(t) \left\| s_{\theta}(\mathbf{x}_t, t) - \nabla_{\mathbf{x}_t} \log q(\mathbf{x}_t | \mathbf{x}_0) \right\|^2 \right] \]
- 由于 \(q(\mathbf{x}_t | \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, (1 - \bar{\alpha}_t) I)\),其中 \(\bar{\alpha}_t = \prod_{s=1}^t (1 - \beta_s)\),分数可解析计算为:
\[ \nabla_{\mathbf{x}_t} \log q(\mathbf{x}_t | \mathbf{x}_0) = - \frac{\mathbf{x}_t - \sqrt{\bar{\alpha}_t} \mathbf{x}_0}{1 - \bar{\alpha}_t} \]
- 模型 \(s_{\theta}(\mathbf{x}_t, t)\) 通常参数化为预测噪声 \(\epsilon\) 或去噪后的 \(\mathbf{x}_0\)。
步骤7:训练与采样流程
- 训练:
- 从训练集中采样干净数据 \(\mathbf{x}_0 \sim p_{\text{data}}\)。
- 随机采样时间步 \(t \sim \text{Uniform}(1, T)\)。
- 采样噪声 \(\epsilon \sim \mathcal{N}(0, I)\),构造加噪样本 \(\mathbf{x}_t = \sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon\)。
- 优化损失函数(简化形式):
\[ \mathcal{L}(\theta) = \mathbb{E}_{t, \mathbf{x}_0, \epsilon} \left[ \| \epsilon - \epsilon_{\theta}(\mathbf{x}_t, t) \|^2 \right] \]
其中 $ \epsilon_{\theta} $ 是预测噪声的模型,与分数模型的关系为:
\[ s_{\theta}(\mathbf{x}_t, t) = - \frac{\epsilon_{\theta}(\mathbf{x}_t, t)}{\sqrt{1 - \bar{\alpha}_t}} \]
- 采样:
- 从随机噪声 \(\mathbf{x}_T \sim \mathcal{N}(0, I)\) 开始。
- 从 \(t = T\) 到 \(t = 1\) 逐步去噪:
\[ \mathbf{x}_{t-1} = \frac{1}{\sqrt{1 - \beta_t}} \left( \mathbf{x}_t + \beta_t s_{\theta}(\mathbf{x}_t, t) \right) + \sqrt{\beta_t} \mathbf{z}, \quad \mathbf{z} \sim \mathcal{N}(0, I) \]
- 最终得到生成样本 \(\mathbf{x}_0\)。
步骤8:去噪分数匹配的优势
- 避免计算归一化常数:通过加噪将目标转为条件分布,避免了真实分布的直接计算。
- 适用于多尺度建模:通过使用不同噪声强度 \(\sigma\),可以学习数据分布在不同尺度下的分数,从而改善采样质量。
- 与扩散模型自然结合:在扩散模型中,每个时间步的加噪分布均对应一个去噪分数匹配问题。
总结
去噪分数匹配通过向数据添加噪声,将难以直接计算的真实数据分布分数估计问题,转化为可求解的条件分数匹配问题。在扩散模型中,该方法对应着训练一个模型来预测每个时间步的噪声(或去噪方向),从而通过逐步去噪过程生成新样本。该方法为基于分数的生成模型提供了稳定且高效的训练基础。