基于多示例学习(Multi-Instance Learning, MIL)的文本分类算法详解
1. 题目描述
在多示例学习(MIL)框架中,一个训练样本不再是一个简单的、有明确标签的数据点,而是一个“包”。每个“包”包含多个“示例”,而标签是赋予整个包的,而非包内单个示例。核心假设是:如果包是正类,则包内至少包含一个或多个关键的正示例;如果包是负类,则包内所有示例都是负例。
将此框架应用于文本分类,一种常见场景是弱监督文本分类,特别是文档级情感分类或主题分类。例如,我们只有整个文档(包)的情感标签(如正面/负面),但不知道具体是哪个句子(示例)承载了该情感。目标是训练一个模型,既能正确分类新文档,也能在一定程度上识别出文档内的关键信息片段。
2. 解题过程循序渐进讲解
我们将围绕一个典型的基于神经网络的MIL文本分类模型展开,其核心是“包-示例”的聚合机制。
步骤一:问题形式化与输入表示
- 问题定义: 假设我们有N个训练文档。第i个文档被表示为一个包 \(B_i = \{ x_{i1}, x_{i2}, ..., x_{iK} \}\),其中 \(x_{ij}\) 是文档中的第j个句子(或段落)的向量表示,K是句子数(可以不同)。包的标签为 \(Y_i \in \{0, 1\}\),1代表正面/目标主题,0代表负面/非目标主题。
- 示例编码: 对包内的每个句子(示例)进行编码,将其转换为一个固定维度的向量表示 \(h_{ij}\)。
- 方法: 使用一个句子级编码器,如一个简单的RNN(GRU/LSTM)、CNN或Transformer的CLS向量。例如,使用BERT的[CLS]向量作为句子表示:
\[ h_{ij} = \text{BERT}_{\text{sentence}}(x_{ij}) \]
* **目标**: 得到每个句子的语义表示向量 $ h_{ij} \in \mathbb{R}^d $。
步骤二:包内示例的评分与注意力聚合
这是MIL的核心。我们需要从包内的多个示例向量 \(\{ h_{i1}, ..., h_{iK} \}\) 中,聚合出一个能代表整个包的向量表示 \(H_i\)。关键是,聚合过程应能自动“关注”到那些可能是正类的关键句子。
- 示例评分: 为包内每个示例计算一个“重要性”得分,表示该示例是“关键正例”的可能性。
- 实现: 引入一个注意力网络。首先,将每个示例向量 \(h_{ij}\) 通过一个前馈神经网络(通常是一个MLP)和一个非线性激活函数(如tanh),映射到一个标量能量值 \(e_{ij}\)。
\[ e_{ij} = w^T \cdot (\tanh(V \cdot h_{ij}^T + b)) + c \]
其中,$ V, w, b, c $ 是可学习参数。
- 注意力权重计算: 对包内所有示例的能量值进行Softmax归一化,得到每个示例的注意力权重 \(a_{ij}\),权重之和为1。
- 公式:
\[ a_{ij} = \frac{\exp(e_{ij})}{\sum_{k=1}^{K} \exp(e_{ik})} \]
* **解释**: Softmax确保了模型必须将所有注意力权重分配给包内的句子。对于一个正类包,模型会倾向于给那些表达正面情感的关键句子分配高权重;对于一个负类包,由于假设所有句子都是负例,注意力权重可能会相对均匀分布,或者集中在某些“典型”的负例句子上。
- 包表示向量聚合: 将包内所有示例的向量表示,按其注意力权重进行加权求和,得到最终的包表示 \(H_i\)。
- 公式:
\[ H_i = \sum_{j=1}^{K} a_{ij} \cdot h_{ij} \]
* **结果**: $ H_i $ 是一个融合了包内所有句子信息,但侧重于关键句子的文档级表示。
步骤三:包级分类与损失计算
- 分类: 将聚合得到的包表示向量 \(H_i\) 送入一个分类器(如一个简单的线性层 + Sigmoid激活函数),预测整个包的标签 \(\hat{Y}_i\)。
- 公式:
\[ \hat{Y}_i = \sigma(W_c \cdot H_i + b_c) \]
其中,$ W_c, b_c $ 是分类器参数,$ \sigma $ 是Sigmoid函数。
- 损失函数: 使用标准分类损失(如二分类交叉熵损失)来衡量包级预测与真实包标签之间的差距。
- 公式:
\[ \mathcal{L} = -\frac{1}{N} \sum_{i=1}^{N} [Y_i \log(\hat{Y}_i) + (1 - Y_i) \log(1 - \hat{Y}_i)] \]
- 端到端训练: 整个模型(示例编码器、注意力网络、分类器)通过反向传播算法,利用包级标签进行端到端训练。模型会自动学会:1) 如何编码句子;2) 如何识别关键句子(通过注意力机制);3) 如何基于关键信息分类整个文档。
步骤四:推断与解释
- 新文档分类: 给定一个新文档,将其分割为句子,通过训练好的模型进行处理,得到最终的包级预测概率 \(\hat{Y}_{\text{new}}\)。
- 关键片段识别(可解释性): MIL模型的一个显著优势是能够提供解释。我们可以查看模型在推断过程中为每个句子分配的注意力权重 \(a_{ij}\)。权重最高的那几个句子,就是模型认为对做出“正面”(或“目标类别”)判断贡献最大的关键句子。这对于情感分析、虚假新闻检测、医学报告分析等需要理由的场景非常有用。
总结: 基于多示例学习的文本分类算法,将弱监督的文档分类问题巧妙地转化为“包-示例”学习问题。通过引入注意力聚合机制,模型能够在仅有文档级标签的情况下,自动定位文档内的关键证据片段,并基于此做出分类决策。它有效地桥接了句子级理解和文档级任务之间的鸿沟,增强了模型的可解释性。