基于预训练语言模型的文本生成算法:基于强化学习的解码策略(Reinforcement Learning-based Decoding)详解
题目描述
在文本生成任务(如机器翻译、对话生成、文本摘要)中,传统的解码策略(如贪心搜索、束搜索)通常依赖极大似然估计(MLE)目标进行训练,但这类方法存在曝光偏差(Exposure Bias)和目标不匹配问题:
- 曝光偏差:训练时模型使用真实的上文词(Teacher Forcing),而推理时依赖自身生成的错误上文,错误会累积。
- 目标不匹配:MLE目标追求逐词概率最大化,但实际任务更关注整体质量(如流畅度、信息量、多样性)。
基于强化学习(RL)的解码策略通过将文本生成建模为序列决策过程,直接优化与任务相关的评价指标(如BLEU、ROUGE、人类评分),从而提升生成质量。
解题过程详解
步骤1:将文本生成转化为RL问题
- 智能体(Agent):文本生成模型(如GPT、T5等)。
- 动作(Action):每一步从词表中选择一个词。
- 状态(State):已生成的部分序列(即上文上下文)。
- 奖励(Reward):生成完整序列后,根据评价指标计算的整体奖励(如ROUGE分数)。
关键挑战:
- 动作空间巨大(词表大小通常为万级别);
- 奖励稀疏(仅在序列结束时获得),且需对抗随机策略的方差问题。
步骤2:设计RL训练目标
假设生成模型参数为 \(\theta\),输入为 \(x\),输出序列为 \(y=(y_1, y_2, ..., y_T)\),RL的目标是最大化期望奖励:
\[J(\theta) = \mathbb{E}_{y \sim p_\theta(y|x)} [R(y)] \]
其中 \(R(y)\) 是奖励函数(如ROUGE分数)。通过策略梯度方法(如REINFORCE)更新参数:
\[\nabla_\theta J(\theta) \approx \mathbb{E}_{y \sim p_\theta} \left[ R(y) \nabla_\theta \log p_\theta(y|x) \right] \]
问题:直接使用蒙特卡洛采样估计梯度方差较大,训练不稳定。
步骤3:降低方差的方法
- 基准值(Baseline):
引入一个基准值 \(b\)(如奖励的移动平均),调整梯度公式:
\[ \nabla_\theta J(\theta) \approx \mathbb{E}_{y \sim p_\theta} \left[ (R(y) - b) \nabla_\theta \log p_\theta(y|x) \right] \]
- 若 \(R(y) > b\),增加生成 \(y\) 的概率;
- 若 \(R(y) < b\),降低生成 \(y\) 的概率。
-
优势函数(Advantage Function):
用更复杂的基准值(如训练一个价值网络预测预期奖励)进一步减少方差。 -
结合MLE的混合目标:
为避免RL训练偏离自然语言分布,混合MLE损失和RL损失:
\[ \mathcal{L}_{\text{混合}} = \lambda \mathcal{L}_{\text{MLE}} + (1-\lambda) \mathcal{L}_{\text{RL}} \]
其中 \(\lambda\) 是超参数,平衡两种目标。
步骤4:具体算法实现(以Self-Critical序列训练为例)
Self-Critical Sequence Training (SCST) 是经典RL解码算法,步骤如下:
- 采样生成:用当前模型生成一个序列 \(y^s \sim p_\theta(y|x)\)。
- 基准生成:用贪心解码(或束搜索)生成一个基准序列 \(y^g\)(即 \(y^g = \arg\max_y p_\theta(y|x)\))。
- 计算奖励:计算两个序列的奖励 \(R(y^s)\) 和 \(R(y^g)\)。
- 梯度更新:
\[ \nabla_\theta J(\theta) \approx (R(y^s) - R(y^g)) \nabla_\theta \log p_\theta(y^s|x) \]
- 若采样序列 \(y^s\) 的奖励高于基准序列 \(y^g\),则鼓励模型生成类似 \(y^s\) 的序列;
- 否则抑制生成 \(y^s\) 的概率。
优点:
- 基准值 \(R(y^g)\) 来自模型自身,无需额外训练价值网络;
- 实验显示在图像描述、文本摘要等任务中显著提升指标。
步骤5:处理多维度奖励
实际任务中可能需要综合多个指标(如流畅度、重复惩罚、内容覆盖等)。可设计多目标奖励函数:
\[R(y) = \sum_{i=1}^k w_i R_i(y) \]
例如:
- \(R_1(y)\):ROUGE分数(衡量内容覆盖);
- \(R_2(y)\):语言模型概率(衡量流畅度);
- \(R_3(y)\):重复词惩罚(提升多样性)。
总结
基于RL的解码策略通过直接优化任务相关指标,缓解了MLE训练的局限性。其核心在于:
- 将生成建模为序列决策问题;
- 使用策略梯度方法优化期望奖励;
- 通过基准值、混合目标等方法稳定训练;
- 结合多维度奖励提升生成质量。
局限性:
- 训练计算成本高;
- 对奖励函数设计敏感,需针对任务调优。