深度学习中优化器的SGDW(SGD with Decoupled Weight Decay)算法原理与实现细节
题目描述
在深度学习模型训练中,权重衰减(Weight Decay)是一种常用的正则化技术,用于防止模型过拟合。然而,在优化算法(如Adam)中,权重衰减通常与梯度下降步骤耦合在一起,导致实际衰减效果与梯度幅度相关,这可能会影响训练效果。SGDW算法(SGD with Decoupled Weight Decay)提出将权重衰减与梯度更新步骤解耦,使得权重衰减独立于优化器的学习率调度,从而更稳定、更符合原始L2正则化的意图。本题目将详细解释SGDW的原理、数学推导、与传统权重衰减的区别,以及具体的实现步骤。
解题过程
1. 背景:传统权重衰减的耦合问题
- 在标准SGD优化器中,权重衰减通常直接在梯度更新时添加到梯度中。
更新公式为:
\(\theta_{t+1} = \theta_t - \eta (\nabla L(\theta_t) + \lambda \theta_t)\)
其中,\(\eta\) 是学习率,\(\lambda\) 是权重衰减系数,\(L\) 是损失函数。 - 在自适应优化器(如Adam)中,权重衰减通常以类似方式实现,但与自适应学习率耦合。
这导致权重衰减的实际效果受学习率 \(\eta\) 影响:当 \(\eta\) 较小时,衰减幅度也变小,这可能偏离L2正则化的原始目标。
2. SGDW的核心思想:解耦权重衰减
- SGDW的核心改进是将权重衰减步骤与梯度更新步骤分离。
具体来说:
(1)先计算梯度更新(不包含权重衰减项)。
(2)再添加一个独立的权重衰减项,该衰减项与学习率无关,仅与权重衰减系数 \(\lambda\) 相关。 - 这样做的好处:
- 权重衰减始终保持固定比例,与学习率的变化无关,使正则化更稳定。
- 更符合L2正则化的数学定义:在损失函数中添加 \(\frac{\lambda}{2} \|\theta\|^2\) 项,其梯度应为 \(\lambda \theta\),但不应被学习率缩放。
3. SGDW的算法步骤
假设模型参数为 \(\theta\),学习率 \(\eta\),权重衰减系数 \(\lambda\),当前迭代步 \(t\)。
步骤1:计算损失函数的梯度
计算损失函数 \(L(\theta_t)\) 对参数 \(\theta_t\) 的梯度:
\(g_t = \nabla L(\theta_t)\)。
注意:这里不包含权重衰减项。
步骤2:执行SGD更新(无权重衰减)
使用标准SGD更新参数:
\(\theta_{t+1/2} = \theta_t - \eta \cdot g_t\)。
这一步是纯梯度下降,没有正则化。
步骤3:独立应用权重衰减
在更新后的参数上直接应用权重衰减:
\(\theta_{t+1} = (1 - \lambda) \cdot \theta_{t+1/2}\)。
注意:衰减系数 \(\lambda\) 直接作用于参数,不乘以学习率 \(\eta\)。
这体现了“解耦”:衰减幅度只由 \(\lambda\) 控制,与 \(\eta\) 无关。
完整更新公式:
\[\theta_{t+1} = (1 - \lambda) (\theta_t - \eta \cdot g_t) \]
展开后为:
\[\theta_{t+1} = (1-\lambda)\theta_t - \eta (1-\lambda) g_t \]
但注意,实际实现中通常分两步进行,以明确解耦逻辑。
4. 与传统权重衰减的对比
- 传统耦合权重衰减(以SGD为例):
\(\theta_{t+1} = \theta_t - \eta (g_t + \lambda \theta_t) = (1 - \eta \lambda) \theta_t - \eta g_t\)。
衰减项 \(\eta \lambda \theta_t\) 与学习率 \(\eta\) 相乘,导致衰减量受 \(\eta\) 影响。 - SGDW解耦权重衰减:
\(\theta_{t+1} = (1 - \lambda) \theta_t - \eta g_t\)(忽略高阶小项 \(\eta \lambda g_t\))。
衰减项 \(\lambda \theta_t\) 与 \(\eta\) 无关。 - 关键区别:在SGDW中,无论学习率如何调整,权重衰减的强度始终保持为 \(\lambda\),使得正则化效果更稳定、可预测。
5. SGDW的变体:与动量结合
SGDW常与动量(Momentum)结合使用,形成SGDW with Momentum。
更新步骤变为:
(1)更新动量:\(m_t = \beta m_{t-1} + g_t\)(\(\beta\) 是动量系数)。
(2)梯度更新:\(\theta_{t+1/2} = \theta_t - \eta \cdot m_t\)。
(3)权重衰减:\(\theta_{t+1} = (1 - \lambda) \cdot \theta_{t+1/2}\)。
这保持了动量加速梯度的优点,同时解耦权重衰减。
6. SGDW的实际实现示例(PyTorch风格伪代码)
def sgdw_step(params, lr, weight_decay, momentum=0.9, grad_clip=None):
for param in params:
if param.grad is None:
continue
grad = param.grad.data
# 梯度裁剪(可选)
if grad_clip is not None:
torch.nn.utils.clip_grad_norm_(param, grad_clip)
# 动量更新
if momentum != 0:
if 'momentum_buffer' not in param.state:
param.state['momentum_buffer'] = torch.zeros_like(param.data)
buf = param.state['momentum_buffer']
buf.mul_(momentum).add_(grad) # buf = momentum * buf + grad
grad = buf
# 梯度下降更新
param.data.add_(-lr, grad) # param = param - lr * grad
# 解耦权重衰减
param.data.mul_(1 - weight_decay) # param = param * (1 - weight_decay)
注意:权重衰减在梯度更新之后独立应用,且不参与动量计算。
7. SGDW的优势与适用场景
- 优势:
- 解耦设计使权重衰减独立于优化器动态,更接近理论L2正则化。
- 在自适应优化器(如AdamW,Adam的SGDW变体)中表现尤其好,可提高泛化性能。
- 对学习率调度(如warmup、衰减)更鲁棒。
- 适用场景:
- 需要强正则化的任务(如图像分类、语言模型)。
- 当使用自适应优化器且出现过拟合时,SGDW(或AdamW)通常是更好的选择。
8. 扩展:AdamW(SGDW思想在Adam中的应用)
AdamW是SGDW思想在Adam优化器上的直接推广:
- 标准Adam的权重衰减耦合在梯度中:
\(m_t = \beta_1 m_{t-1} + (1-\beta_1)(g_t + \lambda \theta_t)\)(不准确,实际实现时通常加在更新步骤)。 - AdamW将其解耦:先计算自适应梯度更新,再独立应用权重衰减:
\(\theta_{t+1} = (1 - \lambda) \theta_t - \eta \cdot \text{AdamUpdate}(g_t)\)。
AdamW已成为训练Transformer等模型的默认优化器。
总结
SGDW通过将权重衰减与梯度更新步骤解耦,解决了传统耦合权重衰减中衰减量受学习率影响的问题,使正则化更稳定有效。其核心在于独立应用权重衰减,不乘以学习率。这种思想可推广到其他优化器(如AdamW),是深度学习优化中的一个重要改进。