深度学习中优化器的AdamP与SAM(Sharpness-Aware Minimization)结合算法原理与实现细节
题目描述
AdamP是一种结合了自适应学习率(Adam风格)和路径方向修正(Pathwise Direction Correction)的优化器,旨在解决传统Adam在特定任务中可能出现的收敛问题(如泛化性能下降)。而SAM(Sharpness-Aware Minimization)是一种优化框架,通过同时最小化损失值和损失函数的尖锐度(Sharpness)来提升模型泛化能力。本题将详细讲解如何将AdamP与SAM结合,形成AdamP-SAM算法,并分析其原理与实现细节。
1. 背景知识:AdamP与SAM的核心思想
(1)AdamP的改进动机
- 传统Adam的缺陷:在训练后期,自适应学习率可能导致参数在局部极小值附近震荡,影响泛化性能。
- AdamP的解决方案:在参数更新时,对Adam的更新方向进行投影,剔除与参数本身方向相关的分量(避免“径向路径”问题),使优化更稳定。
(2)SAM的优化目标
- 尖锐度定义:损失函数在参数邻域内的最大值与当前点的差值,反映损失曲面的平坦程度。
- SAM的优化形式:
\[ \min_{\mathbf{w}} \max_{\|\epsilon\|_2 \leq \rho} L(\mathbf{w} + \epsilon) \]
即同时优化参数\(\mathbf{w}\)和邻域内最坏情况的损失。
2. AdamP-SAM算法的分步推导
步骤1:SAM的双步更新框架
SAM的原始实现分为两步:
- 内层最大化:寻找使损失最大的扰动\(\epsilon\):
\[ \epsilon_t = \rho \cdot \frac{\nabla L(\mathbf{w}_t)}{\|\nabla L(\mathbf{w}_t)\|_2} \]
- 外层最小化:基于扰动后的梯度更新参数:
\[ \mathbf{w}_{t+1} = \mathbf{w}_t - \eta \cdot \nabla L(\mathbf{w}_t + \epsilon_t) \]
步骤2:将AdamP嵌入SAM的更新过程
传统SAM使用SGD或动量法进行外层最小化,而AdamP-SAM将外层更新替换为AdamP:
- 内层最大化(与原始SAM相同):
\[ \epsilon_t = \rho \cdot \frac{\nabla L(\mathbf{w}_t)}{\|\nabla L(\mathbf{w}_t)\|_2 + \delta} \quad (\delta为微小常数,防止除零) \]
- 外层最小化(使用AdamP):
- 计算扰动后的梯度:\(g_t = \nabla L(\mathbf{w}_t + \epsilon_t)\)
- 计算Adam的一阶矩(动量)和二阶矩(自适应学习率):
\[ m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t \]
\[ v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \]
- 偏差修正:
\[ \hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t} \]
- AdamP的关键步骤:对更新方向进行投影修正。
- 计算原始更新向量:\(u_t = \hat{m}_t / (\sqrt{\hat{v}_t} + \delta)\)
- 计算投影分量(剔除与参数径向重叠的部分):
\[ \text{proj}_t = \frac{\langle u_t, \mathbf{w}_t \rangle}{\|\mathbf{w}_t\|_2^2 + \delta} \cdot \mathbf{w}_t \]
- 修正后的更新方向:$u_t' = u_t - \text{proj}_t$
- 参数更新:
\[ \mathbf{w}_{t+1} = \mathbf{w}_t - \eta \cdot u_t' \]
3. 算法实现细节与超参数选择
(1)超参数说明
- \(\rho\)(SAM的邻域半径):通常取0.05~0.1,控制尖锐度优化的强度。
- \(\eta\)(学习率):需略低于标准Adam(因SAM的梯度幅度更大)。
- \(\beta_1, \beta_2\):建议默认取0.9和0.999。
(2)代码实现要点(PyTorch伪代码)
class AdamP_SAM(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), rho=0.05):
defaults = dict(lr=lr, betas=betas, rho=rho)
super().__init__(params, defaults)
def step(self, closure=None):
loss = closure() if closure else None
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
# 内层最大化:计算扰动epsilon
grad_norm = torch.norm(grad)
epsilon = group['rho'] * grad / (grad_norm + 1e-12)
# 在扰动点计算梯度(需二次反向传播)
p.data.add_(epsilon)
loss_perturbed = closure()
p.data.sub_(epsilon)
grad_perturbed = torch.autograd.grad(loss_perturbed, p)[0]
# AdamP更新步骤
state = self.state[p]
if len(state) == 0:
state['m'] = torch.zeros_like(p)
state['v'] = torch.zeros_like(p)
state['step'] = 0
m, v = state['m'], state['v']
beta1, beta2 = group['betas']
state['step'] += 1
m.mul_(beta1).add_(grad_perturbed, alpha=1-beta1)
v.mul_(beta2).add_(grad_perturbed**2, alpha=1-beta2)
m_hat = m / (1 - beta1**state['step'])
v_hat = v / (1 - beta2**state['step'])
u = m_hat / (v_hat.sqrt() + 1e-8)
# 投影修正
proj = (u.flatten() @ p.data.flatten()) / (p.data.norm()**2 + 1e-8)
u_corrected = u - proj * p.data
p.data.add_(-group['lr'] * u_corrected)
return loss
4. 算法优势与适用场景
- 优势:
- 结合SAM的泛化提升能力与AdamP的稳定收敛特性。
- 尤其适用于视觉、自然语言处理中需要强泛化能力的任务(如对抗训练、少样本学习)。
- 缺点:
- 计算开销大(每个step需两次前向+两次反向传播)。
- 超参数调优更复杂(需平衡\(\rho\)与学习率)。
总结
AdamP-SAM通过将自适应学习率修正与尖锐度最小化结合,在保持AdamP稳定性的同时提升了模型泛化能力。其核心在于:
- 用SAM框架寻找扰动梯度以平滑损失曲面;
- 用AdamP的投影机制避免无效的径向更新。
实际应用中需根据任务复杂度权衡计算成本与性能收益。