隐马尔可夫模型(Hidden Markov Model, HMM)中前向-后向算法(Forward-Backward Algorithm)的联合概率计算过程
问题描述
给定一个隐马尔可夫模型(HMM),它包括:
- 隐藏状态集合 \(S = \{s_1, s_2, ..., s_N\}\)。
- 观测符号集合 \(V = \{v_1, v_2, ..., v_M\}\)。
- 状态转移概率矩阵 \(A\),其中 \(a_{ij} = P(q_{t+1}=s_j | q_t = s_i)\)。
- 观测概率矩阵 \(B\),其中 \(b_j(k) = P(o_t = v_k | q_t = s_j)\)。
- 初始状态概率分布 \(\pi\),其中 \(\pi_i = P(q_1 = s_i)\)。
以及一个观测序列 \(O = (o_1, o_2, ..., o_T)\),其中每个 \(o_t \in V\)。
目标:高效计算观测序列 \(O\) 在给定模型参数 \(\lambda = (A, B, \pi)\) 下的联合概率 \(P(O, q_t = s_i | \lambda)\),即观测序列与在时刻 \(t\) 处于特定状态 \(s_i\) 的联合概率。这个联合概率是HMM许多核心算法(如Baum-Welch参数学习、序列解码)的关键中间量。
解题过程详解
1. 理解问题与动机
在HMM中,直接计算 \(P(O, q_t = s_i | \lambda)\) 需要对所有可能的状态序列进行求和:
\[P(O, q_t = s_i | \lambda) = \sum_{所有经过状态 s_i 在时刻 t 的路径} P(O, Q | \lambda) \]
这里 \(Q = (q_1, q_2, ..., q_T)\) 是一个隐藏状态序列。直接计算的复杂度是 \(O(N^T)\),是指数级的,不可行。前向-后向算法通过动态规划思想,将复杂度降低到 \(O(N^2 T)\),是一种高效计算方法。
计算该联合概率的目的:
- 平滑:在已知完整观测序列 \(O\) 的条件下,估计中间时刻 \(t\) 处于状态 \(s_i\) 的概率 \(P(q_t = s_i | O, \lambda)\)(即后验状态概率),这需要用到 \(P(O, q_t = s_i | \lambda)\)。
- 参数学习:Baum-Welch算法(EM算法在HMM中的实现)利用该联合概率来计算状态转移和观测发射的期望计数,从而迭代更新模型参数 \(\lambda\)。
2. 关键思想:分解为前向与后向两部分
联合概率 \(P(O, q_t = s_i | \lambda)\) 可以分解为两个独立部分的乘积:
- 前向部分:从序列开始到时刻 \(t\),并且以状态 \(s_i\) 结束的概率。
- 后向部分:在时刻 \(t\) 处于状态 \(s_i\) 的条件下,从 \(t+1\) 到序列结束的观测概率。
形式化地:
\[P(O, q_t = s_i | \lambda) = \underbrace{P(o_1, o_2, ..., o_t, q_t = s_i | \lambda)}_{\text{前向概率 } \alpha_t(i)} \times \underbrace{P(o_{t+1}, o_{t+2}, ..., o_T | q_t = s_i, \lambda)}_{\text{后向概率 } \beta_t(i)} \]
这里定义了:
- 前向概率 \(\alpha_t(i) = P(o_1, o_2, ..., o_t, q_t = s_i | \lambda)\)。
- 后向概率 \(\beta_t(i) = P(o_{t+1}, o_{t+2}, ..., o_T | q_t = s_i, \lambda)\)。
注意:后向概率 \(\beta_t(i)\) 的条件是 \(q_t = s_i\),并且观测是从 \(t+1\) 开始的未来观测,不包括 \(o_t\)。这是一个条件概率,不是联合概率。
3. 前向算法(Forward Algorithm)计算 \(\alpha_t(i)\)
步骤1:初始化(\(t = 1\))
对于每个状态 \(s_i\):
\[\alpha_1(i) = \pi_i \cdot b_i(o_1) \]
解释:初始时刻处于状态 \(s_i\) 的概率是 \(\pi_i\),并且在该状态下观测到 \(o_1\) 的概率是 \(b_i(o_1)\)。两者相乘得到从开始到时刻1,且状态为 \(s_i\) 的联合概率。
步骤2:递推(\(t = 1, 2, ..., T-1\))
对于每个时刻 \(t+1\) 和每个状态 \(s_j\):
\[\alpha_{t+1}(j) = \left[ \sum_{i=1}^{N} \alpha_t(i) \cdot a_{ij} \right] \cdot b_j(o_{t+1}) \]
解释:要计算在时刻 \(t+1\) 处于状态 \(s_j\) 且观测到前 \(t+1\) 个观测符号的联合概率 \(\alpha_{t+1}(j)\):
- 对前一时刻 \(t\) 的所有可能状态 \(s_i\),计算从 \(s_i\) 转移到 \(s_j\) 的概率 \(a_{ij}\),并乘以该状态的前向概率 \(\alpha_t(i)\)。求和得到在时刻 \(t+1\) 到达 \(s_j\) 的所有路径概率。
- 然后,在状态 \(s_j\) 下观测到 \(o_{t+1}\) 的概率 \(b_j(o_{t+1})\) 乘以该到达概率,得到新的联合概率。
步骤3:终止(计算整个观测序列的概率)
整个观测序列的概率可以通过对所有最终状态求和得到:
\[P(O | \lambda) = \sum_{i=1}^{N} \alpha_T(i) \]
4. 后向算法(Backward Algorithm)计算 \(\beta_t(i)\)
步骤1:初始化(\(t = T\))
对于每个状态 \(s_i\):
\[\beta_T(i) = 1 \]
解释:在最后时刻 \(T\),没有未来的观测需要生成,所以条件概率为1。这是一个约定,确保递推正确。
步骤2:递推(\(t = T-1, T-2, ..., 1\))
对于每个时刻 \(t\) 和每个状态 \(s_i\):
\[\beta_t(i) = \sum_{j=1}^{N} a_{ij} \cdot b_j(o_{t+1}) \cdot \beta_{t+1}(j) \]
解释:要计算在时刻 \(t\) 处于状态 \(s_i\) 的条件下,未来观测序列 \(o_{t+1}, ..., o_T\) 的概率 \(\beta_t(i)\):
- 考虑下一个时刻 \(t+1\) 的所有可能状态 \(s_j\)。
- 从状态 \(s_i\) 转移到 \(s_j\) 的概率是 \(a_{ij}\)。
- 在状态 \(s_j\) 下观测到 \(o_{t+1}\) 的概率是 \(b_j(o_{t+1})\)。
- 在时刻 \(t+1\) 处于状态 \(s_j\) 的条件下,剩余观测序列 \(o_{t+2}, ..., o_T\) 的概率是 \(\beta_{t+1}(j)\)。
- 将这三项相乘,并对所有可能的 \(s_j\) 求和,得到 \(\beta_t(i)\)。
5. 联合概率计算与归一化
利用计算好的 \(\alpha_t(i)\) 和 \(\beta_t(i)\),对于任意时刻 \(t\) 和状态 \(s_i\),联合概率为:
\[P(O, q_t = s_i | \lambda) = \alpha_t(i) \cdot \beta_t(i) \]
这个公式直接来自之前的分解定义。
注意:通常我们更关心的是后验状态概率 \(\gamma_t(i) = P(q_t = s_i | O, \lambda)\),即给定整个观测序列的条件下,时刻 \(t\) 处于状态 \(s_i\) 的概率。这可以通过联合概率除以观测序列的总概率得到:
\[\gamma_t(i) = \frac{P(O, q_t = s_i | \lambda)}{P(O | \lambda)} = \frac{\alpha_t(i) \cdot \beta_t(i)}{\sum_{j=1}^{N} \alpha_t(j) \cdot \beta_t(j)} \]
分母 \(P(O | \lambda)\) 可以用前向算法在 \(t=T\) 时的结果计算:\(P(O | \lambda) = \sum_{j=1}^{N} \alpha_T(j)\)。注意,根据概率一致性,对于任意 \(t\), \(\sum_{j=1}^{N} \alpha_t(j) \cdot \beta_t(j)\) 也应该等于 \(P(O | \lambda)\),这可以作为计算的校验。
总结与意义
前向-后向算法的核心贡献在于,它通过动态规划(前向递推和后向递推)高效地计算了所有时刻、所有状态下的联合概率 \(P(O, q_t = s_i | \lambda)\)。
- 前向算法(\(\alpha\)):从序列起点向终点计算,累积了过去观测和当前状态的信息。
- 后向算法(\(\beta\)):从序列终点向起点计算,编码了未来观测在当前状态下的可能性。
- 联合概率(\(\alpha \cdot \beta\)):将过去和未来的信息结合在一起,给出了在完整观测序列上下文中,特定时刻处于特定状态的完整证据。
这个联合概率是HMM进行平滑(计算后验状态概率 \(\gamma_t(i)\))、学习(Baum-Welch算法中计算转移和发射的期望计数)以及解码(例如维特比算法的软解码版本)的基石。整个算法的计算复杂度为 \(O(N^2 T)\),其中 \(N\) 是状态数,\(T\) 是序列长度,相比穷举搜索的指数复杂度是巨大的效率提升。