Transformer模型中的自注意力机制(Self-Attention)原理与计算步骤
题目描述:
在Transformer模型中,自注意力机制是其核心组件。它允许序列中的每个位置(例如,一个句子中的每个单词)在计算其新的表示时,能够同时关注到序列中的所有其他位置。请你详细解释自注意力机制的计算过程,包括如何生成查询(Query)、键(Key)和值(Value)向量,以及如何通过缩放点积运算得到最终的注意力权重和输出。
解题过程:
自注意力机制的目标是为输入序列中的每个元素生成一个包含全局上下文信息的新表示。我们循序渐进地分解其计算步骤。
第一步:输入表示
假设我们的输入是一个序列,包含 \(n\) 个元素(例如,n个单词)。每个元素用一个 \(d_{model}\) 维的向量表示。因此,整个输入可以表示为一个矩阵 \(X\),其形状为 \(n \times d_{model}\)。
\[X = \begin{bmatrix} x_1 \\ x_2 \\ \vdots \\ x_n \end{bmatrix} \]
其中,\(x_i\) 是序列中第 \(i\) 个位置的向量。
第二步:生成查询(Q)、键(K)、值(V)矩阵
自注意力机制引入三个可学习的权重矩阵:\(W^Q\)(查询权重),\(W^K\)(键权重),\(W^V\)(值权重)。每个权重矩阵的形状都是 \(d_{model} \times d_k\)(对于Q和K)或 \(d_{model} \times d_v\)(对于V)。在原始Transformer论文中,通常设定 \(d_k = d_v = d_{model}/h\),其中 \(h\) 是注意力头的数量。这里我们先考虑单头注意力(h=1)的情况,所以 \(d_k = d_v = d_{model}\)。
通过将输入矩阵 \(X\) 与这些权重矩阵相乘,我们得到查询矩阵 \(Q\)、键矩阵 \(K\) 和值矩阵 \(V\)。
\[Q = X W^Q, \quad K = X W^K, \quad V = X W^V \]
- \(Q\) 的每一行 \(q_i\) 对应输入 \(x_i\) 的“查询”向量,表示“我想从其他位置获取什么信息”。
- \(K\) 的每一行 \(k_i\) 对应输入 \(x_i\) 的“键”向量,表示“我能向其他位置提供什么信息标识”。
- \(V\) 的每一行 \(v_i\) 对应输入 \(x_i\) 的“值”向量,表示“我实际包含的信息内容”。
第三步:计算注意力分数(Attention Scores)
接下来,我们需要计算每个查询向量与所有键向量之间的相关性分数。这通过计算 \(Q\) 和 \(K\) 的转置的点积来实现。
\[\text{Scores} = Q K^T \]
得到的分数矩阵形状为 \(n \times n\)。其中,第 \(i\) 行、第 \(j\) 列的元素 \(s_{ij}\) 表示第 \(i\) 个查询(对应位置i)与第 \(j\) 个键(对应位置j)之间的相关性分数。分数越高,表示位置i与位置j的相关性越强。
第四步:缩放(Scaling)
为了防止点积的结果过大(特别是当 \(d_k\) 维度较高时),导致Softmax函数的梯度消失,我们需要对分数进行缩放。缩放因子是 \(\sqrt{d_k}\)。
\[\text{Scaled Scores} = \frac{\text{Scores}}{\sqrt{d_k}} = \frac{Q K^T}{\sqrt{d_k}} \]
这一步确保了梯度在训练过程中更稳定。
第五步:应用Softmax函数得到注意力权重(Attention Weights)
对缩放后的分数矩阵的每一行应用Softmax函数。Softmax函数将每一行的分数转换为概率分布,使得每一行的所有元素之和为1,并且每个元素的值在0到1之间。
\[A = \text{softmax}\left( \frac{Q K^T}{\sqrt{d_k}} \right) \]
这里,矩阵 \(A\) 就是注意力权重矩阵,形状也是 \(n \times n\)。元素 \(a_{ij}\) 表示位置i对位置j的注意力权重,即“当计算位置i的新表示时,应该从位置j汲取多少信息”。
第六步:计算加权和输出
最后,我们将注意力权重矩阵 \(A\) 与值矩阵 \(V\) 相乘,得到自注意力层的输出矩阵 \(Z\)。
\[Z = A V \]
输出矩阵 \(Z\) 的形状也是 \(n \times d_v\)。它的每一行 \(z_i\) 是原始输入序列中所有位置值的加权和,权重由位置i对所有位置的注意力权重决定。因此,\(z_i\) 是一个融合了序列全局上下文信息的新表示。
总结:
自注意力机制的核心计算步骤可以总结为以下公式:
\[Z = \text{softmax}\left( \frac{Q K^T}{\sqrt{d_k}} \right) V \]
这个过程使得序列中的每个元素都能够动态地、根据当前上下文信息,有选择地聚合序列中所有元素的信息,从而生成更强大的特征表示。