基于图神经网络的图自编码器(Graph Autoencoder, GAE)的图重建与变分图自编码器(Variational Graph Autoencoder, VGAE)的变分下界推导
题目描述
我们讨论基于图神经网络的图自编码器(GAE)及其变分版本(VGAE)。
图自编码器是一种用于图结构数据的无监督表示学习框架。其核心思想是:
- 使用图卷积网络(GCN) 作为编码器,将节点特征和邻接矩阵映射为低维节点嵌入向量。
- 通过解码器 从节点嵌入向量重建图的邻接矩阵(或其他结构信息)。
- 在变分图自编码器中,编码器输出节点嵌入的概率分布,并通过变分推断优化证据下界。
本题目将详细讲解GAE的结构、VGAE的变分下界推导,以及训练过程。
解题过程
步骤1:问题定义与符号说明
设有无向图 \(G = (\mathcal{V}, \mathcal{E})\),其中:
- 节点集合 \(\mathcal{V}\) 包含 \(N\) 个节点。
- 邻接矩阵 \(A \in \{0,1\}^{N \times N}\),若节点 \(i\) 与 \(j\) 有边,则 \(A_{ij} = 1\);对角线元素设为0(无自环)。
- 节点特征矩阵 \(X \in \mathbb{R}^{N \times D}\),每行是节点的 \(D\) 维特征向量。
目标:学习节点的低维嵌入 \(Z \in \mathbb{R}^{N \times d}\)( \(d \ll D\)),使得嵌入能保留图的结构与节点特征信息。
步骤2:图自编码器(GAE)的编码器
GAE使用两层GCN作为编码器:
\[Z = \text{GCN}(X, A) = \tilde{A} \, \text{ReLU}(\tilde{A} X W_0) \, W_1 \]
其中:
- \(\tilde{A} = D^{-\frac{1}{2}} A D^{-\frac{1}{2}}\) 是归一化的对称邻接矩阵(添加自环后的度矩阵归一化)。
- \(D\) 是度矩阵,\(D_{ii} = \sum_j A_{ij}\)。
- \(W_0 \in \mathbb{R}^{D \times H}\)、\(W_1 \in \mathbb{R}^{H \times d}\) 是可训练权重矩阵。
- 输出 \(Z\) 的每一行 \(z_i \in \mathbb{R}^d\) 是节点 \(i\) 的嵌入向量。
步骤3:GAE的解码器与重建目标
解码器通过内积 和sigmoid函数重建邻接矩阵:
\[\hat{A} = \sigma(Z Z^\top) \]
其中 \(\hat{A}_{ij} = \sigma(z_i^\top z_j)\) 表示节点 \(i\) 和 \(j\) 之间存在边的预测概率,\(\sigma\) 是sigmoid函数。
重建损失 采用交叉熵:
\[\mathcal{L}_{\text{GAE}} = -\sum_{(i,j) \in \mathcal{V} \times \mathcal{V}} \left[ A_{ij} \log \hat{A}_{ij} + (1 - A_{ij}) \log (1 - \hat{A}_{ij}) \right] \]
通常只对已观察的边(\(A_{ij}=1\))和采样的非边(\(A_{ij}=0\))计算损失,以提升效率。
步骤4:变分图自编码器(VGAE)的概率模型
VGAE将节点嵌入视为隐变量,并假设其由高斯分布生成:
- 先验分布:\(p(Z) = \prod_{i=1}^N \mathcal{N}(z_i | 0, I)\)。
- 后验分布 由编码器近似为高斯分布:
\[q(Z | X, A) = \prod_{i=1}^N \mathcal{N}(z_i | \mu_i, \text{diag}(\sigma_i^2)) \]
其中 \(\mu_i\) 和 \(\log \sigma_i\) 由GCN编码器输出:
\[\mu = \text{GCN}_\mu(X, A), \quad \log \sigma = \text{GCN}_\sigma(X, A) \]
- 生成模型(解码器):
\[p(A | Z) = \prod_{i=1}^N \prod_{j=1}^N p(A_{ij} | z_i, z_j), \quad p(A_{ij}=1 | z_i, z_j) = \sigma(z_i^\top z_j) \]
步骤5:VGAE的变分下界推导
目标是最大化观测数据 \(A\) 的对数似然 \(\log p(A | X)\)。引入变分分布 \(q(Z | X, A)\) 近似真实后验 \(p(Z | A, X)\),得到证据下界:
\[\log p(A | X) \ge \mathbb{E}_{q(Z|X,A)}[\log p(A | Z)] - \text{KL}[q(Z | X, A) \| p(Z)] \]
- 第一项:重建损失。从 \(q(Z|X,A)\) 采样 \(Z\)(通过重参数化技巧),计算解码器输出与真实邻接矩阵的交叉熵。
- 第二项:KL散度。由于 \(q\) 和 \(p\) 都是高斯分布,可解析计算:
\[\text{KL}[q \| p] = \frac{1}{2} \sum_{i=1}^N \left( \text{tr}(\sigma_i^2) + \mu_i^\top \mu_i - d - \log \det(\text{diag}(\sigma_i^2)) \right) \]
其中 \(\sigma_i^2\) 是方差向量,\(\text{tr}\) 是迹。
总损失:
\[\mathcal{L}_{\text{VGAE}} = -\mathbb{E}_{q(Z|X,A)}[\log p(A | Z)] + \text{KL}[q(Z | X, A) \| p(Z)] \]
训练时最小化 \(\mathcal{L}_{\text{VGAE}}\)。
步骤6:训练步骤
- 编码器前向传播:输入 \(X, A\) 到两个GCN编码器,得到 \(\mu\) 和 \(\log \sigma\)。
- 重参数化采样:对每个节点 \(i\),采样 \(\epsilon_i \sim \mathcal{N}(0, I)\),计算 \(z_i = \mu_i + \sigma_i \odot \epsilon_i\)。
- 解码重建:计算 \(\hat{A} = \sigma(Z Z^\top)\)。
- 计算损失:
- 重建项:交叉熵损失,在训练时通常对正边(\(A_{ij}=1\))和随机采样的负边(\(A_{ij}=0\))计算。
- KL项:按步骤5公式计算。
- 反向传播:通过梯度下降更新GCN权重。
步骤7:关键点与扩展
- GAE 是确定性编码器,直接输出嵌入向量。
- VGAE 是概率性编码器,能学习嵌入的分布,提高泛化性,适合链路预测等任务。
- 解码器使用简单的内积,计算高效,但可能忽略高阶结构;可替换为更复杂的解码器。
- 训练时通常不输入全部非边,而是对每个正边采样若干负边,以降低计算复杂度。
总结
本题目从图自编码器的基本结构出发,详细讲解了编码器(GCN)、解码器(内积+sigmoid)、重建损失,并深入推导了变分图自编码器的证据下界,包括概率假设、KL散度计算、重参数化技巧。VGAE通过变分推断学习节点嵌入的分布,是图表示学习中重要的无监督方法。