基于指针网络(Pointer Network)的文本摘要算法
题目描述
指针网络(Pointer Network)是一种专为解决输出序列元素直接来自输入序列的子集而设计的神经网络结构,常用于文本摘要、问答等任务。在文本摘要中,传统序列到序列(Seq2Seq)模型依赖固定词表生成摘要,但面对生僻词、专有名词或超出词表的词汇时容易出错。指针网络通过引入“指针”机制,允许模型从输入文本中直接复制词汇,显著提升生成摘要的准确性和流畅性。
解题过程
1. 传统Seq2Seq模型的局限性
-
问题背景:
Seq2Seq模型通常包含编码器(Encoder)和解码器(Decoder)。编码器将输入序列(如原文)编码为上下文向量,解码器基于该向量生成目标序列(如摘要)。但解码器生成词汇时仅依赖预设词表,导致以下问题:- 未登录词(OOV)问题:输入中的生僻词或专有名词可能不在词表中,模型无法生成这些词。
- 信息失真:模型可能被迫用泛化词汇替换原文中的关键实体(如人名、地名)。
-
改进思路:
让模型在生成每个词时,动态选择是“生成”一个词(从词表中)还是“复制”一个词(从输入序列中)。
2. 指针网络的基本结构
指针网络在Seq2Seq基础上增加了一个指针生成器(Pointer Generator),其核心组件包括:
(1)编码器
- 使用双向LSTM或Transformer对输入序列 \(X = (x_1, x_2, ..., x_n)\) 编码,得到每个词的隐藏状态 \(h_i\)。
(2)解码器
- 在每一步 \(t\),解码器基于上一步隐藏状态 \(s_{t-1}\) 和上下文向量 \(c_t\)(通过注意力机制计算)生成当前状态 \(s_t\)。
- 传统解码器仅输出词表上的概率分布 \(P_{\text{vocab}}\),而指针网络额外计算一个复制概率 \(p_{\text{gen}} \in [0,1]\),用于权衡“生成”与“复制”:
\[ p_{\text{gen}} = \sigma(W_c c_t + W_s s_t + W_y y_{t-1} + b) \]
其中 \(W_c, W_s, W_y\) 为可学习参数,\(\sigma\) 为Sigmoid函数。
(3)混合概率分布
- 最终的概率分布由生成概率和复制概率加权组合:
\[ P(w) = p_{\text{gen}} \cdot P_{\text{vocab}}(w) + (1 - p_{\text{gen}}) \cdot \sum_{i: x_i = w} a_{t,i} \]
其中:
- \(P_{\text{vocab}}(w)\) 是词表上 \(w\) 的生成概率;
- \(a_{t,i}\) 是解码步 \(t\) 对输入词 \(x_i\) 的注意力权重;
- \(\sum_{i: x_i = w} a_{t,i}\) 表示所有与 \(w\) 相同的输入词注意力权重之和(即复制概率)。
3. 训练与推理细节
训练目标
- 使用负对数似然损失函数,最小化目标摘要序列的负对数概率:
\[ \mathcal{L} = -\frac{1}{T} \sum_{t=1}^T \log P(w_t^*) \]
其中 \(w_t^*\) 是第 \(t\) 步的真实摘要词。
推理策略
- 在测试时,模型逐词生成摘要:
- 计算 \(p_{\text{gen}}\) 和注意力分布 \(a_t\);
- 若 \(p_{\text{gen}}\) 较大,选择 \(P_{\text{vocab}}\) 中概率最高的词;
- 若 \(p_{\text{gen}}\) 较小,选择注意力权重最高的输入词(直接复制)。
4. 解决重复与冗余问题
- 问题:指针网络可能重复复制输入中的相同词汇,导致摘要冗余。
- 解决方案:
- 覆盖机制(Coverage Mechanism):记录历史步的注意力权重之和,避免重复关注已覆盖的输入词。具体地,在计算注意力时加入覆盖向量:
\[ a_{t,i} \propto \exp(v^T \tanh(W_h h_i + W_s s_t + W_c c_{t-1,i})) \]
其中 $ c_{t-1,i} = \sum_{k=1}^{t-1} a_{k,i} $ 是词 $ x_i $ 的历史注意力权重累积。
总结
指针网络通过动态切换生成与复制模式,有效缓解了OOV问题,特别适合需要保留原文关键信息的任务(如摘要)。结合覆盖机制后,能进一步抑制重复生成,提升摘要质量。该思想后续被扩展至预训练模型(如BART、PEGASUS)中,成为文本生成的重要基础技术。