深度学习中优化器的AdamW算法原理与权重衰减机制
字数 2260 2025-10-30 11:52:21
深度学习中优化器的AdamW算法原理与权重衰减机制
题目描述
AdamW算法是Adam优化器的一个改进版本,专门解决了Adam中权重衰减(L2正则化)与自适应学习率结合时出现的问题。你需要理解标准Adam优化器的工作原理,识别其与L2正则化结合时的问题,掌握AdamW如何通过解耦权重衰减来修正这个问题,并了解其具体的实现步骤和优势。
解题过程
-
回顾标准Adam优化器
Adam(Adaptive Moment Estimation)结合了动量(Momentum)和RMSProp的思想。它通过计算梯度的一阶矩(均值)和二阶矩(未中心化的方差)的指数移动平均值来为每个参数自适应地调整学习率。- 对于模型参数 θ,在时间步 t:
- 计算当前小批量的梯度:g_t = ∇θ L(θ_{t-1})
- 更新一阶矩估计(动量):m_t = β₁ * m_{t-1} + (1 - β₁) * g_t
- 更新二阶矩估计(梯度平方):v_t = β₂ * v_{t-1} + (1 - β₂) * g_t² (这里的平方是逐元素运算)
- 偏差校正(因为m和v初始为0,初期会偏向0):m̂_t = m_t / (1 - β₁^t), v̂_t = v_t / (1 - β₂^t)
- 参数更新:θ_t = θ_{t-1} - α * m̂_t / (√v̂_t + ε) (α是学习率,ε是防止除零的小常数)
- 对于模型参数 θ,在时间步 t:
-
识别标准Adam与L2正则化结合的问题
- 在深度学习中,L2正则化(权重衰减)是防止过拟合的常用技术。在原始的SGD(随机梯度下降)中,权重衰减是直接加到权重更新项上的:θ_t = θ_{t-1} - α * g_t - α * λ * θ_{t-1} (λ是权重衰减系数)。这等价于在损失函数中添加了 (λ/2) * ||θ||² 项。
- 然而,当人们将这种"在损失函数中添加L2项"的做法直接套用到Adam优化器时,就产生了问题。此时的梯度 g_t 变成了 ∇θ [L(θ) + (λ/2) * ||θ||²] = ∇θ L(θ) + λθ。
- 在标准Adam中,这个包含了权重衰减项的梯度 g_t 会被用于计算自适应学习率(即除以 √v̂_t)。这意味着权重衰减的大小也会被自适应学习率所缩放。这导致了权重衰减的效果不再是单纯的"乘以λ",而是变成了"乘以(λ/√v̂_t)",其实际衰减量会随着参数的历史梯度(v̂_t)而变化。对于梯度较大的参数,其自适应学习率较小,导致权重衰减的效果被削弱;对于梯度较小的参数,其自适应学习率较大,权重衰减的效果被增强。这并非我们进行L2正则化的本意,我们的本意是对所有权重进行稳定、一致的衰减。
-
AdamW的解决方案:解耦权重衰减
- AdamW的核心思想是将权重衰减与基于损失函数的梯度更新分离开来。它不再将权重衰减项混入损失函数中(即不混入梯度g_t的计算中),而是将其作为一个独立的项,在应用了自适应学习率之后,直接加到参数更新中。
- 具体来说,AdamW的修改非常精妙:
- 计算梯度时,只使用原始损失函数的梯度:g_t = ∇θ L(θ_{t-1})。权重衰减项不参与梯度的计算。
- 在参数更新步骤中,除了应用由Adam计算出的自适应更新量(-α * m̂_t / (√v̂_t + ε))之外,额外地、直接地减去一个权重衰减项(α * λ * θ_{t-1})。
-
AdamW算法的完整步骤
初始化:参数θ₀,一阶矩m₀=0,二阶矩v₀=0,时间步t=0
循环(直到收敛):
t = t + 1- 计算当前小批量下的损失函数梯度(不包含L2项):g_t = ∇θ L(θ_{t-1})
- 更新一阶矩估计:m_t = β₁ * m_{t-1} + (1 - β₁) * g_t
- 更新二阶矩估计:v_t = β₂ * v_{t-1} + (1 - β₂) * g_t²
- 计算一阶矩偏差校正:m̂_t = m_t / (1 - β₁^t)
- 计算二阶矩偏差校正:v̂_t = v_t / (1 - β₂^t)
- 更新参数:θ_t = θ_{t-1} - α * [ m̂_t / (√v̂_t + ε) + λ θ_{t-1} ]
- 注意关键区别:权重衰减项 λ θ_{t-1} 是直接加在括号里的,它与自适应梯度项 m̂_t / (√v̂_t + ε) 是分离的,并且共同乘以全局学习率α。
-
AdamW的优势分析
- 正确的衰减效果:由于权重衰减是直接应用的,没有被自适应学习率缩放,因此它对所有权重都施加了一个稳定、一致的衰减力(力的大小正比于αλ)。这更符合L2正则化的原始目标。
- 通常更好的泛化性能:在实践中,尤其是在训练大型深度学习模型(如Transformer、ResNet等)时,AdamW相比标准的"Adam+L2-in-loss"通常能获得更好的测试集精度(即更好的泛化能力)。
- 成为现代深度学习库的默认或推荐选项:正因为其优势,如PyTorch的
torch.optim.AdamW已经成为许多任务中的首选优化器。
总结来说,AdamW通过将权重衰减从梯度计算中解耦出来,并将其作为一个独立的加法项应用于参数更新,巧妙地修正了标准Adam在处理L2正则化时的问题,从而在实践中实现了更稳定、更有效的优化和正则化效果。