变分自编码器(Variational Autoencoder, VAE)中重参数化技巧(Reparameterization Trick)的推导与作用
题目描述
在变分自编码器(VAE)中,我们的目标是学习一个概率生成模型,使得模型能从潜在空间 \(\mathbf{z}\) 中采样并生成新的数据 \(\mathbf{x}\)。在训练过程中,我们需要对潜在变量 \(\mathbf{z}\) 进行采样,以计算证据下界(ELBO)损失函数的梯度。然而,直接从由编码器网络输出的参数(如均值和方差)定义的后验分布 \(q_{\phi}(\mathbf{z} | \mathbf{x})\) 中采样,是一个随机过程,这使得梯度无法通过采样操作直接回传到编码器网络的参数 \(\phi\) 上。为了解决这个梯度不可传的问题,VAE 引入了重参数化技巧(Reparameterization Trick)。本题目要求详细解释重参数化技巧的动机、数学推导过程,以及它如何使得 VAE 能够通过标准的反向传播算法进行端到端的训练。
解题过程
我们将循序渐进地拆解这个问题,从问题的根源开始,到技巧的推导,最后阐明其在整个训练流程中的作用。
第一步:理解问题——为什么直接采样会导致梯度不可传?
- VAE 的编码器输出:对于输入数据点 \(\mathbf{x}_i\),VAE 的编码器(推断网络)输出的是潜在变量 \(\mathbf{z}\) 的后验分布 \(q_{\phi}(\mathbf{z} | \mathbf{x}_i)\) 的参数。通常,我们假设这个后验分布是各向同性的高斯分布:
\[ q_{\phi}(\mathbf{z} | \mathbf{x}_i) = \mathcal{N}(\mathbf{z}; \boldsymbol{\mu}_i, \boldsymbol{\sigma}_i^2 \mathbf{I}) \]
其中,$ \boldsymbol{\mu}_i = \text{Encoder}_{\boldsymbol{\mu}}(\mathbf{x}_i; \phi) $ 和 $ \log \boldsymbol{\sigma}_i^2 = \text{Encoder}_{\boldsymbol{\sigma}}(\mathbf{x}_i; \phi) $ 是编码器网络输出的向量。
- 采样操作:为了计算 ELBO 损失(包含重构项和 KL 散度项),我们需要从分布 \(q_{\phi}(\mathbf{z} | \mathbf{x}_i)\) 中采样一个具体的潜在变量 \(\mathbf{z}^{(i,l)}\)(其中 \(l\) 表示蒙特卡洛采样的索引)。直接采样可以表示为:
\[ \mathbf{z}^{(i,l)} \sim \mathcal{N}(\boldsymbol{\mu}_i, \boldsymbol{\sigma}_i^2 \mathbf{I}) \]
这个采样操作可以看作:
\[ \mathbf{z}^{(i,l)} = \boldsymbol{\mu}_i + \boldsymbol{\sigma}_i \odot \boldsymbol{\epsilon}^{(l)}, \quad \boldsymbol{\epsilon}^{(l)} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \]
其中 $ \odot $ 表示逐元素相乘。
- 梯度阻断:在上面的直接采样表达式中,采样过程依赖于分布参数 \(\boldsymbol{\mu}_i, \boldsymbol{\sigma}_i\),而 \(\boldsymbol{\mu}_i, \boldsymbol{\sigma}_i\) 又是编码器参数 \(\phi\) 的函数。当我们尝试计算损失函数 \(\mathcal{L}\) 关于 \(\phi\) 的梯度 \(\nabla_{\phi} \mathcal{L}\) 时,由于 \(\mathbf{z}^{(i,l)}\) 是通过一个从高斯分布中采样的随机过程得到的,这个随机性使得梯度 \(\nabla_{\phi} \mathbf{z}^{(i,l)}\) 无法定义或计算(因为采样操作本身不可微)。在计算图中,随机采样节点阻断了梯度从 \(\mathbf{z}\) 流向 \(\boldsymbol{\mu}_i, \boldsymbol{\sigma}_i\) 再到 \(\phi\) 的路径。
第二步:引入重参数化技巧——将随机性分离出去
- 核心思想:重参数化技巧的关键在于,将随机采样操作从一个依赖于参数 \(\phi\) 的过程中剥离出来,转化为一个固定的、与参数无关的分布采样,加上一个确定性的、可微的参数变换。
- 数学表达:我们引入一个辅助的随机变量 \(\boldsymbol{\epsilon}\),它服从一个简单的、固定的分布(通常为标准正态分布 \(\mathcal{N}(\mathbf{0}, \mathbf{I})\))。然后,将我们需要采样的变量 \(\mathbf{z}\) 表示为 \(\boldsymbol{\epsilon}\) 的一个确定性函数,而该函数的参数正是我们想要学习的目标。
对于高斯后验 \(q_{\phi}(\mathbf{z} | \mathbf{x}) = \mathcal{N}(\mathbf{z}; \boldsymbol{\mu}, \boldsymbol{\sigma}^2 \mathbf{I})\),重参数化形式为:
\[ \mathbf{z} = g_{\phi}(\boldsymbol{\epsilon}, \mathbf{x}) = \boldsymbol{\mu} + \boldsymbol{\sigma} \odot \boldsymbol{\epsilon}, \quad \boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \]
这里,$ \boldsymbol{\mu} = \text{Encoder}_{\boldsymbol{\mu}}(\mathbf{x}; \phi) $,$ \log \boldsymbol{\sigma}^2 = \text{Encoder}_{\boldsymbol{\sigma}}(\mathbf{x}; \phi) $。
- 推导验证:我们需要验证,通过上述变换得到的 \(\mathbf{z}\) 确实服从分布 \(\mathcal{N}(\boldsymbol{\mu}, \boldsymbol{\sigma}^2 \mathbf{I})\)。
- \(\boldsymbol{\epsilon}\) 的均值为 \(\mathbf{0}\),方差为 \(\mathbf{I}\)。
- 线性变换的性质:对于一个随机变量 \(\mathbf{y} = A\mathbf{x} + b\),如果 \(E[\mathbf{x}] = \mathbf{m}\), \(\text{Cov}[\mathbf{x}] = \Sigma\),那么 \(E[\mathbf{y}] = A\mathbf{m} + b\), \(\text{Cov}[\mathbf{y}] = A \Sigma A^T\)。
- 在 \(\mathbf{z} = \boldsymbol{\mu} + \boldsymbol{\sigma} \odot \boldsymbol{\epsilon}\) 中,\(A = \text{diag}(\boldsymbol{\sigma})\),\(b = \boldsymbol{\mu}\)。
- 因此,\(E[\mathbf{z}] = \boldsymbol{\mu} + \text{diag}(\boldsymbol{\sigma}) \cdot \mathbf{0} = \boldsymbol{\mu}\)。
- \(\text{Cov}[\mathbf{z}] = \text{diag}(\boldsymbol{\sigma}) \cdot \mathbf{I} \cdot \text{diag}(\boldsymbol{\sigma})^T = \text{diag}(\boldsymbol{\sigma}^2)\)。
- 所以 \(\mathbf{z} \sim \mathcal{N}(\boldsymbol{\mu}, \text{diag}(\boldsymbol{\sigma}^2)) = \mathcal{N}(\boldsymbol{\mu}, \boldsymbol{\sigma}^2 \mathbf{I})\)。验证成功。
第三步:分析技巧如何解决梯度问题
- 新的计算图:现在,采样过程被改写为:
\[ \boldsymbol{\epsilon} \sim p(\boldsymbol{\epsilon}) = \mathcal{N}(\mathbf{0}, \mathbf{I}) \quad (\text{随机性来源}) \]
\[ \mathbf{z} = \boldsymbol{\mu} + \boldsymbol{\sigma} \odot \boldsymbol{\epsilon} \quad (\text{确定性变换}) \]
其中,$ \boldsymbol{\mu} = f_{\mu}(\mathbf{x}; \phi) $,$ \boldsymbol{\sigma} = f_{\sigma}(\mathbf{x}; \phi) $。
- 梯度的流动:
- 随机性被隔离在 \(\boldsymbol{\epsilon}\) 上。在反向传播时,我们将 \(\boldsymbol{\epsilon}\) 视为一个常量(一个从标准正态分布中采样的具体数值)。虽然它是随机的,但在一次前向传播中,它是一个固定的输入值。
- 损失函数 \(\mathcal{L}\) 关于 \(\mathbf{z}\) 的梯度 \(\nabla_{\mathbf{z}} \mathcal{L}\) 可以被正常计算(因为解码器的前向传播和损失计算是可微的)。
- 由于 \(\mathbf{z}\) 是 \(\boldsymbol{\mu}\) 和 \(\boldsymbol{\sigma}\) 的确定性、可微函数(一个简单的加法和平移),我们可以通过链式法则计算梯度:
\[ \frac{\partial \mathcal{L}}{\partial \boldsymbol{\mu}} = \frac{\partial \mathcal{L}}{\partial \mathbf{z}} \cdot \frac{\partial \mathbf{z}}{\partial \boldsymbol{\mu}} = \frac{\partial \mathcal{L}}{\partial \mathbf{z}} \]
\[ \frac{\partial \mathcal{L}}{\partial \boldsymbol{\sigma}} = \frac{\partial \mathcal{L}}{\partial \mathbf{z}} \cdot \frac{\partial \mathbf{z}}{\partial \boldsymbol{\sigma}} = \frac{\partial \mathcal{L}}{\partial \mathbf{z}} \odot \boldsymbol{\epsilon} \]
- 然后,梯度 $ \nabla_{\boldsymbol{\mu}} \mathcal{L} $ 和 $ \nabla_{\boldsymbol{\sigma}} \mathcal{L} $ 可以继续通过编码器网络 $ f_{\mu} $ 和 $ f_{\sigma} $ 反向传播到其参数 $ \phi $ 上,因为从 $ \mathbf{x} $ 到 $ \boldsymbol{\mu}, \boldsymbol{\sigma} $ 的映射(即编码器网络)是完全可微的。
- 蒙特卡洛估计:在实际训练中,ELBO 中的期望项 \(\mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})}[\cdot]\) 通过蒙特卡洛采样来近似。使用重参数化技巧后,这个估计变为:
\[ \mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})}[f(\mathbf{z})] \approx \frac{1}{L} \sum_{l=1}^{L} f\left( \boldsymbol{\mu} + \boldsymbol{\sigma} \odot \boldsymbol{\epsilon}^{(l)} \right), \quad \boldsymbol{\epsilon}^{(l)} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \]
这个估计量是**关于参数 $ \phi $ 可微**的。
第四步:总结重参数化技巧的作用与优势
- 核心作用:它将一个从参数化分布中采样的随机过程,转换为一个从固定简单分布中采样,再经过参数化确定性变换的过程。这使得梯度可以畅通无阻地通过这个确定性变换,从而实现了对随机生成模型(如 VAE 的编码器部分)的端到端梯度优化。
- 降低方差:相比于其他可能的方法(如得分函数估计器/REINFORCE),重参数化技巧通常能提供更低方差的梯度估计,这使得训练更加稳定和高效。
- 通用性:虽然我们以高斯分布为例,但重参数化技巧可以推广到任何能表示为固定噪声源的可微变换的分布。例如,对于服从 Logistic 分布、Gumbel 分布、Gamma 分布(通过形状-尺度参数化)的变量,都有相应的重参数化方法。
- 在 VAE 训练中的位置:
- 前向传播:输入 \(\mathbf{x}\) → 编码器输出 \(\boldsymbol{\mu}, \boldsymbol{\sigma}\) → 采样 \(\boldsymbol{\epsilon}\) → 计算 \(\mathbf{z} = \boldsymbol{\mu} + \boldsymbol{\sigma} \odot \boldsymbol{\epsilon}\) → 解码器输出 \(\mathbf{x}'\) → 计算重构损失和 KL 散度损失(ELBO)。
- 反向传播:损失梯度流经解码器到 \(\mathbf{z}\) → 通过 \(\frac{\partial \mathbf{z}}{\partial \boldsymbol{\mu}}, \frac{\partial \mathbf{z}}{\partial \boldsymbol{\sigma}}\) 流到编码器输出 → 流经编码器网络更新参数 \(\phi\)。同时,编码器输出 \(\boldsymbol{\mu}, \boldsymbol{\sigma}\) 也直接用于计算 KL 散度的解析梯度。
结论
重参数化技巧是变分自编码器能够成功训练的关键。它通过一个巧妙的数学变换,将不可微的采样操作转化为可微的计算,使得标准基于梯度的优化算法(如 SGD、Adam)能够应用于包含潜在变量采分的概率模型。这一技巧不仅限于 VAE,也广泛用于其他需要学习概率分布参数的生成模型和变分推断场景中。