深度学习中的元学习(Meta-Learning)小样本学习(Few-Shot Learning)中的原型网络(Prototype Networks)算法原理与度量学习机制
题目描述
在机器学习中,我们常常面临数据稀缺的问题,尤其是在某些领域(如医疗影像、罕见事件检测)中,收集大量标注数据成本高昂。小样本学习(Few-Shot Learning, FSL)旨在让模型仅通过极少数(例如,每类仅1或5个)标注样本就能快速学习新任务。元学习(Meta-Learning)是实现小样本学习的一种主流框架,其核心思想是让模型学会“如何学习”,即在大量不同但相关的任务上进行训练,使其能够快速适应新任务。
原型网络(Prototype Networks)是元学习中小样本学习的一个经典算法。它基于一个直观的想法:每个类别可以通过其样本在特征空间中的均值(即“原型”)来表示。对新样本进行分类时,只需计算其与各类原型的距离,选择最近的原型对应的类别作为预测结果。该算法简洁高效,且天然地融入了度量学习(Metric Learning)的思想,通过训练优化特征嵌入函数,使得同类样本在嵌入空间中聚集,不同类样本相互远离。
解题过程(原理与机制循序渐进讲解)
步骤1:问题形式化——N-Way K-Shot 学习任务
原型网络处理的是典型的N-way K-shot分类任务。
- N-way:每个任务包含N个不同的类别。
- K-shot:每个类别提供K个带标签的样本作为支持集(Support Set)。
- 查询集(Query Set):同一批类别中,另一些未标记的样本,用于评估模型分类性能。
例如,一个5-way 1-shot任务:支持集包含5个类别,每个类别1个样本(共5个);查询集包含同一5个类别中另一些样本(如每个类别5个,共25个),模型需要预测这些查询样本的标签。
步骤2:核心思想——类原型计算
原型网络的核心理念是为每个类别计算一个“原型”(Prototype),作为该类在特征空间中的代表点。
- 设有一个特征嵌入函数 \(f_\phi\)(通常是一个神经网络,如CNN),参数为 \(\phi\),它将输入图像 \(x\) 映射到一个 \(D\) 维特征向量 \(f_\phi(x) \in \mathbb{R}^D\)。
- 对于一个包含N个类别的任务,其支持集记为 \(S = \{ (x_i, y_i) \}_{i=1}^{N \times K}\),其中 \(y_i \in \{1, 2, ..., N\}\)。
- 对于每个类别 \(c\),其原型 \(\mathbf{p}_c\) 定义为该类所有支持样本特征向量的均值:
\[\mathbf{p}_c = \frac{1}{|S_c|} \sum_{(x_i, y_i) \in S_c} f_\phi(x_i) \]
其中 \(S_c\) 是支持集中标签为 \(c\) 的样本集合。
直观理解:在训练良好的嵌入空间中,同类样本的特征应该很接近,因此它们的均值(原型)能够很好地代表该类别的中心。
步骤3:距离度量与分类决策
对于查询集 \(Q\) 中的一个样本 \(x\),我们通过特征嵌入函数得到其表示 \(f_\phi(x)\)。分类决策基于该查询特征与所有类原型之间的某种距离(通常是欧氏距离)。
- 计算查询特征与每个原型 \(\mathbf{p}_c\) 的平方欧氏距离:
\[d(f_\phi(x), \mathbf{p}_c) = \| f_\phi(x) - \mathbf{p}_c \|^2_2 \]
- 然后,通过一个softmax函数将这些距离转化为属于各个类别的概率:
\[p_\phi(y = c | x) = \frac{\exp(-d(f_\phi(x), \mathbf{p}_c))}{\sum_{c'=1}^{N} \exp(-d(f_\phi(x), \mathbf{p}_{c'}))} \]
即,距离越近,属于该类的概率越高。预测时,选择概率最高的类别 \(\arg\max_c p_\phi(y=c|x)\)。
步骤4:训练目标——最小化负对数似然损失
模型的目标是学习一个良好的特征嵌入函数 \(f_\phi\),使得同类样本聚集、异类样本分离。这通过最小化查询集上的分类损失来实现。
- 对于一个元训练任务,其损失函数定义为查询集上所有样本的负对数似然损失(即交叉熵损失)之和:
\[\mathcal{L}(\phi) = -\sum_{(x_j, y_j) \in Q} \log p_\phi(y = y_j | x_j) \]
- 注意,支持集仅用于计算原型,不直接参与损失计算;损失仅基于查询集的预测。
步骤5:元训练过程(Episode-based Training)
原型网络采用与测试环境一致的“分集式训练”(Episode-based Training),这是元学习的典型方式。
- 构建元训练集:从大规模数据集(如miniImageNet)中采样大量N-way K-shot任务。每个任务称为一个“episode”。
- 单个episode训练步骤:
- 采样一个任务:随机选择N个类别,每类随机选取K个样本作为支持集,另选一批样本作为查询集。
- 使用当前嵌入网络 \(f_\phi\) 计算每个类别的原型 \(\mathbf{p}_c\)。
- 计算查询集样本的预测概率和损失 \(\mathcal{L}(\phi)\)。
- 通过梯度下降(如SGD或Adam)更新网络参数 \(\phi\),最小化损失。
- 重复:在大量episode上迭代训练,使模型学会提取对分类任务泛化性强的特征。
步骤6:度量学习的视角
原型网络本质上是度量学习的一种形式:
- 学习度量空间:嵌入函数 \(f_\phi\) 将原始数据映射到一个度量空间(通常是欧氏空间)。好的嵌入应该使得在该空间中,简单的最近邻分类(此处是到原型的距离)就能取得好效果。
- 与孪生网络、匹配网络的关系:孪生网络(Siamese Networks)通过成对样本的对比学习相似度;匹配网络(Matching Networks)使用注意力机制将查询样本与整个支持集加权比较。原型网络则更进一步,通过计算类别原型,将比较简化为与N个原型的距离计算,效率更高,且对噪声更鲁棒(因为原型是均值,平滑了单个样本的噪声)。
步骤7:关键优势与变体
- 简洁高效:无需复杂的注意力机制或二阶优化,计算成本低。
- 零样本学习扩展:原型可以来自类别语义描述(如词向量)而非图像,轻松扩展到零样本学习(Zero-Shot Learning)。
- 变体:
- 高斯原型网络:将原型视为高斯分布的均值,同时学习协方差矩阵,以更精细地建模类别分布。
- 少样本回归:将原型思想扩展到回归任务,如预测样本的属性值。
- 半监督原型网络:利用未标注的支持样本(如额外的未标注图像)通过软分配或一致性正则化改进原型计算。
总结
原型网络通过一个简单而强大的思想——在嵌入空间中为每个类别计算一个原型,然后基于距离进行最近邻分类——成功解决了小样本学习问题。它将度量学习自然地融入到元学习框架中,通过大量episode训练,使模型学习到可迁移的特征表示。其优雅的设计、高效的计算以及良好的性能,使其成为小样本学习领域的一个基础性算法。