生成对抗网络中的Wasserstein距离与WGAN原理
题目描述
在原始生成对抗网络(GAN)中,生成器与判别器的训练存在梯度不稳定、模式崩溃等问题。Wasserstein GAN(WGAN)通过引入Wasserstein距离替代JS散度作为损失函数,有效改善了训练稳定性。本题要求详解Wasserstein距离的定义、WGAN的改进原理及实现细节。
1. 原始GAN的问题根源
核心矛盾:原始GAN的判别器使用JS散度(Jensen-Shannon Divergence)衡量真实分布与生成分布的距离。当两者重叠部分可忽略时(常见于高维空间),JS散度会饱和为常数,导致梯度消失。
举例说明:
- 假设真实分布\(P_r\)和生成分布\(P_g\)是两条平行直线上的均匀分布,二者在二维空间无重叠。
- 此时KL散度(Kullback-Leibler Divergence)为无穷大,JS散度恒为\(\log 2\),梯度为0,生成器无法更新。
2. Wasserstein距离的定义
Wasserstein距离(又称Earth-Mover距离)描述将概率分布\(P_r\)“搬移”成分布\(P_g\)所需的最小成本:
\[W(P_r, P_g) = \inf_{\gamma \in \Pi(P_r, P_g)} \mathbb{E}_{(x,y) \sim \gamma} [\|x - y\|] \]
- \(\Pi(P_r, P_g)\)是\(P_r\)和\(P_g\)的所有联合分布集合;
- \(\gamma(x,y)\)表示将\(x\)处的质量搬到\(y\)处的运输方案;
- 直观理解:计算分布形态差异时考虑几何空间中的实际距离,即使分布不重叠也能提供有意义的梯度。
3. WGAN的改进原理
3.1 损失函数重构
WGAN的判别器改为批评器(Critic),直接拟合Wasserstein距离:
\[L = \mathbb{E}_{x \sim P_r} [D(x)] - \mathbb{E}_{z \sim P_z} [D(G(z))] \]
- 批评器\(D\)的目标是最大化\(L\)(拉大真实样本与生成样本的得分差距);
- 生成器\(G\)的目标是最小化\(\mathbb{E}_{z \sim P_z} [D(G(z))]\)(让生成样本的得分接近真实样本)。
3.2 Lipschitz约束的实现
Wasserstein距离的计算需满足函数\(D\)的1-Lipschitz连续性(即梯度范数不超过1)。WGAN通过以下方法约束:
- 权重裁剪(原始WGAN):强制限制批评器参数在区间\([-c, c]\)内,但易导致梯度不稳定或消失。
- 梯度惩罚(WGAN-GP改进):在损失函数中增加梯度惩罚项:
\[L_{GP} = \lambda \mathbb{E}_{\hat{x} \sim P_{\hat{x}}} [(\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2] \]
其中\(\hat{x}\)是真实样本与生成样本连线上的随机插值点。
4. WGAN的训练步骤
- 初始化:批评器\(D\)和生成器\(G\)的参数。
- 批评器更新(多次迭代):
- 采样真实数据批次\(\{x_i\} \sim P_r\)和噪声\(\{z_i\} \sim P_z\);
- 计算生成数据\(G(z_i)\);
- 计算Wasserstein损失\(L_D = -\left( \frac{1}{m} \sum D(x_i) - \frac{1}{m} \sum D(G(z_i)) \right)\);
- 若使用WGAN-GP,增加梯度惩罚项;
- 更新批评器参数。
- 生成器更新(单次迭代):
- 采样噪声\(\{z_i\} \sim P_z\);
- 计算生成器损失\(L_G = -\frac{1}{m} \sum D(G(z_i))\);
- 更新生成器参数。
- 重复步骤2-3至收敛。
5. WGAN的优势
- 训练稳定性:Wasserstein距离提供平滑梯度,避免模式崩溃;
- 损失函数可解释性:损失值下降直接反映生成质量提升;
- 兼容复杂架构:适用于强化学习、文本生成等场景。
关键注意点:WGAN-GP虽提升性能,但梯度惩罚需额外计算成本,需权衡效率与效果。