深度学习中的随机权重平均(Stochastic Weight Averaging, SWA)算法原理与优化过程
字数 2994 2025-12-11 06:28:39
深度学习中的随机权重平均(Stochastic Weight Averaging, SWA)算法原理与优化过程
题目描述
随机权重平均(SWA)是一种在深度学习训练后期使用的优化技术。其核心思想是:在模型训练接近收敛时,不再使用最终的单一权重点,而是对训练过程中访问到的多个权重点进行平均,从而得到一个更平坦、泛化能力更强的解。SWA 通常能显著提升模型的测试性能,且计算开销很小。请详细讲解 SWA 的动机、算法步骤、理论依据及其实现细节。
解题过程(循序渐进讲解)
第一步:问题背景与动机
-
传统训练的局限性:
- 标准训练(如SGD)通常在训练损失达到一个较低点后停止,并保存该时刻的权重。
- 然而,深度神经网络的损失曲面非常复杂,存在许多局部最小值。这些最小值中,“平坦”的极小值通常比“尖锐”的极小值具有更好的泛化能力。
- 最终收敛的权重点可能恰好位于一个“尖锐”的极小值内,对训练数据的小扰动敏感,导致测试性能不稳定。
-
SWA的直观想法:
- 如果能在训练后期,沿着损失曲面收集多个不同的权重点(它们都处于低损失区域),并对它们进行平均。
- 平均后的权重有望落在这些点之间的某个区域,这个区域很可能是一个更平坦的极小值盆地,从而获得更好的泛化性能。
第二步:算法核心步骤
假设我们使用 SGD 或 Adam 等优化器进行训练。SWA 的操作分为两个阶段:
-
预热阶段(Warm-up Phase):
- 使用常规优化器(如带动量的SGD)训练模型一定周期(例如总训练周期的75%或自定义周期数),使模型初步收敛。
- 此阶段不使用SWA,只是正常训练。预热阶段结束后,学习率可能已经按照计划下降到了一个较小的值。
-
SWA平均阶段(Averaging Phase):
- 学习率调整:进入SWA阶段后,通常使用一个较高且恒定的学习率(例如0.01)或一个循环学习率(Cyclic LR)。高学习率可以使权重在平坦极小值区域周围“游走”,探索不同的低损失点。
- 权重点采样:在此阶段的训练过程中,以固定的频率(例如每个epoch结束时)记录当前模型的权重。
- 权重更新:不直接使用这些采样点的权重进行预测,而是维护一个运行平均权重(Running Average Weight)。
- 平均公式:假设在第 \(t\) 次采样时,当前模型权重为 \(w_t\),当前的平均权重为 \(w_{swa}\),则更新规则为:
\[ w_{swa} \leftarrow \frac{w_{swa} \cdot n_{models} + w_t}{n_{models} + 1} \]
其中 $ n_{models} $ 是此前已平均的模型数量。实际操作中,可以更高效地写为:
\[ w_{swa} \leftarrow w_{swa} \cdot \alpha + w_t \cdot (1 - \alpha) \]
其中 $ \alpha = \frac{n_{models}}{n_{models} + 1} $。通常初始化 $ w_{swa} = w_1 $(第一个采样点的权重)。
- 推断阶段:
- 训练结束后,使用计算得到的平均权重 \(w_{swa}\) 替换模型权重,用于后续的测试和部署。
第三步:关键细节与理论解释
-
为什么高学习率有效?
- 在训练后期,权重已经接近收敛区域。一个较高的恒定学习率会阻止权重完全收敛到某个尖锐的极小点,而是使其在平坦极小值区域的边界附近振荡。采样这些振荡点进行平均,相当于对这个平坦区域进行“探测”和“平滑”。
-
SWA与集成学习(Ensemble)的区别:
- 集成学习是独立训练多个模型,在推断时对所有模型的输出进行平均(或投票)。这需要存储多个完整模型,计算成本高。
- SWA只对权重进行平均,得到一个单一的模型。存储和计算成本与普通模型无异,是一种高效的“隐式集成”。
-
理论保障:中心化定理
- SWA可以看作是在SGD迭代路径上对权重进行平均。有理论表明,当SGD使用恒定学习率或循环学习率在凸损失曲面(或局部凸区域)中运行时,对迭代点进行平均可以收敛到该区域中心一个更优的解。尽管神经网络是非凸的,但大量实验证明此方法在局部凸区域依然有效。
-
批归一化(BatchNorm)层的特殊处理:
- 如果模型包含BatchNorm层,在训练结束后,不能直接使用平均权重 \(w_{swa}\) 进行推断。
- 原因:BatchNorm层在训练时维护了运行均值和方差。SWA平均的权重来自不同时刻,对应的BatchNorm统计量(均值和方差)并不一致。直接使用 \(w_{swa}\) 和最终的BatchNorm统计量会导致不一致。
- 解决方法:在训练集(或一个大型子集)上,使用平均后的权重 \(w_{swa}\) 对模型进行一次前向传播(不反向传播),目的是重新计算并更新BatchNorm层的运行均值和方差。之后,模型才能用于测试。
第四步:算法流程总结
- 输入:模型 \(M\),训练数据,优化器(如SGD),总训练周期 \(T\),SWA开始周期 \(T_{start}\)(如0.75T)。
- 初始化:\(w_{swa} \leftarrow None\),\(n \leftarrow 0\)。
- For epoch = 1 to \(T\):
- 使用优化器正常训练一个epoch。
- If epoch >= \(T_{start}\) (进入SWA阶段):
- 将优化器学习率调整为较高的恒定值或循环计划。
- 在每个epoch结束时(或每K个迭代后):
- 记录当前权重 \(w_{current}\)。
- If \(w_{swa}\) is None:
- \(w_{swa} \leftarrow w_{current}\)
- Else:
- \(n \leftarrow n + 1\)
- \(w_{swa} \leftarrow w_{swa} \cdot \frac{n}{n+1} + w_{current} \cdot \frac{1}{n+1}\)
- 训练后处理:
- 将模型权重设置为 \(w_{swa}\)。
- 如果模型有BatchNorm层,在训练数据上运行一次前向传播以更新其运行统计量。
- 输出:优化后的模型 \(M\)(权重为 \(w_{swa}\))。
第五步:优势与注意事项
-
优势:
- 几乎零额外成本:只需在训练后期多存储一个平均权重变量,计算开销极小。
- 显著提升泛化能力:在多种任务(图像分类、语义分割、语言建模等)上都能稳定提升1-2个百分点的测试精度。
- 缓解过拟合:通过找到更平坦的解,增强了模型鲁棒性。
-
注意事项:
- 起始时机:SWA必须在模型初步收敛后开始,否则可能平均到较差的权重。
- 学习率策略:SWA阶段的学习率策略是关键,高学习率是“探索”平坦区域的核心。
- BatchNorm处理:务必进行前述的统计量更新,否则性能可能下降。
通过以上步骤,SWA 巧妙地利用训练后期权重的轨迹信息,通过简单的平均操作,引导模型走向泛化性能更优的平坦解区域。