生成对抗网络(GAN)中的Wasserstein距离优化算法(WGAN-GP)
题目描述
我们将深入讲解生成对抗网络(GAN)训练中一个重要的改进算法:带有梯度惩罚的Wasserstein GAN(WGAN-GP)。原始GAN在训练时存在模式崩溃、梯度消失或不稳定等问题。WGAN-GP算法通过使用Wasserstein距离衡量真实与生成数据分布的差异,并引入梯度惩罚项来满足Lipschitz约束,从而稳定训练、提升生成质量。你需要理解WGAN-GP的核心思想、Wasserstein距离的优势、梯度惩罚的数学原理,以及完整的算法步骤。
解题过程
1. 传统GAN的训练问题回顾
传统GAN由生成器G和判别器D构成。判别器D输出一个概率(0到1之间),表示输入是真实数据的置信度。其损失函数为最小最大博弈:
\
\[ \min_G \max_D V(D, G) = \mathbb{E}_{x \sim P_r} [\log D(x)] + \mathbb{E}_{z \sim p_z} [\log (1 - D(G(z)))] \ \]
其中$P_r$是真实数据分布,$p_z$是噪声先验分布(如标准正态分布),$z$是噪声输入。
训练中常见的问题包括:
- 梯度消失:当判别器D训练得过于强大时,生成器G的梯度会趋近于零,导致无法更新。
- 模式崩溃:生成器G只产生少数几种样本,缺乏多样性。
- 评价指标不敏感:判别器输出的概率值(如JS散度)在分布不重叠时饱和,无法提供有效的梯度信息。
2. Wasserstein距离的引入
Wasserstein距离(Earth-Mover距离)衡量两个概率分布$P_r$和$P_g$之间的差异。其定义为:
\
\[ W(P_r, P_g) = \inf_{\gamma \in \Pi(P_r, P_g)} \mathbb{E}_{(x, y) \sim \gamma} [\|x - y\|] \ \]
其中$\Pi(P_r, P_g)$是联合分布集合,其边缘分布分别为$P_r$和$P_g$。直观上,它表示将分布$P_r$“搬运”成分布$P_g$所需的最小“工作量”。
关键优势:即使两个分布没有重叠(支撑集不相交),Wasserstein距离仍然能提供有意义的梯度,避免了JS散度在这种情况下变为常数的缺陷。
3. 从Wasserstein距离到WGAN的损失函数
通过Kantorovich-Rubinstein对偶,Wasserstein距离可表达为:
\
\[ W(P_r, P_g) = \sup_{\|f\|_L \leq 1} \mathbb{E}_{x \sim P_r} [f(x)] - \mathbb{E}_{x \sim P_g} [f(x)] \ \]
其中上确界取遍所有1-Lipschitz函数$f: \mathcal{X} \to \mathbb{R}$。Lipschitz约束要求存在常数$K$使得$|f(x) - f(y)| \leq K |x - y| $,此处$K=1$。
在WGAN中,判别器被替换为Critic(评价函数)$D_w$(参数为$w$),其目标是最大化上述差值,而生成器$G_\theta$(参数为$θ$)则最小化该差值。因此损失函数为:
- Critic的目标:\
\[ L_D = \mathbb{E}_{x \sim P_g} [D_w(x)] - \mathbb{E}_{x \sim P_r} [D_w(x)] \ \]
- 生成器的目标:\
\[ L_G = - \mathbb{E}_{x \sim P_g} [D_w(x)] = -\mathbb{E}_{z \sim p_z} [D_w(G_\theta(z))] \ \]
注意:Critic的输出现在是一个实数分数,而非概率值。
4. 满足Lipschitz约束的挑战与WGAN-GP的解决方案
原始WGAN通过权重裁剪(将Critic参数限制在某个区间如[-0.01, 0.01])来近似满足Lipschitz约束,但这可能导致梯度爆炸或消失,且限制了Critic的表达能力。
WGAN-GP提出用梯度惩罚来直接强制Lipschitz约束。其思想是:对于1-Lipschitz函数,其梯度的范数应几乎处处不超过1。因此,在Critic的损失中添加一个惩罚项,惩罚那些梯度范数偏离1的输入点。
具体地,定义惩罚项为:
\
\[ \lambda \, \mathbb{E}_{\hat{x} \sim P_{\hat{x}}} [(\|\nabla_{\hat{x}} D_w(\hat{x})\|_2 - 1)^2] \ \]
其中:
- $\lambda$是惩罚系数,通常设为10。
- $P_{\hat{x}}$定义为沿着真实数据分布$P_r$与生成数据分布$P_g$之间采样点的直线上的随机插值点。即:
\
\[ \hat{x} = \epsilon x + (1 - \epsilon) \tilde{x}, \quad x \sim P_r, \, \tilde{x} \sim P_g, \, \epsilon \sim U[0,1] \ \]
这样做是因为理论上最优的Critic在$P_r$和$P_g$之间的区域梯度范数应为1,惩罚该区域可有效约束Critic。
5. WGAN-GP的完整算法步骤
输入:
- 真实数据分布$P_r$,噪声先验分布$p_z$(如$z \sim \mathcal{N}(0, I)$)
- Critic网络$D_w$,生成器网络$G_\theta$
- 学习率$\alpha$,批大小$m$,Critic每轮训练次数$n_{\text{critic}}$(通常$n_{\text{critic}}=5$),梯度惩罚系数$\lambda=10$
训练循环(直到收敛):
- 训练Critic(重复$n_{\text{critic}}$次):
a. 采样一批真实数据$\{x^{(i)}\}{i=1}^m \sim P_r$,一批噪声$\{z^{(i)}\}{i=1}^m \sim p_z$,生成假数据$\tilde{x}^{(i)} = G_\theta(z^{(i)})$。
b. 随机采样$\epsilon^{(i)} \sim U[0,1]$,计算插值样本:
\
\[ \hat{x}^{(i)} = \epsilon^{(i)} x^{(i)} + (1 - \epsilon^{(i)}) \tilde{x}^{(i)} \ \]
c. 计算Critic损失(带梯度惩罚):
\
\[ L_D = \underbrace{\frac{1}{m} \sum_{i=1}^m D_w(\tilde{x}^{(i)}) - \frac{1}{m} \sum_{i=1}^m D_w(x^{(i)})}_{\text{Wasserstein距离估计}} + \lambda \cdot \frac{1}{m} \sum_{i=1}^m (\|\nabla_{\hat{x}^{(i)}} D_w(\hat{x}^{(i)})\|_2 - 1)^2 \ \]
d. 更新Critic参数$w$:$w \gets w - \alpha \cdot \text{Adam}(\nabla_w L_D, w)$(通常使用Adam优化器)。
- 训练生成器(每$n_{\text{critic}}$轮Critic训练后执行一次):
a. 采样一批噪声$\{z^{(i)}\}_{i=1}^m \sim p_z$。
b. 计算生成器损失:
\
\[ L_G = - \frac{1}{m} \sum_{i=1}^m D_w(G_\theta(z^{(i)})) \ \]
c. 更新生成器参数$θ$:$θ \gets θ - \alpha \cdot \text{Adam}(\nabla_\theta L_G, θ)$。
输出:训练好的生成器$G_\theta$。
6. 关键细节说明
- 梯度计算:在计算梯度惩罚项时,需要对每个插值点$\hat{x}^{(i)}$计算$D_w$对其输入的梯度$\nabla_{\hat{x}^{(i)}} D_w(\hat{x}^{(i)})$,然后计算该梯度的L2范数。现代自动微分框架(如PyTorch、TensorFlow)可方便地计算这些梯度。
- 为何插值采样有效:理论分析表明,最优Critic在$P_r$和$P_g$之间的直线区域上梯度范数为1。惩罚这些点可促使整个函数满足1-Lipschitz约束,且在实践中稳定有效。
- 与原始WGAN对比:WGAN-GP移除了权重裁剪,允许Critic学习更复杂的函数,同时梯度惩罚提供了更平滑的约束,缓解了训练不稳定性。
7. 算法优势
- 训练稳定:Critic的梯度更可靠,减少了模式崩溃。
- 生成质量高:Wasserstein距离提供有意义的训练信号,有助于生成多样化、高质量的样本。
- 超参数鲁棒:梯度惩罚系数$\lambda$和Critic训练次数$n_{\text{critic}}$在较宽范围内有效,易于调整。
总结
WGAN-GP通过结合Wasserstein距离的分布度量优势和梯度惩罚的软Lipschitz约束,显著改善了GAN的训练稳定性与生成效果。其核心在于将原始GAN的概率判断转变为基于Wasserstein距离的分数优化,并用可微的梯度惩罚项替代权重裁剪,使Critic能够更好地引导生成器的训练。