生成对抗网络中的Wasserstein距离与WGAN原理
字数 1893 2025-10-29 12:21:34

生成对抗网络中的Wasserstein距离与WGAN原理

题目描述

在原始生成对抗网络(GAN)中,生成器与判别器的训练存在梯度不稳定、模式崩溃等问题。Wasserstein GAN(WGAN)通过引入Wasserstein距离替代JS散度作为损失函数,有效改善了训练稳定性。本题要求详解Wasserstein距离的定义、WGAN的改进原理及实现细节。


1. 原始GAN的问题根源

核心矛盾:原始GAN的判别器使用JS散度(Jensen-Shannon Divergence)衡量真实分布与生成分布的距离。当两者重叠部分可忽略时(常见于高维空间),JS散度会饱和为常数,导致梯度消失。

举例说明

  • 假设真实分布\(P_r\)和生成分布\(P_g\)是两条平行直线上的均匀分布,二者在二维空间无重叠。
  • 此时KL散度(Kullback-Leibler Divergence)为无穷大,JS散度恒为\(\log 2\),梯度为0,生成器无法更新。

2. Wasserstein距离的定义

Wasserstein距离(又称Earth-Mover距离)描述将概率分布\(P_r\)“搬移”成分布\(P_g\)所需的最小成本:

\[W(P_r, P_g) = \inf_{\gamma \in \Pi(P_r, P_g)} \mathbb{E}_{(x,y) \sim \gamma} [\|x - y\|] \]

  • \(\Pi(P_r, P_g)\)\(P_r\)\(P_g\)的所有联合分布集合;
  • \(\gamma(x,y)\)表示将\(x\)处的质量搬到\(y\)处的运输方案;
  • 直观理解:计算分布形态差异时考虑几何空间中的实际距离,即使分布不重叠也能提供有意义的梯度。

3. WGAN的改进原理

3.1 损失函数重构

WGAN的判别器改为批评器(Critic),直接拟合Wasserstein距离:

\[L = \mathbb{E}_{x \sim P_r} [D(x)] - \mathbb{E}_{z \sim P_z} [D(G(z))] \]

  • 批评器\(D\)的目标是最大化\(L\)(拉大真实样本与生成样本的得分差距);
  • 生成器\(G\)的目标是最小化\(\mathbb{E}_{z \sim P_z} [D(G(z))]\)(让生成样本的得分接近真实样本)。

3.2 Lipschitz约束的实现

Wasserstein距离的计算需满足函数\(D\)的1-Lipschitz连续性(即梯度范数不超过1)。WGAN通过以下方法约束:

  • 权重裁剪(原始WGAN):强制限制批评器参数在区间\([-c, c]\)内,但易导致梯度不稳定或消失。
  • 梯度惩罚(WGAN-GP改进):在损失函数中增加梯度惩罚项:

