基于自回归语言模型的文本生成算法:最小贝叶斯风险解码(Minimum Bayes Risk Decoding)详解
算法描述
最小贝叶斯风险解码是自然语言处理中,尤其是在文本生成任务(如机器翻译、文本摘要、对话生成)中,一种先进的解码策略。与贪心搜索、束搜索等旨在直接寻找概率最高的输出序列的方法不同,MBR解码基于“风险最小化”的贝叶斯决策理论框架。它的核心思想是:从一组候选输出序列中,选择一个能够最小化“期望风险”的序列作为最终输出。这里的“风险”通常定义为生成序列与“真实”或“理想”输出之间的损失函数(如负BLEU、负ROUGE、负语义相似度)。由于真实的参考序列在生成时未知,MBR通过模型自身生成的多个候选序列来近似计算期望风险,并选择那个在候选集上平均表现最好(即风险最低)的序列。这种方法能够绕过模型概率分布的某些缺陷(如概率高但不流畅、无意义的序列),直接针对最终评价指标进行优化,从而生成质量更高、更稳健的文本。
解题过程循序渐进讲解
第一步:理解解码问题的本质与MBR的动机
- 标准解码的局限性:在自回归语言模型中,生成文本是一个序列决策过程。标准的解码目标通常是寻找使序列概率 \(P(y|x)\) 最大化的输出 \(y\),即最大后验概率解码:\(y^* = \arg\max_y P(y|x)\)。然而,\(P(y|x)\) 是一个自回归概率的乘积,最大化它可能偏向于生成短、安全但可能乏味或与人类偏好不符的序列。此外,概率最高的序列未必是在BLEU、ROUGE等最终评测指标上得分最高的序列。
- MBR的决策理论视角:MBR将解码视为一个决策问题。其目标不是直接最大化概率,而是最小化一个“期望损失”。给定输入 \(x\),我们选择一个输出 \(y\),这个选择会带来一个损失 \(L(y, y_{ref})\),其中 \(y_{ref}\) 是未知的真实参考序列。由于 \(y_{ref}\) 未知,我们计算在所有可能参考序列上的期望损失(即贝叶斯风险),并选择使其最小的 \(y\):
\[ y^*_{MBR} = \arg\min_{y \in \mathcal{Y}} \mathbb{E}_{P(y_{ref}|x)}[L(y, y_{ref})] \]
这里,$\mathcal{Y}$ 是所有可能输出序列的集合,$P(y_{ref}|x)$ 是给定输入下真实参考序列的后验分布(我们同样不知道)。
- MBR的关键近似:我们无法知道 \(P(y_{ref}|x)\)。MBR的核心假设是,用于生成候选序列的模型 \(P(y|x)\) 是真实后验分布 \(P(y_{ref}|x)\) 的一个合理近似。因此,我们用从 \(P(y|x)\) 中采样或生成的一组候选序列 \(\mathcal{S} = \{y_1, y_2, ..., y_N\}\) 来近似这个期望积分。同时,我们考虑的候选输出 \(y\) 也来自这个集合 \(\mathcal{S}\)。
第二步:MBR解码的具体步骤
MBR解码可以分为四个主要步骤,下面我们详细拆解:
步骤1:候选集生成
- 目标:从模型 \(P(y|x)\) 中生成一个多样化的、高质量的候选输出序列集合 \(\mathcal{S}\)。
- 方法:不直接使用束搜索的单一最优结果,而是采用能产生多样性输出的采样方法。
- 采样法:从模型输出的下一个词分布中随机采样。可以使用核采样 或温度调节的随机采样 来平衡多样性和质量。
- 多样化束搜索:修改束搜索,鼓励搜索路径的多样性,得到N个不同的候选。
- 重要性采样:从提议分布中采样,然后重新加权。
- 关键:候选集应尽可能覆盖高概率区域,并具有一定的多样性,以便后续比较和选择。
步骤2:损失函数定义
- 目标:定义一个函数 \(L(y_i, y_j)\),用于量化两个候选序列 \(y_i\) 和 \(y_j\) 之间的“差异”或“不好”的程度。在MBR中,我们希望这个损失函数与最终的评价指标(如负BLEU、负ROUGE、负语义相似度)相关联。
- 常见选择:
- 基于N-gram重叠的损失:\(L(y_i, y_j) = 1 - BLEU(y_i, y_j)\) 或 \(L(y_i, y_j) = 1 - ROUGE(y_i, y_j)\)。这里,我们把 \(y_j\) 当作“伪参考”,计算 \(y_i\) 相对于它的得分。取负值是因为我们希望最小化损失,等价于最大化BLEU/ROUGE。
- 基于模型得分的损失:\(L(y_i, y_j) = -\log P(y_i | x, y_j)\),但这不常用,因为计算成本高。
- 基于嵌入的损失:\(L(y_i, y_j) = 1 - \text{cosine\_sim}(E(y_i), E(y_j))\),其中 \(E(\cdot)\) 是句子编码器(如BERT)。这能捕捉语义相似度。
步骤3:期望风险计算
- 目标:对于候选集 \(\mathcal{S}\) 中的每一个候选序列 \(y_i\),计算它相对于整个候选集的期望风险 \(\mathcal{R}(y_i)\)。
- 公式:使用蒙特卡洛近似,用候选集 \(\mathcal{S}\) 近似模型分布 \(P(y|x)\):
\[ \mathcal{R}(y_i) = \frac{1}{N} \sum_{j=1}^{N} L(y_i, y_j) \]
这个式子的含义是:将候选 $y_j$ 视为一个可能的“真实参考”的样本,损失 $L(y_i, y_j)$ 是选择 $y_i$ 而真实是 $y_j$ 时的代价。我们对所有可能的 $y_j$ 求平均,就得到了选择 $y_i$ 的**期望风险**。求和时通常包括 $j=i$ 的情况(即与自身的损失,通常是0或一个固定值)。
步骤4:最终序列选择
- 目标:从候选集 \(\mathcal{S}\) 中选择期望风险最小的序列作为最终输出。
- 公式:
\[ y^*_{MBR} = \arg\min_{y_i \in \mathcal{S}} \mathcal{R}(y_i) = \arg\min_{y_i \in \mathcal{S}} \frac{1}{N} \sum_{j=1}^{N} L(y_i, y_j) \]
- 解读:被选中的 \(y^*_{MBR}\) 不是在模型概率意义上最高的,而是在所有候选序列中“最具有代表性”、“最稳健”或“共识度最高”的序列。它与其他候选的平均差异最小,意味着它最接近候选集的“中心”,从而可能更安全、更流畅、更符合人类对质量的综合判断。
第三步:算法伪代码与示例
假设我们有一个训练好的文本摘要模型,输入文档为 \(x\),我们要生成摘要。
- 候选生成:使用核采样(top-p=0.9)从模型 \(P(y|x)\) 中独立采样 \(N=5\) 个候选摘要:
[y1, y2, y3, y4, y5]。 - 定义损失:使用基于ROUGE-L的损失,\(L(y_i, y_j) = 1 - ROUGE\_L(y_i, y_j)\)。
- 计算期望风险:
- 对于
y1,计算它与y1, y2, y3, y4, y5各自的 \(1-ROUGE\_L\),然后求平均,得到R(y1)。 - 同理计算
R(y2),R(y3),R(y4),R(y5)。
- 对于
- 最终选择:比较
R(y1)到R(y5),找出最小值。假设R(y3)最小,则最终输出摘要为y3。
第四步:MBR解码的优缺点与适用范围
- 优点:
- 直接优化评估指标:通过自定义损失函数,可以与下游任务的评估指标(如BLEU, ROUGE)直接对齐。
- 缓解模式崩溃:不盲目追求最高概率,能避免生成短、重复、无意义的“安全”文本,提高多样性。
- 提高鲁棒性:选择的是候选中的“共识”序列,对模型概率校准误差和噪声更稳健。
- 理论优雅:基于坚实的贝叶斯决策理论。
- 缺点:
- 计算成本高:需要生成多个候选(通常N>10),并计算所有候选对之间的损失(\(O(N^2)\))。损失函数如ROUGE计算也有开销。
- 依赖候选质量:如果候选集质量普遍很差,MBR选出的也只是“矮子里的将军”。
- 损失函数设计:如何设计一个既能快速计算又能准确反映最终质量的损失函数是一个挑战。
- 适用范围:
- 对生成质量要求高,且有一定计算资源的场景。
- 特别适用于机器翻译、文本摘要、对话生成等任务,这些任务有相对明确的自动评估指标(BLEU, ROUGE)可以作为损失函数的基础。
- 当标准束搜索或采样结果不尽如人意时,可作为后处理优化手段。
总结:最小贝叶斯风险解码(MBR)将文本生成从概率最大化问题,转化为风险最小化的决策问题。它通过生成多样化的候选集,并选择其中期望风险(通常定义为与其它候选的平均差异)最小的序列作为输出,从而绕过模型概率分布的某些偏差,生成更稳健、更符合人类评估标准的文本。虽然计算成本较高,但其在提升生成质量方面的潜力使其成为高级解码策略中的重要一员。