深度学习中优化器的SGD with Decoupled Weight Decay (SGDW) 算法原理与实现细节
我们已讲过多个优化器(如AdamW、LAMB、SGDW等),为了避免重复,这次我将为您讲解SGD with Decoupled Weight Decay (SGDW) 算法,重点剖析其核心思想、与标准SGD with Weight Decay的区别、数学推导及实现细节。
1. 题目描述
SGD with Decoupled Weight Decay (SGDW) 是一种改进的随机梯度下降算法,由Ilya Loshchilov和Frank Hutter在论文《Decoupled Weight Decay Regularization》中提出。该算法揭示了传统SGD中权重衰减(weight decay)与L2正则化在自适应优化器(如Adam)中的等价性不成立的问题,并提出将权重衰减从梯度更新中解耦,直接应用于权重本身,从而在SGD和Adam等优化器中实现更有效的正则化。本题目将详细讲解SGDW的动机、算法步骤、数学原理及代码实现。
2. 问题背景:权重衰减与L2正则化的混淆
在深度学习优化中,正则化用于防止过拟合。传统SGD中,权重衰减(Weight Decay) 通常通过在损失函数中添加L2正则项实现:
\[L_{\text{total}}(\theta) = L(\theta) + \frac{\lambda}{2} \|\theta\|^2_2 \]
其中\(\lambda\)是权重衰减系数。对参数\(\theta\)求梯度时,正则项的梯度为\(\lambda \theta\),因此SGD的更新规则为:
\[\theta_{t+1} = \theta_t - \eta \nabla L(\theta_t) - \eta \lambda \theta_t \]
这里,权重衰减被实现为梯度更新的一部分。
然而,在自适应优化器(如Adam)中,这种实现方式会导致权重衰减与梯度归一化(如动量、二阶矩估计)耦合,从而减弱正则化效果。SGDW的核心思想是:将权重衰减从梯度计算中分离,直接应用于参数更新。
3. 算法原理:解耦权重衰减
3.1 标准SGD with Weight Decay的问题
对于标准SGD,权重衰减等价于L2正则化,因为:
- 梯度包含\(\nabla L(\theta) + \lambda \theta\)。
- 更新为\(\theta_{t+1} = \theta_t - \eta (\nabla L(\theta_t) + \lambda \theta_t) = (1 - \eta \lambda) \theta_t - \eta \nabla L(\theta_t)\)。
这里的衰减因子\((1 - \eta \lambda)\)在每次迭代中直接缩放权重。
但在自适应优化器中,如Adam,梯度会被除以其二阶矩的平方根(自适应学习率),导致权重衰减项也被缩放,从而与学习率\(\eta\)和梯度统计量耦合,失去原有的正则化强度。
3.2 SGDW的更新规则
SGDW将权重衰减解耦为独立步骤:
\[\theta_{t+1} = \theta_t - \eta \nabla L(\theta_t) - \eta \lambda \theta_t \]
注意:这里\(\lambda \theta_t\)是直接加在更新中的,而不是通过梯度。为了更清晰,我们将其重写为两步:
- 计算梯度更新:\(\Delta \theta_t = -\eta \nabla L(\theta_t)\)
- 应用权重衰减:\(\theta_{t+1} = \theta_t + \Delta \theta_t - \eta \lambda \theta_t\)
等价于:
\[\theta_{t+1} = (1 - \eta \lambda) \theta_t - \eta \nabla L(\theta_t) \]
关键点:权重衰减系数\(\lambda\)与学习率\(\eta\)相乘,这意味着衰减强度与学习率成正比。这与传统SGD相同,但解耦后可以更灵活地调整\(\lambda\)。
4. 算法步骤
SGDW的伪代码如下:
输入:
- 初始参数\(\theta_0\)
- 学习率\(\eta\)
- 权重衰减系数\(\lambda\)
- 动量系数\(\beta\)(可选,SGDW通常与动量结合)
- 损失函数\(L(\theta)\)
算法流程:
- 初始化动量变量\(m_0 = 0\)
- 对于每个迭代\(t = 0, 1, 2, \dots\):
a. 采样小批量数据,计算梯度\(g_t = \nabla L(\theta_t)\)
b. 更新动量(如果使用动量):\(m_{t+1} = \beta m_t + (1 - \beta) g_t\)
c. 计算参数更新(使用动量或原始梯度):\(\Delta \theta_t = -\eta m_{t+1}\)(或\(-\eta g_t\))
d. 应用权重衰减:\(\theta_{t+1} = \theta_t + \Delta \theta_t - \eta \lambda \theta_t\)
注意:权重衰减项\(-\eta \lambda \theta_t\)是直接作用于参数,而非通过梯度。这与AdamW中的思想一致,但应用于SGD框架。
5. 与标准SGD with Weight Decay的区别
- 解耦性:
- 标准SGD:权重衰减通过梯度实现,与梯度耦合。
- SGDW:权重衰减作为独立项,直接调整参数。
- 自适应优化器兼容性:
- 在Adam中,标准权重衰减会被自适应学习率缩放,导致正则化效果不稳定。
- SGDW的解耦方式可以无缝融入Adam(即AdamW),保持衰减强度一致。
- 超参数调整:
- SGDW中,\(\lambda\)的调优更稳定,因为其作用不受梯度归一化影响。
6. 数学推导:为什么解耦有效?
考虑损失函数\(L(\theta)\)和L2正则项:
\[L_{\text{reg}}(\theta) = L(\theta) + \frac{\lambda}{2} \|\theta\|^2_2 \]
梯度为:
\[\nabla L_{\text{reg}}(\theta) = \nabla L(\theta) + \lambda \theta \]
标准SGD更新:
\[\theta_{t+1} = \theta_t - \eta (\nabla L(\theta_t) + \lambda \theta_t) = (1 - \eta \lambda) \theta_t - \eta \nabla L(\theta_t) \]
在自适应优化器中,梯度会被除以\(\sqrt{v_t} + \epsilon\)(其中\(v_t\)是二阶矩估计),因此权重衰减项变为\(\eta \lambda \theta_t / (\sqrt{v_t} + \epsilon)\),其强度随\(v_t\)变化,导致正则化效果不稳定。
SGDW直接应用\(-\eta \lambda \theta_t\),避免被自适应学习率缩放,从而保持一致的衰减强度。
7. 实现细节
以PyTorch风格实现SGDW(带动量):
import torch
def sgdw_optimizer(params, lr=0.01, weight_decay=0.01, momentum=0.9):
"""
SGDW优化器实现。
params: 模型参数(可迭代)
lr: 学习率
weight_decay: 权重衰减系数
momentum: 动量系数
"""
velocities = [torch.zeros_like(p) for p in params]
def step():
for p, v in zip(params, velocities):
if p.grad is None:
continue
# 1. 计算动量更新
v.mul_(momentum).add_(p.grad, alpha=1 - momentum)
# 2. 计算梯度更新(不含权重衰减)
update = -lr * v
# 3. 应用权重衰减(解耦)
p.data.add_(update)
p.data.add_(-lr * weight_decay, p.data)
return step
# 示例使用
model_params = [torch.randn(10, requires_grad=True)]
opt_step = sgdw_optimizer(model_params, lr=0.01, weight_decay=0.001, momentum=0.9)
# 模拟训练循环
for _ in range(100):
loss = model_params[0].sum() # 示例损失
loss.backward()
opt_step()
model_params[0].grad.zero_()
关键点:
- 权重衰减通过
p.data.add_(-lr * weight_decay, p.data)直接应用,而非添加到梯度。 - 动量更新与标准SGD相同,但衰减步骤独立。
8. 与AdamW的关系
SGDW是解耦权重衰减思想在SGD中的应用,而AdamW将其扩展到Adam优化器。AdamW的更新规则为:
\[m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t \]
\[ v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \]
\[ \hat{m}_t = m_t / (1 - \beta_1^t), \quad \hat{v}_t = v_t / (1 - \beta_2^t) \]
\[ \theta_{t+1} = \theta_t - \eta \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \theta_t \right) \]
注意:这里\(\lambda \theta_t\)是直接加在更新中,而非通过梯度计算。
9. 总结
- SGDW通过将权重衰减从梯度更新中解耦,直接应用于参数,解决了自适应优化器中正则化效果不稳定的问题。
- 其更新规则简单:\(\theta_{t+1} = (1 - \eta \lambda) \theta_t - \eta \nabla L(\theta_t)\),与标准SGD形式相同,但思想上有本质区别。
- 实现时需注意权重衰减项独立于梯度计算,这在PyTorch等框架中可通过直接修改参数实现。
- 该算法是AdamW的前身,强调了解耦权重衰减在深度学习优化中的普适重要性。
通过以上步骤,您应该能理解SGDW的原理、实现及其在优化器设计中的关键作用。