\[L_{GP} = \lambda \mathbb{E}_{\hat{x} \sim P_{\hat{x}}} [(\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2] \]

其中\(\hat{x}\)是真实样本与生成样本连线上的随机插值点。


4. WGAN的训练步骤

  1. 初始化:批评器\(D\)和生成器\(G\)的参数。
  2. 批评器更新(多次迭代):
    • 采样真实数据批次\(\{x_i\} \sim P_r\)和噪声\(\{z_i\} \sim P_z\)
    • 计算生成数据\(G(z_i)\)
    • 计算Wasserstein损失\(L_D = -\left( \frac{1}{m} \sum D(x_i) - \frac{1}{m} \sum D(G(z_i)) \right)\)
    • 若使用WGAN-GP,增加梯度惩罚项;
    • 更新批评器参数。
  3. 生成器更新(单次迭代):
    • 采样噪声\(\{z_i\} \sim P_z\)
    • 计算生成器损失\(L_G = -\frac{1}{m} \sum D(G(z_i))\)
    • 更新生成器参数。
  4. 重复步骤2-3至收敛。

5. WGAN的优势

  • 训练稳定性:Wasserstein距离提供平滑梯度,避免模式崩溃;
  • 损失函数可解释性:损失值下降直接反映生成质量提升;
  • 兼容复杂架构:适用于强化学习、文本生成等场景。

关键注意点:WGAN-GP虽提升性能,但梯度惩罚需额外计算成本,需权衡效率与效果。

生成对抗网络中的Wasserstein距离与WGAN原理 题目描述 在原始生成对抗网络(GAN)中,生成器与判别器的训练存在梯度不稳定、模式崩溃等问题。Wasserstein GAN(WGAN)通过引入Wasserstein距离替代JS散度作为损失函数,有效改善了训练稳定性。本题要求详解Wasserstein距离的定义、WGAN的改进原理及实现细节。 1. 原始GAN的问题根源 核心矛盾 :原始GAN的判别器使用JS散度(Jensen-Shannon Divergence)衡量真实分布与生成分布的距离。当两者重叠部分可忽略时(常见于高维空间),JS散度会饱和为常数,导致梯度消失。 举例说明 : 假设真实分布\( P_ r \)和生成分布\( P_ g \)是两条平行直线上的均匀分布,二者在二维空间无重叠。 此时KL散度(Kullback-Leibler Divergence)为无穷大,JS散度恒为\( \log 2 \),梯度为0,生成器无法更新。 2. Wasserstein距离的定义 Wasserstein距离(又称Earth-Mover距离)描述将概率分布\( P_ r \)“搬移”成分布\( P_ g \)所需的最小成本: \[ W(P_ r, P_ g) = \inf_ {\gamma \in \Pi(P_ r, P_ g)} \mathbb{E}_ {(x,y) \sim \gamma} [ \|x - y\| ] \] \( \Pi(P_ r, P_ g) \)是\( P_ r \)和\( P_ g \)的所有联合分布集合; \( \gamma(x,y) \)表示将\( x \)处的质量搬到\( y \)处的运输方案; 直观理解:计算分布形态差异时考虑几何空间中的实际距离,即使分布不重叠也能提供有意义的梯度。 3. WGAN的改进原理 3.1 损失函数重构 WGAN的判别器改为 批评器 (Critic),直接拟合Wasserstein距离: \[ L = \mathbb{E} {x \sim P_ r} [ D(x)] - \mathbb{E} {z \sim P_ z} [ D(G(z)) ] \] 批评器\( D \)的目标是最大化\( L \)(拉大真实样本与生成样本的得分差距); 生成器\( G \)的目标是最小化\( \mathbb{E}_ {z \sim P_ z} [ D(G(z)) ] \)(让生成样本的得分接近真实样本)。 3.2 Lipschitz约束的实现 Wasserstein距离的计算需满足函数\( D \)的1-Lipschitz连续性(即梯度范数不超过1)。WGAN通过以下方法约束: 权重裁剪 (原始WGAN):强制限制批评器参数在区间\([ -c, c ]\)内,但易导致梯度不稳定或消失。 梯度惩罚 (WGAN-GP改进):在损失函数中增加梯度惩罚项: \[ L_ {GP} = \lambda \mathbb{E} {\hat{x} \sim P {\hat{x}}} [ (\|\nabla_ {\hat{x}} D(\hat{x})\|_ 2 - 1)^2 ] \] 其中\( \hat{x} \)是真实样本与生成样本连线上的随机插值点。 4. WGAN的训练步骤 初始化 :批评器\( D \)和生成器\( G \)的参数。 批评器更新 (多次迭代): 采样真实数据批次\( \{x_ i\} \sim P_ r \)和噪声\( \{z_ i\} \sim P_ z \); 计算生成数据\( G(z_ i) \); 计算Wasserstein损失\( L_ D = -\left( \frac{1}{m} \sum D(x_ i) - \frac{1}{m} \sum D(G(z_ i)) \right) \); 若使用WGAN-GP,增加梯度惩罚项; 更新批评器参数。 生成器更新 (单次迭代): 采样噪声\( \{z_ i\} \sim P_ z \); 计算生成器损失\( L_ G = -\frac{1}{m} \sum D(G(z_ i)) \); 更新生成器参数。 重复步骤2-3至收敛。 5. WGAN的优势 训练稳定性 :Wasserstein距离提供平滑梯度,避免模式崩溃; 损失函数可解释性 :损失值下降直接反映生成质量提升; 兼容复杂架构 :适用于强化学习、文本生成等场景。 关键注意点 :WGAN-GP虽提升性能,但梯度惩罚需额外计算成本,需权衡效率与效果。