变分自编码器(VAE)的损失函数推导与优化过程
题目描述
变分自编码器(VAE)是一种生成模型,结合了神经网络与变分推断,目标是学习输入数据(如图像)的概率分布,并从中生成新样本。VAE的核心挑战是直接建模真实分布不可行,因此通过引入潜在变量(latent variable)并优化证据下界(ELBO)来近似真实分布。本题要求详细推导VAE的损失函数(即ELBO),并解释其优化过程。
解题过程
-
问题建模
- 设输入数据为 \(x\),潜在变量为 \(z\)(通常服从高斯先验 \(p(z) = \mathcal{N}(0, I)\))。
- 目标:学习真实后验分布 \(p(z|x)\),但直接计算困难(因涉及积分 \(p(x) = \int p(x|z)p(z)dz\))。
- VAE引入变分分布 \(q_\phi(z|x)\)(编码器)近似 \(p(z|x)\),并定义解码器 \(p_\theta(x|z)\) 生成数据。
-
证据下界(ELBO)推导
- 从对数似然 \(\log p(x)\) 出发,引入 \(q_\phi(z|x)\):
\[ \log p(x) = \mathbb{E}_{q_\phi(z|x)} \left[ \log \frac{p(x, z)}{q_\phi(z|x)} \right] + D_{\text{KL}}(q_\phi(z|x) \| p(z|x)) \]
- 由于KL散度非负,推导出ELBO:
\[ \log p(x) \geq \mathbb{E}_{q_\phi(z|x)} \left[ \log p_\theta(x|z) \right] - D_{\text{KL}}(q_\phi(z|x) \| p(z)) \]
- 第一项:重构损失(reconstruction loss),衡量解码器重建数据的能力,常用交叉熵或均方误差。
- 第二项:正则化损失(regularization loss),约束 \(q_\phi(z|x)\) 接近先验 \(p(z)\),避免过拟合。
- 重参数化技巧(Reparameterization Trick)
- 问题:直接采样 \(z \sim q_\phi(z|x)\) 导致梯度无法反向传播。
- 解决方案:将采样过程解耦为确定性部分和随机噪声。例如,若 \(q_\phi(z|x) = \mathcal{N}(\mu, \sigma^2)\),则令:
\[ z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I) \]
- 梯度可通过 \(\mu\) 和 \(\sigma\) 回传,而 \(\epsilon\) 作为随机变量不参与求导。
- 损失函数优化
- 总损失函数为负ELBO:
\[ \mathcal{L}(\theta, \phi) = -\mathbb{E}_{q_\phi(z|x)} \left[ \log p_\theta(x|z) \right] + D_{\text{KL}}(q_\phi(z|x) \| p(z)) \]
- 重构项:通过蒙特卡洛采样估计,对每个样本 \(x_i\),采样 \(z \sim q_\phi(z|x_i)\) 计算 \(\log p_\theta(x_i|z)\)。
- KL散度项:若 \(p(z) = \mathcal{N}(0, I)\),\(q_\phi(z|x) = \mathcal{N}(\mu, \sigma^2)\),则有闭式解:
\[ D_{\text{KL}} = -\frac{1}{2} \sum_{j=1}^J (1 + \log \sigma_j^2 - \mu_j^2 - \sigma_j^2) \]
其中 $ J $ 为潜在空间维度。
-
训练流程
- 步骤1:编码器输入 \(x\),输出 \(\mu\) 和 \(\log \sigma^2\)(保证方差非负)。
- 步骤2:重参数化采样 \(z = \mu + \epsilon \odot \exp(\log \sigma^2/2)\)。
- 步骤3:解码器将 \(z\) 映射为重建数据 \(\hat{x}\)。
- 步骤4:计算重构损失(如二元交叉熵)和KL散度,反向传播更新参数 \(\theta\) 和 \(\phi\)。
-
关键设计细节
- 平衡重构与正则化:KL项可能过早趋近零("后验坍塌"),可通过加权KL项(如 \(\beta\)-VAE)控制。
- 潜在空间维度:维度过低限制生成能力,过高导致训练不稳定。
- 解码器分布假设:连续数据用高斯分布,离散数据用伯努利分布。
总结
VAE通过变分推断将生成问题转化为ELBO优化,结合重参数化技巧实现端到端训练。其损失函数直接反映了生成质量与潜在空间规整性的权衡,是理解概率生成模型的基础。