基于变分推断(Variational Inference, VI)的隐变量模型参数估计算法详解
字数 5241
更新时间 2025-12-14 01:20:56

基于变分推断(Variational Inference, VI)的隐变量模型参数估计算法详解

1. 问题背景与目标

在许多自然语言处理任务中,我们常常会使用含有隐变量的概率图模型,例如主题模型(LDA)变分自编码器(VAE) 等。这类模型可以表示为观测数据 \(\mathbf{x}\) 和隐变量 \(\mathbf{z}\) 的联合分布 \(p_{\theta}(\mathbf{x}, \mathbf{z})\),其中 \(\theta\) 是模型的参数。

我们的目标通常有两个:

  • 学习模型参数 \(\theta\),使其能最好地解释我们观测到的数据。
  • 推断隐变量 \(\mathbf{z}\) 的后验分布 \(p_{\theta}(\mathbf{z} | \mathbf{x})\),例如,给定一篇文档,推断其主题分布。

模型的证据(Evidence)或对数似然是 \(p_{\theta}(\mathbf{x}) = \int p_{\theta}(\mathbf{x}, \mathbf{z}) d\mathbf{z}\)。然而,对于许多有趣的模型,这个积分是难以直接计算的,因为隐变量空间可能很复杂,导致后验分布 \(p_{\theta}(\mathbf{z} | \mathbf{x}) = p_{\theta}(\mathbf{x}, \mathbf{z}) / p_{\theta}(\mathbf{x})\) 也无法直接求解(分母就是难以计算的 \(p_{\theta}(\mathbf{x})\))。

核心挑战:如何高效地学习模型参数 \(\theta\) 并近似隐变量的后验分布?变分推断就是解决这个问题的关键框架。

2. 变分推断的核心思想

变分推断的核心理念是将推断问题转化为一个优化问题

  1. 用“简单”分布近似“复杂”分布:我们引入一个由参数 \(\phi\) 定义的、形式简单的分布族 \(q_{\phi}(\mathbf{z})\),称为变分分布。我们的目标是从这个分布族中,寻找一个与真实后验分布 \(p_{\theta}(\mathbf{z} | \mathbf{x})\) 最接近的成员 \(q_{\phi^*}(\mathbf{z})\) 来近似它。

  2. “接近”的度量:我们用Kullback-Leibler散度来衡量两个分布 \(q\)\(p\) 之间的“距离”。

\[KL(q_{\phi}(\mathbf{z}) || p_{\theta}(\mathbf{z} | \mathbf{x})) = \mathbb{E}_{q_{\phi}(\mathbf{z})} \left[ \log \frac{q_{\phi}(\mathbf{z})}{p_{\theta}(\mathbf{z} | \mathbf{x})} \right] \]

KL散度越小,说明 $q_{\phi}(\mathbf{z})$ 与 $p_{\theta}(\mathbf{z} | \mathbf{x})$ 越接近。
  1. 优化目标的构建
    直接最小化 \(KL(q || p)\) 是不可能的,因为它依赖于难以计算的后验 \(p_{\theta}(\mathbf{z} | \mathbf{x})\)。我们对目标进行数学变换:

\[ \begin{aligned} \log p_{\theta}(\mathbf{x}) &= \log \int p_{\theta}(\mathbf{x}, \mathbf{z}) d\mathbf{z} \\ &= \log \int q_{\phi}(\mathbf{z}) \frac{p_{\theta}(\mathbf{x}, \mathbf{z})}{q_{\phi}(\mathbf{z})} d\mathbf{z} \\ &\ge \int q_{\phi}(\mathbf{z}) \log \frac{p_{\theta}(\mathbf{x}, \mathbf{z})}{q_{\phi}(\mathbf{z})} d\mathbf{z} \quad \text{(根据Jensen不等式)} \\ &= \mathbb{E}_{q_{\phi}(\mathbf{z})} [\log p_{\theta}(\mathbf{x}, \mathbf{z})] - \mathbb{E}_{q_{\phi}(\mathbf{z})} [\log q_{\phi}(\mathbf{z})] \\ &\triangleq \mathcal{L}(\theta, \phi; \mathbf{x}) \end{aligned} \]

