基于图神经网络的零样本图像分类算法
题目描述
在传统的图像分类任务中,模型通常在训练时学习一个固定的、封闭的类别集合。然而,现实世界中存在海量类别,我们无法为所有类别都收集足够多的标注数据。零样本图像分类 的目标是让模型能够识别在训练阶段从未见过的类别。
核心挑战在于,对于这些“未见类”,我们没有其任何训练样本。为了解决这个问题,零样本学习通常引入语义信息 作为桥梁,将视觉特征和类别标签联系起来。这些语义信息可以是类别的属性描述、词向量等。图神经网络擅长处理具有关系结构的数据,因此被用来建模类别之间的复杂语义关系,从而将知识从“已见类”有效地迁移到“未见类”。
本题目将详细讲解一种基于图神经网络的零样本图像分类算法 的核心思想与实现步骤。
解题过程循序渐进讲解
第一步:问题形式化与核心思想
-
定义符号:
- 已见类集合:\(\mathcal{Y}^{s} = \{y_1^s, ..., y_{C^s}^s\}\),共有 \(C^s\) 个类别,每个类别有足够的标注图像。
- 未见类集合:\(\mathcal{Y}^{u} = \{y_1^u, ..., y_{C^u}^u\}\),共有 \(C^u\) 个类别,训练时没有其图像样本。且 \(\mathcal{Y}^{s} \cap \mathcal{Y}^{u} = \varnothing\)。
- 所有类别集合:\(\mathcal{Y} = \mathcal{Y}^{s} \cup \mathcal{Y}^{u}\)。
- 语义信息:每个类别 \(y_i\) 都有一个对应的语义向量 \(\mathbf{a}_i \in \mathbb{R}^d\)。这个向量可以来自于:
- 人工定义的属性:例如,“有羽毛”、“会飞”、“是哺乳动物”等。
- 从大型语料库(如维基百科)学习到的词向量:例如 Word2Vec, GloVe。
- 语言模型(如BERT)生成的描述向量。
-
核心思想:
- 训练一个模型,它能够将输入的图像特征映射到一个语义嵌入空间。
- 在测试时,对于一个属于未见类的图像,模型提取其特征并映射到语义空间,然后与所有类别(包括已见和未见)的语义向量进行比较,通过计算相似度(如余弦相似度)来预测其类别标签。
- 图神经网络的作用是增强类别的语义表示。它不仅仅使用原始的、孤立的语义向量 \(\mathbf{a}_i\),而是将所有这些类别构建成一个图,节点是类别,边表示类别间的关系(如相似性)。GNN通过聚合邻居节点的信息,为每个类别生成一个上下文感知的、更鲁棒的语义表示 \(\mathbf{g}_i\)。这个新表示 \(\mathbf{g}_i\) 比原始的 \(\mathbf{a}_i\) 蕴含了更丰富的结构化知识,有助于在语义空间中进行更准确的匹配。
第二步:构建类别关系图
这是GNN发挥作用的基础。我们需要为所有类别 \(\mathcal{Y}\) 构建一个图 \(G = (V, E)\)。
- 节点:每个类别 \(y_i\) 对应一个图节点 \(v_i\)。节点的初始特征就是其语义向量 \(\mathbf{h}_i^{(0)} = \mathbf{a}_i\)。
- 边:定义节点之间的连接关系。常见的方法有:
- 全连接:所有节点两两相连。虽然简单,但可能引入噪音。
- K-最近邻:计算每对语义向量 \(\mathbf{a}_i\) 和 \(\mathbf{a}_j\) 之间的相似度(如余弦相似度),然后为每个节点只保留与其最相似的K个节点的边。
- 基于知识图谱:如果类别存在于知识图谱(如WordNet)中,则可以直接利用其中的层级(is-a)关系或相关(related-to)关系来定义边。
第三步:设计图神经网络进行语义传播
我们使用GNN来更新每个节点(类别)的特征。这里以常见的图卷积网络 的一层操作为例:
- 消息传递:对于每个节点 \(v_i\),它从其一阶邻居节点 \(\mathcal{N}(i)\) 收集信息。信息通常是邻居节点上一层的特征。
- 聚合:将收集到的邻居信息聚合成一个单一的向量。常用的聚合函数有求和、均值、最大值。
\[ \mathbf{m}_i^{(l)} = \text{AGGREGATE}^{(l)}(\{ \mathbf{h}_j^{(l-1)} : j \in \mathcal{N}(i) \}) \]
- 更新:将节点自身上一层的特征 \(\mathbf{h}_i^{(l-1)}\) 和聚合得到的邻居信息 \(\mathbf{m}_i^{(l)}\) 结合起来,并通过一个可学习的权重矩阵 \(\mathbf{W}^{(l)}\) 和非线性激活函数 \(\sigma\)(如ReLU)来更新当前节点的特征。
\[ \mathbf{h}_i^{(l)} = \sigma ( \mathbf{W}^{(l)} \cdot \text{CONCAT}( \mathbf{h}_i^{(l-1)}, \mathbf{m}_i^{(l)} ) ) \]
这里 `CONCAT` 表示向量拼接。也可以使用更复杂的GNN层,如GAT(图注意力网络),它会为不同的邻居分配不同的注意力权重。
- 堆叠多层:我们将上述操作堆叠 \(L\) 层。经过 \(L\) 层的信息传播后,每个节点最终的特征 \(\mathbf{h}_i^{(L)}\) 就包含了其 \(L\)-hop 邻居的语义信息。我们将其作为该类别增强后的语义表示:\(\mathbf{g}_i = \mathbf{h}_i^{(L)}\)。
第四步:训练视觉→语义映射函数
这是模型训练的核心部分。我们只使用已见类 \(\mathcal{Y}^{s}\) 的图像-标签数据对进行训练。
- 视觉特征提取:使用一个预训练好的卷积神经网络(如ResNet)作为图像编码器 \(f_{vis}(\cdot)\)。输入一张已见类的图像 \(x\),得到其视觉特征向量 \(\mathbf{v} = f_{vis}(x)\)。
- 语义投影:设计一个投影函数 \(f_{proj}: \mathbb{R}^{D_{vis}} \to \mathbb{R}^{d}\),通常是一个或多个全连接层。它将高维视觉特征 \(\mathbf{v}\) 映射到与语义向量相同维度的空间,得到图像的视觉语义嵌入 \(\phi(x) = f_{proj}(\mathbf{v})\)。
- 计算相似度与损失:
- 对于图像 \(x\) 及其真实标签 \(y \in \mathcal{Y}^s\),我们得到其视觉嵌入 \(\phi(x)\)。
- 从GNN中获取其对应类别增强后的语义表示 \(\mathbf{g}_y\)。
- 目标是让 \(\phi(x)\) 和 \(\mathbf{g}_y\) 尽可能接近,同时远离其他类别的语义表示 \(\mathbf{g}_{j} (j \neq y)\)。
- 常用的损失函数是交叉熵损失。我们将 \(\phi(x)\) 与所有已见类的增强语义表示 \(\{ \mathbf{g}_j^s \}\) 计算余弦相似度,得到一个相似度分数向量,然后通过Softmax归一化为概率分布,最后计算与真实标签的交叉熵损失。
\[ p(y_i^s | x) = \frac{\exp(\text{cosine}(\phi(x), \mathbf{g}_i^s) / \tau)}{\sum_{j=1}^{C^s} \exp(\text{cosine}(\phi(x), \mathbf{g}_j^s) / \tau)} \]
\[ \mathcal{L}_{cls} = -\log p(y | x) \]
其中 $ \tau $ 是温度系数。
- 对抗训练:为了增强模型的泛化能力,防止其过拟合到已见类的视觉-语义映射上,有时会引入一个对抗训练 的判别器。判别器试图区分嵌入是来自“已见类”还是“未见类”,而特征提取和投影网络则试图“欺骗”判别器,让两者不可分。这迫使视觉语义嵌入的空间更具一般性,有利于迁移到未见类。
第五步:测试与推理
当模型训练完成后,就可以对包含未见类的图像进行分类了。
- 获取所有类别的增强语义表示:将所有类别(包括已见和未见)的原始语义向量 \(\{ \mathbf{a}_i \}_{i=1}^{C^s + C^u}\) 输入到已经训练好的GNN中。GNN会输出所有类别增强后的语义表示 \(\{ \mathbf{g}_i \}_{i=1}^{C^s + C^u}\)。注意,GNN是在所有类别的语义关系图上训练的,因此它也能为未见类生成有意义的增强表示。
- 处理测试图像:对于一张测试图像 \(x_{test}\),我们通过训练好的视觉编码器 \(f_{vis}\) 和投影函数 \(f_{proj}\) 得到其视觉语义嵌入 \(\phi(x_{test})\)。
- 最近邻搜索:计算 \(\phi(x_{test})\) 与所有类别的增强语义表示 \(\{ \mathbf{g}_i \}\) 之间的余弦相似度。
- 预测标签:选择相似度最高的类别作为预测结果:
\[ \hat{y} = \arg\max_{y_i \in \mathcal{Y}} \text{cosine}( \phi(x_{test}), \mathbf{g}_i ) \]
由于 $ \{ \mathbf{g}_i \} $ 包含了未见类的表示,因此模型可以预测出训练时没见过的类别。
总结
基于GNN的零样本图像分类算法的核心贡献在于,它不仅仅将每个类别视为独立的语义点,而是利用GNN显式地建模了类别之间的语义关联图。这使得:
- 知识迁移:已见类的信息可以通过图结构有效地传播到语义相似的未见类节点上,从而增强了未见类的语义表示。
- 缓解域偏移:GNN生成的增强表示能更好地捕捉类别的本质特征,使得视觉特征到语义空间的投影函数学习得更具泛化性,减少了对已见类视觉特征的过拟合。
这种方法巧妙地将视觉感知、语义理解和图结构推理结合在一起,为解决“认识未知”这一核心挑战提供了有力工具。