基于变分推断(Variational Inference, VI)的隐变量模型参数估计算法详解
1. 问题背景与目标
在许多自然语言处理任务中,我们常常会使用含有隐变量的概率图模型,例如主题模型(LDA)、变分自编码器(VAE) 等。这类模型可以表示为观测数据 \(\mathbf{x}\) 和隐变量 \(\mathbf{z}\) 的联合分布 \(p_{\theta}(\mathbf{x}, \mathbf{z})\),其中 \(\theta\) 是模型的参数。
我们的目标通常有两个:
- 学习模型参数 \(\theta\),使其能最好地解释我们观测到的数据。
- 推断隐变量 \(\mathbf{z}\) 的后验分布 \(p_{\theta}(\mathbf{z} | \mathbf{x})\),例如,给定一篇文档,推断其主题分布。
模型的证据(Evidence)或对数似然是 \(p_{\theta}(\mathbf{x}) = \int p_{\theta}(\mathbf{x}, \mathbf{z}) d\mathbf{z}\)。然而,对于许多有趣的模型,这个积分是难以直接计算的,因为隐变量空间可能很复杂,导致后验分布 \(p_{\theta}(\mathbf{z} | \mathbf{x}) = p_{\theta}(\mathbf{x}, \mathbf{z}) / p_{\theta}(\mathbf{x})\) 也无法直接求解(分母就是难以计算的 \(p_{\theta}(\mathbf{x})\))。
核心挑战:如何高效地学习模型参数 \(\theta\) 并近似隐变量的后验分布?变分推断就是解决这个问题的关键框架。
2. 变分推断的核心思想
变分推断的核心理念是将推断问题转化为一个优化问题。
-
用“简单”分布近似“复杂”分布:我们引入一个由参数 \(\phi\) 定义的、形式简单的分布族 \(q_{\phi}(\mathbf{z})\),称为变分分布。我们的目标是从这个分布族中,寻找一个与真实后验分布 \(p_{\theta}(\mathbf{z} | \mathbf{x})\) 最接近的成员 \(q_{\phi^*}(\mathbf{z})\) 来近似它。
-
“接近”的度量:我们用Kullback-Leibler散度来衡量两个分布 \(q\) 和 \(p\) 之间的“距离”。
\[KL(q_{\phi}(\mathbf{z}) || p_{\theta}(\mathbf{z} | \mathbf{x})) = \mathbb{E}_{q_{\phi}(\mathbf{z})} \left[ \log \frac{q_{\phi}(\mathbf{z})}{p_{\theta}(\mathbf{z} | \mathbf{x})} \right] \]
KL散度越小,说明 $q_{\phi}(\mathbf{z})$ 与 $p_{\theta}(\mathbf{z} | \mathbf{x})$ 越接近。
- 优化目标的构建:
直接最小化 \(KL(q || p)\) 是不可能的,因为它依赖于难以计算的后验 \(p_{\theta}(\mathbf{z} | \mathbf{x})\)。我们对目标进行数学变换:
\[ \begin{aligned} \log p_{\theta}(\mathbf{x}) &= \log \int p_{\theta}(\mathbf{x}, \mathbf{z}) d\mathbf{z} \\ &= \log \int q_{\phi}(\mathbf{z}) \frac{p_{\theta}(\mathbf{x}, \mathbf{z})}{q_{\phi}(\mathbf{z})} d\mathbf{z} \\ &\ge \int q_{\phi}(\mathbf{z}) \log \frac{p_{\theta}(\mathbf{x}, \mathbf{z})}{q_{\phi}(\mathbf{z})} d\mathbf{z} \quad \text{(根据Jensen不等式)} \\ &= \mathbb{E}_{q_{\phi}(\mathbf{z})} [\log p_{\theta}(\mathbf{x}, \mathbf{z})] - \mathbb{E}_{q_{\phi}(\mathbf{z})} [\log q_{\phi}(\mathbf{z})] \\ &\triangleq \mathcal{L}(\theta, \phi; \mathbf{x}) \end{aligned} \]
这里我们推导出了一个**证据下界**,记作 $\mathcal{L}(\theta, \phi; \mathbf{x})$,也称为**ELBO**。
进一步推导可以发现:
\[ \log p_{\theta}(\mathbf{x}) = \mathcal{L}(\theta, \phi; \mathbf{x}) + KL(q_{\phi}(\mathbf{z}) || p_{\theta}(\mathbf{z} | \mathbf{x})) \]
这个等式非常优美地揭示了变分推断的本质:
* $\log p_{\theta}(\mathbf{x})$ 是我们想最大化的**模型证据**(常数)。
* $\mathcal{L}$ 是我们构建的、可以优化的**证据下界**。
* $KL(q || p)$ 是变分分布与真实后验的差异。
**结论**:**最大化ELBO $\mathcal{L}$ 等价于同时最小化KL散度 $KL(q || p)$**。通过最大化一个可计算的下界 $\mathcal{L}$,我们同时完成了两件事:1) 让 $q_{\phi}$ 逼近真实后验;2) 让模型对数似然 $\log p_{\theta}(\mathbf{x})$ 尽可能大(从而学习到更好的参数 $\theta$)。
3. 算法的详细步骤
一个标准的变分推断EM算法步骤如下:
步骤1:初始化
- 随机初始化模型参数 \(\theta^{(0)}\) 和变分参数 \(\phi^{(0)}\)。
步骤2:E步(变分E步)
- 固定模型参数 \(\theta\) 不变。
- 通过最大化ELBO \(\mathcal{L}(\theta, \phi; \mathbf{x})\) 来更新变分参数 \(\phi\)。这一步的目的是改进变分分布 \(q_{\phi}\),使其更接近当前模型下的真实后验。
- 最大化方法取决于 \(q\) 的选择。如果 \(q\) 属于指数族,并且满足“平均场”假设,我们可能得到闭式更新公式。更通用的方法是使用随机梯度下降。
步骤3:M步
- 固定变分分布 \(q_{\phi}(\mathbf{z})\) 不变。
- 通过最大化ELBO \(\mathcal{L}(\theta, \phi; \mathbf{x})\) 来更新模型参数 \(\theta\)。这一步的目的是改进模型本身,使其能更好地解释观测数据。
- 目标函数是:\(\theta^{(new)} = \arg\max_{\theta} \mathbb{E}_{q_{\phi}(\mathbf{z})}[\log p_{\theta}(\mathbf{x}, \mathbf{z})]\)。这可以看作是在变分分布 \(q_{\phi}\) 下,关于联合概率 \(\log p_{\theta}(\mathbf{x}, \mathbf{z})\) 的期望最大化。
步骤4:迭代
- 重复步骤2和步骤3,直到ELBO \(\mathcal{L}\) 的值收敛(即变化小于某个阈值)。
与经典EM算法的联系:
- 在经典EM中,E步是精确计算后验 \(p_{\theta}(\mathbf{z} | \mathbf{x})\),M步是最大化完整数据的期望似然。
- 在变分推断中,E步是近似计算后验(通过优化 \(q_{\phi}\) 逼近 \(p_{\theta}(\mathbf{z} | \mathbf{x})\)),因此也称为“变分E步”。当 \(q_{\phi}\) 的分布族能完美拟合真实后验时,变分EM就退化为经典EM。
4. 关键技术:黑盒变分推断与重参数化技巧
上述算法在实际中面临一个主要挑战:如何计算ELBO \(\mathcal{L}\) 及其关于 \(\phi\) 的梯度?\(\mathcal{L}\) 中包含对 \(q_{\phi}(\mathbf{z})\) 的期望,而 \(q_{\phi}\) 的参数 \(\phi\) 存在于期望的概率分布内部,直接求梯度非常困难。
解决方案:
- 蒙特卡洛估计:用采样的方法来近似计算期望。从 \(q_{\phi}(\mathbf{z})\) 中采样 \(L\) 个样本 \(\mathbf{z}^{(l)} \sim q_{\phi}(\mathbf{z})\),则ELBO的估计为:
\[\mathcal{L}(\theta, \phi; \mathbf{x}) \approx \frac{1}{L} \sum_{l=1}^{L} [\log p_{\theta}(\mathbf{x}, \mathbf{z}^{(l)}) - \log q_{\phi}(\mathbf{z}^{(l)})] \]
- 重参数化技巧:为了计算 \(\nabla_{\phi} \mathcal{L}\),关键在于如何使梯度信号通过随机采样操作 \(\mathbf{z} \sim q_{\phi}(\mathbf{z})\) 反向传播。技巧是将采样过程参数化分离。
- 假设我们可以将 \(\mathbf{z}\) 表示为一个确定性函数 \(\mathbf{z} = g_{\phi}(\epsilon, \mathbf{x})\),其中 \(\epsilon\) 是来自一个固定简单分布(如标准正态分布)的随机噪声。
- 例如,如果 \(q_{\phi}(z) = \mathcal{N}(\mu, \sigma^2)\),我们可以令 \(z = \mu + \sigma \cdot \epsilon\),其中 \(\epsilon \sim \mathcal{N}(0, 1)\)。
- 这样,ELBO的梯度可以重写为:
\[\nabla_{\phi} \mathcal{L} \approx \frac{1}{L} \sum_{l=1}^{L} \nabla_{\phi} [\log p_{\theta}(\mathbf{x}, g_{\phi}(\epsilon^{(l)}, \mathbf{x})) - \log q_{\phi}(g_{\phi}(\epsilon^{(l)}, \mathbf{x}) | \mathbf{x})] \]
现在,梯度可以直接通过确定性的函数 $g_{\phi}$ 进行反向传播了。
黑盒变分推断:结合蒙特卡洛估计和重参数化技巧,我们发展出了一种通用算法。它不依赖于模型 \(p_{\theta}\) 和变分分布 \(q_{\phi}\) 的具体形式,只需能计算联合概率 \(\log p_{\theta}(\mathbf{x}, \mathbf{z})\) 并从 \(q_{\phi}\) 中采样即可。这使得我们可以利用自动微分框架(如PyTorch, TensorFlow)来同时优化 \(\theta\) 和 \(\phi\),无需单独执行E步和M步,而是用随机梯度下降联合优化它们。
5. 总结与应用
总结:变分推断通过引入一个参数化的变分分布 \(q_{\phi}\) 来近似复杂的真实后验 \(p_{\theta}(\mathbf{z} | \mathbf{x})\),并将逼近问题转化为最大化证据下界 \(\mathcal{L}\) 的优化问题。黑盒变分推断结合蒙特卡洛采样和重参数化技巧,使其成为一种强大且通用的近似推断工具。
在NLP中的应用:
- 主题模型:LDA的核心推导就是基于变分推断。其中,变分分布 \(q\) 通常取为平均场形式,并对模型参数 \(\theta\) (文档-主题分布) 和 \(\phi\) (主题-词语分布) 进行迭代更新。
- 变分自编码器:这是深度学习与变分推断结合的典范。VAE的编码器输出变分分布的参数 \(\phi\),解码器对应生成模型 \(p_{\theta}\),其训练目标就是最大化ELBO。
- 深度生成模型:在文本生成任务中,许多先进的模型,如深度隐变量语言模型,都依赖变分推断来学习有意义的隐表示和生成高质量的文本。