这里我们推导出了一个**证据下界**,记作 $\mathcal{L}(\theta, \phi; \mathbf{x})$,也称为**ELBO**。

进一步推导可以发现:

\[ \log p_{\theta}(\mathbf{x}) = \mathcal{L}(\theta, \phi; \mathbf{x}) + KL(q_{\phi}(\mathbf{z}) || p_{\theta}(\mathbf{z} | \mathbf{x})) \]

这个等式非常优美地揭示了变分推断的本质:
*   $\log p_{\theta}(\mathbf{x})$ 是我们想最大化的**模型证据**(常数)。
*   $\mathcal{L}$ 是我们构建的、可以优化的**证据下界**。
*   $KL(q || p)$ 是变分分布与真实后验的差异。

**结论**:**最大化ELBO $\mathcal{L}$ 等价于同时最小化KL散度 $KL(q || p)$**。通过最大化一个可计算的下界 $\mathcal{L}$,我们同时完成了两件事:1) 让 $q_{\phi}$ 逼近真实后验;2) 让模型对数似然 $\log p_{\theta}(\mathbf{x})$ 尽可能大(从而学习到更好的参数 $\theta$)。

3. 算法的详细步骤

一个标准的变分推断EM算法步骤如下:

步骤1:初始化

  • 随机初始化模型参数 \(\theta^{(0)}\) 和变分参数 \(\phi^{(0)}\)

步骤2:E步(变分E步)

  • 固定模型参数 \(\theta\) 不变。
  • 通过最大化ELBO \(\mathcal{L}(\theta, \phi; \mathbf{x})\) 来更新变分参数 \(\phi\)。这一步的目的是改进变分分布 \(q_{\phi}\),使其更接近当前模型下的真实后验
  • 最大化方法取决于 \(q\) 的选择。如果 \(q\) 属于指数族,并且满足“平均场”假设,我们可能得到闭式更新公式。更通用的方法是使用随机梯度下降

步骤3:M步

  • 固定变分分布 \(q_{\phi}(\mathbf{z})\) 不变。
  • 通过最大化ELBO \(\mathcal{L}(\theta, \phi; \mathbf{x})\) 来更新模型参数 \(\theta\)。这一步的目的是改进模型本身,使其能更好地解释观测数据
  • 目标函数是:\(\theta^{(new)} = \arg\max_{\theta} \mathbb{E}_{q_{\phi}(\mathbf{z})}[\log p_{\theta}(\mathbf{x}, \mathbf{z})]\)。这可以看作是在变分分布 \(q_{\phi}\) 下,关于联合概率 \(\log p_{\theta}(\mathbf{x}, \mathbf{z})\) 的期望最大化。

步骤4:迭代

  • 重复步骤2和步骤3,直到ELBO \(\mathcal{L}\) 的值收敛(即变化小于某个阈值)。

与经典EM算法的联系

  • 在经典EM中,E步是精确计算后验 \(p_{\theta}(\mathbf{z} | \mathbf{x})\),M步是最大化完整数据的期望似然。
  • 在变分推断中,E步是近似计算后验(通过优化 \(q_{\phi}\) 逼近 \(p_{\theta}(\mathbf{z} | \mathbf{x})\)),因此也称为“变分E步”。当 \(q_{\phi}\) 的分布族能完美拟合真实后验时,变分EM就退化为经典EM。

4. 关键技术:黑盒变分推断与重参数化技巧

