深度学习中的优化器之SGD with Polyak Averaging算法原理与实现细节
字数 1376 2025-11-13 22:38:14
深度学习中的优化器之SGD with Polyak Averaging算法原理与实现细节
题目描述
SGD with Polyak Averaging是一种结合随机梯度下降与参数平均化的优化技术。其核心思想是在模型训练过程中,对历史参数值进行加权平均,而非直接使用最新的参数。这种方法能有效平滑优化路径,提升模型在测试集上的泛化能力,特别适用于非凸优化问题中振荡收敛的场景。
解题过程循序渐进讲解
1. 基础SGD的局限性分析
- 标准SGD每次迭代按负梯度方向更新参数:
\(\theta_{t+1} = \theta_t - \eta \nabla f_t(\theta_t)\)
其中 \(\eta\) 为学习率,\(\nabla f_t\) 为当前批次梯度。 - 问题:在损失函数曲面的平坦区域或鞍点附近,梯度估计的噪声会导致参数在最优解附近振荡,影响收敛稳定性。
2. Polyak Averaging的核心思想
- 维护参数的指数移动平均值(Exponential Moving Average, EMA):
\(\theta_{\text{avg},t} = \beta \cdot \theta_{\text{avg},t-1} + (1-\beta) \cdot \theta_t\)
其中 \(\beta \in [0,1)\) 为衰减率,控制历史参数的权重。 - 物理意义:通过加权平均抑制参数更新中的高频振荡,使优化轨迹更平滑。
3. 算法具体步骤
(1)初始化参数 \(\theta_0\),平均参数 \(\theta_{\text{avg},0} = \theta_0\),设定衰减率 \(\beta\)(常取0.99)
(2)对于每轮迭代 \(t=1,2,\dots,T\):
- 采样小批量数据,计算梯度 \(g_t = \nabla f_t(\theta_{t-1})\)
- 更新参数:\(\theta_t = \theta_{t-1} - \eta g_t\)
- 更新平均参数:\(\theta_{\text{avg},t} = \beta \cdot \theta_{\text{avg},t-1} + (1-\beta) \cdot \theta_t\)
(3)训练完成后,使用平均参数 \(\theta_{\text{avg},T}\) 作为最终模型参数
4. 关键参数的作用
- 衰减率 \(\beta\):
- \(\beta\) 越大,历史参数权重越高,平滑效果越强但响应延迟增加
- 典型取值0.9-0.999,需通过验证集调整
- 学习率 \(\eta\):
- 需与 \(\beta\) 协同调节,较大 \(\beta\) 可配合稍大学习率
5. 实现细节与代码示例
import torch
import torch.nn as nn
class SGDWithPolyakAveraging:
def __init__(self, model, lr=0.01, beta=0.99):
self.model = model
self.lr = lr
self.beta = beta
self.step_count = 0
# 初始化平均参数(深拷贝)
self.avg_params = [p.data.clone() for p in model.parameters()]
def step(self):
self.step_count += 1
with torch.no_grad():
for i, param in enumerate(self.model.parameters()):
if param.grad is None:
continue
# SGD更新
param.data -= self.lr * param.grad
# Polyak平均更新
self.avg_params[i] = (
self.beta * self.avg_params[i] +
(1 - self.beta) * param.data
)
def swap_to_averaged(self):
"""将模型参数替换为平均参数(推理时调用)"""
for i, param in enumerate(self.model.parameters()):
param.data.copy_(self.avg_params[i])
6. 算法优势分析
- 收敛稳定性:通过平均操作抑制梯度噪声影响
- 泛化提升:平滑的参数轨迹降低过拟合风险
- 兼容性:可与动量法、自适应学习率等方法结合使用
7. 实际应用注意事项
- 训练阶段仍使用原始参数计算梯度,仅最后推理时切换为平均参数
- 在分布式训练中需同步所有节点的平均参数
- 对于周期性学习率调度,需调整 \(\beta\) 以适应学习率变化