基于噪声对比估计(Noise-Contrastive Estimation, NCE)的神经语言模型训练算法详解
题目描述
在自然语言处理中,神经语言模型(如早期的神经网络语言模型,NNLM)通常需要预测给定上下文时下一个单词的概率分布。这个分布的大小等于整个词表的大小V(通常是数万甚至数十万)。传统的训练方法采用Softmax函数对V个词的概率进行归一化,这需要计算所有V个词的得分(logits),然后计算归一化概率,计算成本为O(V),在V很大时非常昂贵。
噪声对比估计 是一种解决此问题的训练技巧。其核心思想是:不直接计算整个词表的概率分布,而是将语言模型训练任务转换成一个二分类任务。具体来说,对于每个真实的训练样本(真实数据),我们将其视为“正样本”;同时,从某个简单的噪声分布(如一元词频分布)中采样出k个噪声词,构成“负样本”。然后训练模型去区分当前样本是来自真实数据分布还是噪声分布。通过这种方式,模型在避免计算整个词表归一化概率的情况下,依然能学到有用的词表示。
解题过程循序渐进讲解
第一步:理解传统Softmax的瓶颈
假设我们有一个标准的神经网络语言模型。给定上下文词序列 \(c\)(例如前n-1个词),模型输出一个隐藏表示,然后通过一个输出权重矩阵 \(W\)(形状为 \(d \times V\),d是隐藏层维度)计算出V个候选词的得分(logits)\(s = W^T h\)。为了得到概率分布,我们需要对所有得分进行Softmax归一化:
\[P(w | c) = \frac{\exp(s_w)}{\sum_{j=1}^{V} \exp(s_j)} \]
其中分母的求和计算量是O(V),是训练时的主要瓶颈。
第二步:噪声对比估计的核心思想
NCE的核心洞察是:我们并不需要精确计算完整的概率分布 \(P(w | c)\),只需要让模型学会给真实数据样本比噪声样本更高的分数。为此,NCE将问题重构为一个二分类任务:
- 正样本:从真实数据分布 \(P_d(w | c)\) 中采样的(上下文c,目标词w)对。
- 负样本:从某个简单的、容易采样的噪声分布 \(Q(w)\) 中采样k个噪声词 \(w_1, w_2, ..., w_k\)。通常 \(Q(w)\) 选择为一元词频分布(unigram distribution),即每个词在训练语料中出现的频率。
对于每个训练样本(上下文c,真实词 \(w_t\)),我们将其与k个从 \(Q(w)\) 中独立采样的噪声词 \(w_{n1}, w_{n2}, ..., w_{nk}\) 组合在一起。然后训练模型来区分:哪个词是来自真实数据分布 \(P_d\) 的?哪些是来自噪声分布 \(Q\) 的?
第三步:定义二分类概率
模型需要为任意一个(上下文c,词w)对计算一个“得分”,这个得分反映了w来自真实数据分布 \(P_d\) 而非噪声分布 \(Q\) 的可能性。我们用 \(P_{\theta}(D=1 | w, c)\) 表示给定(c, w)对,它来自真实数据分布的概率;用 \(P_{\theta}(D=0 | w, c) = 1 - P_{\theta}(D=1 | w, c)\) 表示它来自噪声分布的概率。
NCE假设真实数据分布的对数概率可以通过一个参数化的函数 \(s_{\theta}(w, c)\) 加上一个对数配分函数的常数来建模,但NCE的一个巧妙之处在于,它将这个对数配分函数也作为一个可学习的参数 \(b_c\)(依赖于上下文c)。具体地,定义:
\[\log P_{\theta}(w | c) = s_{\theta}(w, c) - b_c \]
然后,根据概率的比值,我们定义二分类概率:
\[P_{\theta}(D=1 | w, c) = \sigma( \log P_{\theta}(w | c) - \log Q(w) ) = \sigma( s_{\theta}(w, c) - b_c - \log Q(w) ) \]
其中 \(\sigma(x) = 1/(1+\exp(-x))\) 是sigmoid函数。
这个公式直观解释是:如果模型给的真实数据概率 \(P_{\theta}(w|c)\) 相对于噪声概率 \(Q(w)\) 的比值越大,sigmoid的输出越接近1,即模型越相信(w, c)来自真实分布。
第四步:构建损失函数
对于一个训练样本(上下文c,真实词 \(w_t\))和k个噪声词 \(\{w_{ni}\}_{i=1}^k\):
- 对于正样本 \(w_t\),我们希望 \(P_{\theta}(D=1 | w_t, c)\) 接近1。
- 对于每个噪声样本 \(w_{ni}\),我们希望 \(P_{\theta}(D=1 | w_{ni}, c)\) 接近0(等价地, \(P_{\theta}(D=0 | w_{ni}, c)\) 接近1)。
因此,损失函数是这些二分类交叉熵损失的总和。标准的NCE损失函数定义为:
\[J_{NCE}(\theta) = - \left[ \log P_{\theta}(D=1 | w_t, c) + \sum_{i=1}^{k} \log P_{\theta}(D=0 | w_{ni}, c) \right] \]
代入sigmoid公式:
\[P_{\theta}(D=1 | w, c) = \sigma( s_{\theta}(w, c) - b_c - \log Q(w) ) \]
\[ P_{\theta}(D=0 | w, c) = 1 - P_{\theta}(D=1 | w, c) = \sigma( -(s_{\theta}(w, c) - b_c - \log Q(w)) ) \]
所以损失函数为:
\[J_{NCE}(\theta) = - \left[ \log \sigma( s_{\theta}(w_t, c) - b_c - \log Q(w_t) ) + \sum_{i=1}^{k} \log \sigma( -(s_{\theta}(w_{ni}, c) - b_c - \log Q(w_{ni}) ) \right] \]
在这个损失函数中,我们只需要为1个正样本词和k个噪声词(k通常远小于V,比如k=5~100)计算得分 \(s_{\theta}(w, c)\),而不需要为整个词表的V个词计算得分。这大大降低了计算复杂度,从O(V)降到了O(k)。
第五步:训练过程与参数更新
- 前向传播:对于给定的上下文c,模型计算出隐藏表示h。
- 采样噪声词:从预定义的噪声分布 \(Q(w)\) 中采样k个词。\(Q(w)\) 通常取训练语料中的一元词频分布,但实践中为了提高高频噪声词的采样概率,会对频率取3/4次幂(类似Word2Vec的负采样技巧)。
- 计算得分:只为真实词 \(w_t\) 和k个噪声词,从输出嵌入矩阵中查找对应的输出向量,并计算与h的点积(或其它相似度函数),得到得分 \(s_{\theta}(w, c)\)。
- 计算损失:如上所述,计算NCE损失。
- 反向传播与更新:计算损失关于模型参数 \(\theta\) 和上下文相关的偏置 \(b_c\) 的梯度,并更新它们。
随着训练的进行,模型不仅学会了为真实词分配更高的分数,而且学到的输出词向量(即 \(s_{\theta}(w, c)\) 计算中使用的词嵌入)也能捕获有意义的语义信息。一旦模型训练完成,我们可以恢复近似的归一化概率:\(P_{\theta}(w | c) \approx \exp(s_{\theta}(w, c)) / Z(c)\),其中 \(Z(c)\) 可以通过在少量词上估计得到,但在很多下游任务(如获取词向量)中,甚至不需要这个明确的概率。
第六步:NCE与负采样(Negative Sampling)的关系
NCE是一个更通用的框架。Word2Vec的负采样 可以看作是NCE的一个特例。在负采样中,它做出了一个关键简化:固定了偏置项 \(b_c = \log Z(c)\),并假设 \(Z(c) = 1\)。这样,二分类概率简化为:
\[P(D=1 | w, c) = \sigma( s_{\theta}(w, c) - \log Q(w) ) \]
在Word2Vec的Skip-gram模型中,上下文c是中心词,目标w是周围词,得分 \(s_{\theta}(w, c)\) 是中心词向量和周围词向量的点积。负采样损失为:
\[J_{NS} = -\log \sigma( v_c \cdot v_w ) - \sum_{i=1}^{k} \log \sigma( -v_c \cdot v_{n_i} ) \]
这正是NCE在上述简化假设下的结果。因此,负采样是一种更简单的、近似版的NCE,其目标是学习好的词向量,而非精确的语言模型概率。
总结
噪声对比估计 通过将多分类问题转化为一系列二分类问题,巧妙地规避了Softmax中归一化分母的巨大计算开销。它通过区分真实数据和噪声样本来训练模型,使得模型在仅计算少量(k+1个)词得分的情况下,就能有效地学习语言模型的参数和高质量的词表示。这种思想不仅加速了训练,也催生了像Word2Vec负采样这样影响深远的技术。