上述算法在实际中面临一个主要挑战:如何计算ELBO \(\mathcal{L}\) 及其关于 \(\phi\) 的梯度?\(\mathcal{L}\) 中包含对 \(q_{\phi}(\mathbf{z})\) 的期望,而 \(q_{\phi}\) 的参数 \(\phi\) 存在于期望的概率分布内部,直接求梯度非常困难。

解决方案

  1. 蒙特卡洛估计:用采样的方法来近似计算期望。从 \(q_{\phi}(\mathbf{z})\) 中采样 \(L\) 个样本 \(\mathbf{z}^{(l)} \sim q_{\phi}(\mathbf{z})\),则ELBO的估计为:

\[\mathcal{L}(\theta, \phi; \mathbf{x}) \approx \frac{1}{L} \sum_{l=1}^{L} [\log p_{\theta}(\mathbf{x}, \mathbf{z}^{(l)}) - \log q_{\phi}(\mathbf{z}^{(l)})] \]

  1. 重参数化技巧:为了计算 \(\nabla_{\phi} \mathcal{L}\),关键在于如何使梯度信号通过随机采样操作 \(\mathbf{z} \sim q_{\phi}(\mathbf{z})\) 反向传播。技巧是将采样过程参数化分离
    • 假设我们可以将 \(\mathbf{z}\) 表示为一个确定性函数 \(\mathbf{z} = g_{\phi}(\epsilon, \mathbf{x})\),其中 \(\epsilon\) 是来自一个固定简单分布(如标准正态分布)的随机噪声。
    • 例如,如果 \(q_{\phi}(z) = \mathcal{N}(\mu, \sigma^2)\),我们可以令 \(z = \mu + \sigma \cdot \epsilon\),其中 \(\epsilon \sim \mathcal{N}(0, 1)\)
    • 这样,ELBO的梯度可以重写为:

\[\nabla_{\phi} \mathcal{L} \approx \frac{1}{L} \sum_{l=1}^{L} \nabla_{\phi} [\log p_{\theta}(\mathbf{x}, g_{\phi}(\epsilon^{(l)}, \mathbf{x})) - \log q_{\phi}(g_{\phi}(\epsilon^{(l)}, \mathbf{x}) | \mathbf{x})] \]

   现在,梯度可以直接通过确定性的函数 $g_{\phi}$ 进行反向传播了。

黑盒变分推断:结合蒙特卡洛估计重参数化技巧,我们发展出了一种通用算法。它不依赖于模型 \(p_{\theta}\) 和变分分布 \(q_{\phi}\) 的具体形式,只需能计算联合概率 \(\log p_{\theta}(\mathbf{x}, \mathbf{z})\) 并从 \(q_{\phi}\) 中采样即可。这使得我们可以利用自动微分框架(如PyTorch, TensorFlow)来同时优化 \(\theta\)\(\phi\),无需单独执行E步和M步,而是用随机梯度下降联合优化它们。

5. 总结与应用

总结:变分推断通过引入一个参数化的变分分布 \(q_{\phi}\) 来近似复杂的真实后验 \(p_{\theta}(\mathbf{z} | \mathbf{x})\),并将逼近问题转化为最大化证据下界 \(\mathcal{L}\) 的优化问题。黑盒变分推断结合蒙特卡洛采样和重参数化技巧,使其成为一种强大且通用的近似推断工具。

在NLP中的应用

  1. 主题模型LDA的核心推导就是基于变分推断。其中,变分分布 \(q\) 通常取为平均场形式,并对模型参数 \(\theta\) (文档-主题分布) 和 \(\phi\) (主题-词语分布) 进行迭代更新。
  2. 变分自编码器:这是深度学习与变分推断结合的典范。VAE的编码器输出变分分布的参数 \(\phi\),解码器对应生成模型 \(p_{\theta}\),其训练目标就是最大化ELBO。
  3. 深度生成模型:在文本生成任务中,许多先进的模型,如深度隐变量语言模型,都依赖变分推断来学习有意义的隐表示和生成高质量的文本。
相似文章
相似文章
 全屏