自注意力机制(Self-Attention)中的多头注意力(Multi-Head Attention)原理与实现细节
字数 1350 2025-11-06 12:40:14
自注意力机制(Self-Attention)中的多头注意力(Multi-Head Attention)原理与实现细节
题目描述
多头注意力是Transformer模型的核心组件,它通过并行运行多个自注意力机制来捕捉输入序列中不同子空间的特征关系。本题目要求深入理解多头注意力的设计动机、计算步骤以及如何通过多头的并行处理增强模型的表达能力。
解题过程
-
设计动机
- 单一自注意力机制可能仅聚焦于一种依赖模式(如局部语法结构),但实际任务中需要同时捕捉多种关系(如全局语义、指代关联等)。
- 多头注意力将输入投影到多个子空间,在每个子空间中独立计算注意力,最后合并结果,从而扩展模型的关注维度。
-
输入投影与头拆分
- 设输入序列为 \(X \in \mathbb{R}^{n \times d_{\text{model}}}\)(\(n\) 为序列长度,\(d_{\text{model}}\) 为模型维度)。
- 使用三组可学习的权重矩阵 \(W^Q_i, W^K_i, W^V_i \in \mathbb{R}^{d_{\text{model}} \times d_k}\)(\(i\) 为头索引),将输入分别投影为查询(Query)、键(Key)、值(Value):
\[ Q_i = X W^Q_i, \quad K_i = X W^K_i, \quad V_i = X W^V_i \]
- 通常设 \(d_k = d_v = d_{\text{model}} / h\)(\(h\) 为头数),确保总计算量接近单头注意力。
- 单头注意力计算
- 对每个头 \(i\),计算缩放点积注意力:
\[ \text{Attention}(Q_i, K_i, V_i) = \text{softmax}\left(\frac{Q_i K_i^\top}{\sqrt{d_k}}\right) V_i \]
- 缩放因子 \(\sqrt{d_k}\) 防止点积过大导致梯度消失。
- 多头输出合并
- 将各头的输出拼接为 \(\text{Concat}(\text{head}_1, \dots, \text{head}_h) \in \mathbb{R}^{n \times (h \cdot d_v)}\)。
- 通过可学习矩阵 \(W^O \in \mathbb{R}^{(h \cdot d_v) \times d_{\text{model}}}\) 线性投影,恢复原始维度:
\[ \text{MultiHead}(X) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W^O \]
-
代码实现示例(PyTorch风格)
import torch import torch.nn as nn class MultiHeadAttention(nn.Module): def __init__(self, d_model, h): super().__init__() self.d_model = d_model self.h = h self.d_k = d_model // h self.W_Q = nn.Linear(d_model, d_model) # 拆分为h个头 self.W_K = nn.Linear(d_model, d_model) self.W_V = nn.Linear(d_model, d_model) self.W_O = nn.Linear(d_model, d_model) def forward(self, X): batch_size, n, d_model = X.shape # 投影后重塑为 (batch_size, n, h, d_k) Q = self.W_Q(X).view(batch_size, n, self.h, self.d_k).transpose(1, 2) K = self.W_K(X).view(batch_size, n, self.h, self.d_k).transpose(1, 2) V = self.W_V(X).view(batch_size, n, self.h, self.d_k).transpose(1, 2) # 缩放点积注意力 scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5) attn_weights = torch.softmax(scores, dim=-1) head_output = torch.matmul(attn_weights, V) # (batch_size, h, n, d_k) # 合并多头输出 output = head_output.transpose(1, 2).contiguous().view(batch_size, n, d_model) return self.W_O(output) -
优势分析
- 并行性:各头计算独立,适合GPU加速。
- 多样性:不同头可学习到不同关注模式(如一个头关注局部依赖,另一个头关注长程依赖)。
- 可解释性:可通过可视化注意力权重分析模型关注点。
总结
多头注意力通过分头投影、独立计算、结果合并的机制,增强了模型捕捉复杂依赖关系的能力,是Transformer实现高效并行计算和强大表达能力的关键。