基于对比学习的句子表示学习算法:InfoNCE损失函数详解
算法题目描述
InfoNCE(Noise-Contrastive Estimation 的一种变体,全称 Information Noise-Contrastive Estimation)是一种用于自监督对比学习的损失函数,广泛应用于句子表示学习。其核心思想是通过对比正样本对(语义相似的句子对)和负样本对(语义不相关的句子对)来学习高质量的句子嵌入表示,使得相似句子在表示空间中距离更近,不相似句子距离更远。InfoNCE损失函数通过构造一个多分类任务来实现这一目标,是SimCLR、MoCo等经典对比学习框架的基础。
解题过程循序渐进讲解
步骤1:理解对比学习的基本框架
在句子表示学习中,对比学习的目标是学习一个编码器 \(f(\cdot)\)(通常为BERT、RoBERTa等预训练模型),将输入句子 \(x\) 映射为一个固定维度的向量表示 \(h = f(x)\)。为了训练这个编码器,我们需要:
- 数据增强:为每个句子 \(x\) 生成其增强版本 \(x^+\)(例如回译、删除词语、同义词替换等),构成正样本对 \((x, x^+)\)。
- 负样本构建:从同一个批次(batch)中随机选择其他句子或其增强版本作为负样本 \(x^-\),与 \(x\) 构成负样本对。
- 对比目标:优化编码器,使正样本对的表示向量相似度尽可能高,负样本对的相似度尽可能低。
步骤2:定义相似度度量
首先,我们需要一个函数来衡量两个句子表示向量的相似度。通常使用余弦相似度(cosine similarity):
\[\text{sim}(h_i, h_j) = \frac{h_i \cdot h_j}{\|h_i\| \|h_j\|} \]
其中 \(h_i, h_j\) 是经过L2归一化后的向量,因此相似度范围在[-1, 1]之间,值越大表示越相似。
步骤3:构建InfoNCE损失函数
假设我们有一个批次包含 \(N\) 个句子,每个句子通过数据增强得到两个视图,总共 \(2N\) 个样本。对于每个样本 \(i\),其增强版本 \(j(i)\) 是它的正样本,批次中其他 \(2(N-1)\) 个样本都是负样本。
InfoNCE损失函数为:
\[\mathcal{L}_i = -\log \frac{\exp(\text{sim}(h_i, h_{j(i)}) / \tau)}{\sum_{k=1}^{2N} \mathbb{1}_{[k \neq i]} \exp(\text{sim}(h_i, h_k) / \tau)} \]
其中:
- \(h_i\) 是锚点样本(anchor)的表示。
- \(h_{j(i)}\) 是正样本的表示。
- 分母对所有样本(包括正样本和负样本)求和,但排除了 \(k=i\) 自身(避免与自身对比,因为自身是完全相同的样本,没有对比意义)。
- \(\tau\) 是一个温度超参数(temperature),控制相似度分布的平滑程度。较小的 \(\tau\) 会放大相似度差异,使损失更关注困难的负样本。
- \(\mathbb{1}_{[k \neq i]}\) 是指示函数,当 \(k \neq i\) 时为1,否则为0。
步骤4:直观理解InfoNCE
InfoNCE可以看作一个 \(2N-1\) 类的分类任务:给定锚点样本 \(h_i\),需要从 \(2N-1\) 个候选样本(1个正样本 + \(2N-2\) 个负样本)中正确识别出正样本 \(h_{j(i)}\)。分子鼓励正样本的相似度得分高,分母惩罚所有负样本的相似度得分高。对数损失使得模型能够区分正样本和所有负样本。
步骤5:温度参数 \(\tau\) 的作用
温度参数 \(\tau\) 是InfoNCE的关键超参数:
- 当 \(\tau\) 很小时(如0.05),指数项 \(\exp(\text{sim} / \tau)\) 会放大相似度的差异。模型会更关注那些与锚点相似度较高的困难负样本(hard negatives),从而学习更精细的判别特征,但训练可能不稳定。
- 当 \(\tau\) 较大时(如1.0),相似度差异被平滑,模型对所有负样本一视同仁,容易训练但可能无法学到足够好的表示。
通常,\(\tau\) 需要通过实验调优,常见值在0.05到0.2之间。
步骤6:对称损失计算
在实际训练中,为了使每个样本都得到充分学习,通常计算对称的InfoNCE损失。即对于每个正样本对 \((i, j)\),计算两次损失:一次以 \(i\) 为锚点,\(j\) 为正样本;另一次以 \(j\) 为锚点,\(i\) 为正样本。最终损失是所有样本损失的平均:
\[\mathcal{L} = \frac{1}{2N} \sum_{i=1}^{2N} \mathcal{L}_i \]
步骤7:算法训练流程
- 输入一个批次 \(N\) 个原始句子。
- 对每个句子应用两次数据增强,得到 \(2N\) 个增强句子。
- 用编码器 \(f(\cdot)\) 计算所有增强句子的表示向量 \(h_1, h_2, ..., h_{2N}\),并进行L2归一化。
- 对于每个样本 \(i\),在批次中找到其对应的正样本索引 \(j(i)\)。
- 使用InfoNCE公式计算每个样本的损失 \(\mathcal{L}_i\),并求平均得到总损失。
- 通过反向传播更新编码器参数,最小化总损失。
步骤8:与互信息(Mutual Information)的联系
InfoNCE的名称来源于其理论背景:它最大化锚点样本 \(x\) 与其正样本 \(x^+\) 的表示向量之间的互信息下界。具体来说,最小化InfoNCE损失等价于最大化 \(I(h_i; h_{j(i)})\) 的下界,即鼓励表示向量捕获两个视图之间的共享信息(语义内容),而忽略无关的噪声(具体增强方式)。这使得学习到的表示具有很好的泛化性。
总结
InfoNCE通过构建一个多分类对比任务,利用正负样本对的相似度比较,驱动编码器学习判别性的句子表示。其核心优势在于理论严谨(与互信息最大化关联)、实践有效,已成为句子表示学习中的标准损失函数。掌握InfoNCE的关键在于理解其对比机制、温度参数的作用以及对称损失的计算方式,从而能够将其应用于各种自监督表示学习场景。