深度学习中的元学习(Meta-Learning)算法原理与MAML(Model-Agnostic Meta-Learning)框架
题目描述
元学习(Meta-Learning)是让模型学会如何学习的方法,其核心目标是通过在多个相关任务上训练,使模型能够快速适应新任务。MAML(Model-Agnostic Meta-Learning)是一种经典的元学习算法,它不依赖特定模型结构,而是通过优化模型初始参数,使得从该参数出发,仅需少量梯度更新就能在新任务上达到高性能。例如,在少样本分类任务中,MAML的目标是让模型通过少量标注样本(如5个样本)快速学习新类别。
解题过程
-
问题定义
- 假设存在一个任务分布 \(p(\mathcal{T})\),每个任务 \(\mathcal{T}_i\) 包含训练集(支持集)和测试集(查询集)。
- 目标:找到一组模型初始参数 \(\theta \,使得对任意新任务 \( \mathcal{T}_i\),从 \(\theta\) 开始,经过一步或几步梯度下降后,模型在 \(\mathcal{T}_i\) 上的损失最小。
-
MAML的双层优化结构
- 内层更新(任务特定适应):
对于每个任务 \(\mathcal{T}_i\),从初始参数 \(\theta\) 出发,使用支持集计算损失 \(\mathcal{L}_{\mathcal{T}_i}(f_\theta)\),并通过梯度下降更新得到任务特定参数:
- 内层更新(任务特定适应):
\[ \theta_i' = \theta - \alpha \nabla_\theta \mathcal{L}_{\mathcal{T}_i}(f_\theta) \]
其中 $ \alpha $ 为内层学习率。
- 外层更新(元优化):
使用查询集计算所有任务在适应后参数 \(\theta_i'\) 上的损失之和,并优化初始参数 \(\theta\):
\[ \min_\theta \sum_{\mathcal{T}_i \sim p(\mathcal{T})} \mathcal{L}_{\mathcal{T}_i}(f_{\theta_i'}) \]
通过梯度下降更新 $ \theta $:
\[ \theta \leftarrow \theta - \beta \nabla_\theta \sum_{\mathcal{T}_i} \mathcal{L}_{\mathcal{T}_i}(f_{\theta_i'}) \]
其中 $ \beta $ 为外层学习率。
- 梯度计算的关键点
- 外层梯度需考虑内层更新对 \(\theta\) 的依赖,因此需要计算二阶导数(Hessian矩阵)。具体地:
\[ \nabla_\theta \mathcal{L}_{\mathcal{T}_i}(f_{\theta_i'}) = \nabla_{\theta_i'} \mathcal{L}_{\mathcal{T}_i}(f_{\theta_i'}) \cdot \nabla_\theta (\theta_i') \]
其中 $ \nabla_\theta (\theta_i') = I - \alpha \nabla_\theta^2 \mathcal{L}_{\mathcal{T}_i}(f_\theta) $。
- 为简化计算,MAML常使用一阶近似(FOMAML),忽略二阶项,直接假设 \(\nabla_\theta (\theta_i') \approx I\)。
-
算法实现步骤
- 随机初始化参数 \(\theta\)。
- 循环以下步骤直至收敛:
a. 采样一批任务 \(\{\mathcal{T}_i\}\)。
b. 对每个任务,计算内层更新后的参数 \(\theta_i'\)。
c. 计算所有任务在 \(\theta_i'\) 上的查询损失之和。
d. 通过反向传播计算梯度并更新 \(\theta\)。
-
应用示例:5样本分类
- 每个任务包含5张新类别图片(支持集)和15张查询图片。
- 内层更新:用支持集计算损失,更新 \(\theta\) 到 \(\theta_i'\)。
- 外层更新:用查询集评估 \(\theta_i'\) 的泛化能力,优化 \(\theta\) 以提升跨任务适应性。
总结
MAML通过双层优化实现了模型初始参数的元学习,使其成为快速适应新任务的强基线。其核心思想是让梯度更新方向兼顾多任务共性,而非仅优化单一任务性能。