自注意力机制中的掩码(Mask)技术与因果注意力(Causal Attention)的实现
1. 题目描述
在深度学习的自注意力机制(Self-Attention)中,掩码(Mask) 是一种关键技术,用于控制注意力权重的计算范围。特别是在处理序列数据时,掩码能实现两种关键功能:
- 填充掩码(Padding Mask):屏蔽输入序列中的无效填充位置,防止模型关注无意义的填充符号(如
<PAD>)。 - 因果掩码(Causal Mask):在解码器或自回归生成任务中,确保当前位置只能关注过去位置,屏蔽未来信息,保持生成过程的因果性。
本题将详细讲解掩码技术的原理,并重点推导因果注意力在Transformer解码器中的实现过程,包括掩码的数学形式、计算步骤及其在训练与推理中的作用。
2. 自注意力机制回顾
自注意力的核心是查询(Query)、键(Key)、值(Value) 的三元组计算。给定输入序列矩阵 \(X \in \mathbb{R}^{n \times d}\)(\(n\) 为序列长度,\(d\) 为特征维度),通过线性变换得到 \(Q, K, V\):
\[Q = X W_Q, \quad K = X W_K, \quad V = X W_V \]
注意力权重 \(A\) 通过缩放点积计算:
\[A = \text{softmax}\left( \frac{Q K^T}{\sqrt{d_k}} \right) \in \mathbb{R}^{n \times n} \]
输出为加权和:
\[\text{Output} = A V \]
其中 \(A\) 的每个元素 \(A_{ij}\) 表示位置 \(i\) 对位置 \(j\) 的注意力强度。掩码的目标是修改 \(A\) 的计算,屏蔽某些位置的注意力权重。
3. 填充掩码(Padding Mask)
3.1 问题背景
- 在批量训练中,序列长度可能不同,通常用填充符号(如0)将序列补齐到相同长度。
- 如果不加处理,模型会计算填充位置与其他位置的注意力,引入噪声。
3.2 掩码构造
- 定义一个掩码矩阵 \(M_{\text{pad}} \in \mathbb{R}^{n \times n}\),其中:
\[ M_{\text{pad}}(i,j) = \begin{cases} 0 & \text{如果位置 } j \text{ 是有效输入(非填充)} \\ -\infty & \text{如果位置 } j \text{ 是填充} \end{cases} \]
- 实际实现中,通常对输入序列的填充位置进行标记,然后扩展为矩阵形式。
3.3 应用方式
在计算注意力权重时,将掩码加到缩放点积结果上:
\[A = \text{softmax}\left( \frac{Q K^T}{\sqrt{d_k}} + M_{\text{pad}} \right) \]
加 \(-\infty\) 的位置经过softmax后权重接近0,实现了屏蔽效果。
4. 因果掩码(Causal Mask)
4.1 问题背景
- 在自回归生成任务(如机器翻译、文本生成)中,解码器在生成当前位置的输出时,只能依赖已生成的过去位置,不能看到未来信息。
- 例如,生成第 \(t\) 个词时,只能使用第 \(1, 2, \dots, t-1\) 个词。
4.2 掩码构造
- 因果掩码是一个下三角矩阵 \(M_{\text{causal}} \in \mathbb{R}^{n \times n}\):
\[ M_{\text{causal}}(i,j) = \begin{cases} 0 & \text{如果 } j \le i \quad \text{(允许关注过去及当前)} \\ -\infty & \text{如果 } j > i \quad \text{(屏蔽未来)} \end{cases} \]
例如,当 \(n=4\) 时:
\[ M_{\text{causal}} = \begin{bmatrix} 0 & -\infty & -\infty & -\infty \\ 0 & 0 & -\infty & -\infty \\ 0 & 0 & 0 & -\infty \\ 0 & 0 & 0 & 0 \end{bmatrix} \]
4.3 应用方式
将因果掩码加到注意力计算中:
\[A = \text{softmax}\left( \frac{Q K^T}{\sqrt{d_k}} + M_{\text{causal}} \right) \]
这样,位置 \(i\) 的注意力权重只会分配给位置 \(j \le i\),实现了单向注意力。
5. 结合填充掩码与因果掩码
在Transformer解码器中,通常需要同时应用两种掩码:
- 填充掩码:屏蔽批次中不同序列的填充位置。
- 因果掩码:确保自回归性质。
5.1 合并掩码
将两种掩码相加(元素级):
\[M = M_{\text{pad}} + M_{\text{causal}} \]
- 如果某个位置既是填充又是未来,其掩码值为 \(-\infty\),softmax后权重为0。
- 如果某个位置是有效过去,掩码值为0,正常计算注意力。
5.2 计算步骤
完整实现流程:
- 输入:序列矩阵 \(X\)(包含填充标记)。
- 线性变换:计算 \(Q, K, V\)。
- 缩放点积:计算 \(S = \frac{Q K^T}{\sqrt{d_k}}\)。
- 添加掩码:
- 根据填充标记生成 \(M_{\text{pad}}\)(形状 \(n \times n\))。
- 生成下三角矩阵 \(M_{\text{causal}}\)。
- 计算 \(S_{\text{masked}} = S + M_{\text{pad}} + M_{\text{causal}}\)。
- Softmax与输出:
\[ A = \text{softmax}(S_{\text{masked}}), \quad \text{Output} = A V \]
6. 训练与推理中的因果注意力
6.1 训练阶段
- 使用完整序列进行训练,但通过因果掩码确保每个位置只能看到左侧上下文。
- 例如,输入序列为
[A, B, C, D],预测目标为[B, C, D, E]。
在计算位置C的注意力时,掩码确保它只能关注[A, B],无法看到D。 - 这种机制使得模型可以并行计算所有位置的输出,同时保持自回归性质。
6.2 推理阶段(自回归生成)
- 初始输入为起始符(如
<SOS>)。 - 每次生成一个词,将新词添加到输入序列末尾。
- 重新计算注意力时,由于因果掩码的存在,新位置只能看到已生成的所有词。
- 重复直到生成结束符(如
<EOS>)。
7. 数学示例说明
假设序列长度 \(n=3\),特征维度 \(d_k=1\),简化计算:
- 输入 \(X = [x_1, x_2, x_3]\),无填充。
- 计算 \(Q = K = [1, 2, 3]\)(标量化简)。
- 缩放点积:
\[ S = Q K^T = \begin{bmatrix} 1 & 2 & 3 \\ 2 & 4 & 6 \\ 3 & 6 & 9 \end{bmatrix} \]
- 加因果掩码:
\[ S + M_{\text{causal}} = \begin{bmatrix} 1 & -\infty & -\infty \\ 2 & 4 & -\infty \\ 3 & 6 & 9 \end{bmatrix} \]
- Softmax(按行):
\[ A = \begin{bmatrix} \text{softmax}([1, -\infty, -\infty]) \\ \text{softmax}([2, 4, -\infty]) \\ \text{softmax}([3, 6, 9]) \end{bmatrix} = \begin{bmatrix} 1 & 0 & 0 \\ 0.12 & 0.88 & 0 \\ 0.00 & 0.02 & 0.98 \end{bmatrix} \]
可见,位置1只关注自身,位置2关注前两个位置,位置3关注全部。
8. 核心总结
- 填充掩码:防止注意力分散到无效填充位置,提升模型鲁棒性。
- 因果掩码:通过下三角矩阵强制实现自回归生成,是Transformer解码器的核心组件。
- 实现本质:在注意力权重计算前,将掩码矩阵(\(-\infty\) 与 0)加到缩放点积结果上,通过softmax将屏蔽位置权重压至0。
- 意义:掩码技术使得Transformer能够灵活处理变长序列并执行序列生成任务,奠定了其在NLP、语音等序列建模领域的基石。
通过以上步骤,掩码技术将自注意力的双向全局计算转化为受控的、符合任务需求的形式,是理解Transformer架构及其应用的关键。