扩散模型中的CFG(Classifier-Free Guidance)原理与条件生成引导机制
我会为你详解CFG算法的核心思想、数学原理及实现细节,确保每一步都清晰易懂。
1. 问题背景:扩散模型的条件生成困境
扩散模型(如DDPM)能够生成高质量样本,但在条件生成任务中面临挑战:
- 传统方法:使用分类器引导(Classifier Guidance),需额外训练一个分类器网络,在去噪过程中利用分类器的梯度调整生成方向。但分类器训练成本高,且易受对抗样本干扰。
- 核心需求:如何实现高效、稳定的条件生成,避免额外分类器?
CFG的提出:通过一种巧妙的训练策略,在单一扩散模型中同时支持无条件生成和条件生成,无需分类器。
2. CFG核心思想:联合训练条件与无条件扩散模型
关键洞察
- 训练时,以一定概率随机丢弃条件信息(如类别标签、文本描述),使模型同时学会:
- 条件生成:给定条件 \(y\) 时生成数据。
- 无条件生成:不依赖条件时生成数据。
- 推理时,通过条件与无条件预测的线性组合引导生成方向,增强条件控制。
数学符号说明
- \(x_t\):\(t\) 时刻的噪声数据。
- \(y\):条件(如类别标签、文本)。
- \(\epsilon_\theta(x_t, t, y)\):条件扩散模型预测的噪声。
- \(\epsilon_\theta(x_t, t, \varnothing)\):无条件扩散模型预测的噪声(条件被替换为空集 \(\varnothing\))。
3. 训练阶段:随机条件丢弃
步骤详解
- 输入构造:
每个训练样本为三元组 \((x_0, t, y)\),其中 \(x_0\) 是干净数据,\(t\) 是随机时间步,\(y\) 是条件。 - 随机丢弃:
以概率 \(p_{\text{drop}}\)(通常设为 0.1~0.2)将条件 \(y\) 替换为空标识 \(\varnothing\)。 - 训练目标:
模型需预测添加到 \(x_0\) 中的噪声 \(\epsilon\):
\[ \mathcal{L} = \mathbb{E}_{x_0, t, y, \epsilon} \left[ \| \epsilon - \epsilon_\theta(x_t, t, \tilde{y}) \|^2 \right] \]
其中 \(\tilde{y}\) 为可能被丢弃后的条件,\(x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \epsilon\)。
效果:
- 当 \(y\) 保留时,模型学习条件生成。
- 当 \(y\) 被丢弃时,模型学习无条件生成。
- 二者共享同一网络参数 \(\theta\),仅输入条件不同。
4. 推理阶段:引导生成
噪声预测组合公式
在去噪的每一步,用条件与无条件预测的加权组合作为最终噪声预测:
\[\hat{\epsilon}_\theta(x_t, t, y) = \epsilon_\theta(x_t, t, \varnothing) + w \cdot \left( \epsilon_\theta(x_t, t, y) - \epsilon_\theta(x_t, t, \varnothing) \right) \]
其中:
- \(w \geq 1\) 是引导尺度。
- \(\epsilon_\theta(x_t, t, y) - \epsilon_\theta(x_t, t, \varnothing)\) 表示“条件引起的噪声调整方向”。
物理意义
- 若 \(w=1\):退化为普通条件生成。
- 若 \(w>1\):沿条件方向增强,使生成更贴合条件,但可能降低多样性。
- 极端情况 \(w \gg 1\):生成高度贴合条件,但可能过拟合。
5. 为什么有效?概率视角解释
扩散模型的去噪过程可视为逼近数据分布 \(p(x|y)\)。CFG实际在隐式调整条件分布的对数梯度:
\[\hat{\epsilon}_\theta \approx -\nabla_{x_t} \log p(x_t) + w \cdot \left( -\nabla_{x_t} \log p(x_t|y) + \nabla_{x_t} \log p(x_t) \right) \]
化简后相当于:
\[\hat{\epsilon}_\theta \approx -\nabla_{x_t} \left[ \log p(x_t) + w \cdot \log p(y|x_t) \right] \]
这等价于用权重 \(w\) 放大条件似然项 \(p(y|x_t)\),从而强化条件控制。
6. 实现细节与注意事项
(1)条件编码
- 文本条件:通过CLIP或BERT编码为向量,输入扩散模型(如U-Net的交叉注意力层)。
- 类别条件:嵌入为向量后与时间步编码融合。
(2)引导尺度 \(w\) 的选择
- 通常 \(w \in [1, 20]\),需验证调优。
- 过大导致图像质量下降(过饱和、伪影)。
(3)训练技巧
- 在条件输入层添加零初始化的投影层,使初始时条件影响微弱,训练更稳定。
- 使用更高的无条件丢弃概率(如 0.2)以提升无条件生成质量。
7. 与分类器引导的对比
| 方面 | 分类器引导 | CFG |
|---|---|---|
| 额外网络 | 需单独训练分类器 | 无需额外网络 |
| 训练成本 | 高(分类器+扩散模型) | 低(仅扩散模型) |
| 对抗鲁棒性 | 敏感(依赖分类器梯度) | 稳定(端到端训练) |
| 条件控制灵活性 | 弱(依赖分类器性能) | 强(可调节 \(w\) ) |
8. 总结:CFG的核心贡献
- 统一框架:单一模型同时支持条件/无条件生成。
- 训练简单:仅需在原始扩散模型基础上增加条件随机丢弃。
- 效果显著:在文本到图像生成(如DALL·E 2、Stable Diffusion)中广泛应用,显著提升语义对齐质量。