图注意力网络(GAT)的消息传递与节点分类过程
题目描述
图注意力网络(Graph Attention Network, GAT)是一种基于注意力机制的图神经网络架构,用于处理图结构数据。与传统的图卷积网络(GCN)不同,GAT通过自注意力机制为每个节点的邻居分配不同的权重,从而更灵活地捕捉图中节点间的关系。本题将详细讲解GAT的消息传递机制和节点分类过程,包括注意力系数的计算、节点特征的聚合以及多层网络的堆叠方法。
解题过程
1. 图注意力层的基本结构
GAT的核心是图注意力层,它通过注意力机制对邻居节点的特征进行加权聚合。假设图中每个节点 \(i\) 的特征向量为 \(\mathbf{h}_i \in \mathbb{R}^F\),其中 \(F\) 是特征维度。图注意力层的目标是生成每个节点的新特征 \(\mathbf{h}_i' \in \mathbb{R}^{F'}\)。
步骤:
- 线性变换:首先,对每个节点的特征应用一个可学习的权重矩阵 \(\mathbf{W} \in \mathbb{R}^{F' \times F}\),将输入特征映射到新的特征空间:
\[ \mathbf{z}_i = \mathbf{W} \mathbf{h}_i \]
这里,\(\mathbf{z}_i\) 是节点 \(i\) 变换后的特征。
- 注意力系数计算:对于每对相邻节点 \(i\) 和 \(j\)(包括自连接 \(j=i\)),计算注意力系数 \(e_{ij}\),表示节点 \(j\) 对节点 \(i\) 的重要性:
\[ e_{ij} = a(\mathbf{z}_i, \mathbf{z}_j) \]
其中,\(a\) 是一个共享的注意力函数,通常实现为单层前馈神经网络。具体地:
\[ e_{ij} = \text{LeakyReLU}\left(\mathbf{a}^\top [\mathbf{z}_i \| \mathbf{z}_j]\right) \]
这里,\(\mathbf{a} \in \mathbb{R}^{2F'}\) 是注意力机制的参数向量,\(\|\) 表示向量拼接,LeakyReLU 是一个非线性激活函数(斜率设为 0.2)。
- 归一化注意力系数:使用 softmax 函数对注意力系数进行归一化,确保所有邻居的权重之和为 1:
\[ \alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}(i)} \exp(e_{ik})} \]
其中,\(\mathcal{N}(i)\) 是节点 \(i\) 的邻居节点集合(包括自身)。归一化后的 \(\alpha_{ij}\) 即为节点 \(j\) 对节点 \(i\) 的最终注意力权重。
- 特征聚合:将归一化后的注意力权重用于加权求和,得到节点 \(i\) 的新特征:
\[ \mathbf{h}_i' = \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij} \mathbf{z}_j\right) \]
其中,\(\sigma\) 是一个非线性激活函数(如 ELU 或 ReLU)。如果使用多头注意力(multi-head attention),则会将多个注意力头的输出拼接或平均作为最终特征。
2. 多头注意力机制
为了稳定学习过程并增强表达能力,GAT 通常采用多头注意力。假设有 \(K\) 个注意力头,每个头独立计算一组注意力权重和聚合特征。对于节点 \(i\),其输出特征为:
- 拼接方式(用于中间层):
\[ \mathbf{h}_i' = \|_{k=1}^K \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij}^k \mathbf{W}^k \mathbf{h}_j\right) \]
其中,\(\alpha_{ij}^k\) 和 \(\mathbf{W}^k\) 是第 \(k\) 个注意力头的参数。
- 平均方式(用于输出层):
\[ \mathbf{h}_i' = \sigma\left(\frac{1}{K} \sum_{k=1}^K \sum_{j \in \mathcal{N}(i)} \alpha_{ij}^k \mathbf{W}^k \mathbf{h}_j\right) \]
3. 节点分类过程
在节点分类任务中,GAT 通过堆叠多个图注意力层来构建深度网络。以两层 GAT 为例:
- 第一层:使用多头注意力(例如 \(K=8\))和 ELU 激活函数,将输入特征转换为隐藏特征。
- 第二层:使用单头注意力或平均多头注意力,输出维度等于类别数 \(C\),并通过 softmax 函数生成类别概率:
\[ \mathbf{y}_i = \text{softmax}\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij} \mathbf{W} \mathbf{h}_j\right) \]
其中,\(\mathbf{y}_i \in \mathbb{R}^C\) 是节点 \(i\) 的预测概率分布。
4. 训练与优化
GAT 通过最小化交叉熵损失函数进行训练:
\[\mathcal{L} = -\sum_{i \in \mathcal{Y}} \sum_{c=1}^C y_{ic} \log \hat{y}_{ic} \]
其中,\(\mathcal{Y}\) 是带标签的节点集合,\(y_{ic}\) 是真实标签的 one-hot 编码,\(\hat{y}_{ic}\) 是预测概率。优化器(如 Adam)用于更新权重矩阵 \(\mathbf{W}\) 和注意力参数 \(\mathbf{a}\)。
关键点说明
- 自注意力机制:GAT 允许节点为不同邻居分配不同权重,无需依赖预定义的图结构(如度矩阵)。
- 计算效率:注意力系数可以并行计算,适用于大规模图数据。
- 归纳能力:GAT 能够泛化到未见过的新图,适用于动态图或迁移学习场景。
通过以上步骤,GAT 能够有效学习图中节点的表示,并在节点分类、链接预测等任务中取得优异性能。