深度学习中优化器的SGD with Layer-wise Adaptive Moments (LAMB) 算法原理与实现细节
题目描述
LAMB(Layer-wise Adaptive Moments)是一种结合了自适应学习率与逐层归一化的优化算法,专为大规模深度学习模型(如BERT)设计。它通过逐层调整参数更新步长,有效解决了训练大模型时学习率选择困难、收敛速度慢的问题。LAMB的核心思想是对每一层参数进行自适应学习率调整,同时引入信任比率(trust ratio)来确保更新方向的稳定性。
解题过程
1. 问题背景
- 传统优化器(如Adam)在训练大模型时,可能因层间梯度分布差异导致收敛不稳定。
- 学习率需精细调整:过大易发散,过小则收敛慢。LAMB通过逐层归一化更新步长,允许使用更大的全局学习率。
2. LAMB的核心思想
- 逐层自适应:对每一层参数独立计算自适应学习率,避免层间梯度尺度差异的影响。
- 信任比率:通过比较参数更新前后的范数比例,控制更新步长,确保方向稳定性。
3. 算法步骤分解
设第 \(t\) 步时,待优化参数为 \(\theta_t\),梯度为 \(g_t\)。
步骤1:计算一阶矩和二阶矩(类似Adam)
\[m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t \]
\[v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \]
- \(m_t, v_t\) 分别为梯度的一阶矩(动量)和二阶矩(自适应学习率基础)。
- \(\beta_1, \beta_2\) 为衰减率(通常取0.9和0.999)。
步骤2:偏差校正
\[\hat{m}_t = \frac{m_t}{1 - \beta_1^t} \]
\[\hat{v}_t = \frac{v_t}{1 - \beta_2^t} \]
- 校正初期估计偏差,使更新更稳定。
步骤3:计算自适应学习率更新
\[\Delta_t = \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} \]
- \(\epsilon\) 为数值稳定性常数(如1e-6)。
- 此时 \(\Delta_t\) 为未归一化的参数更新方向。
步骤4:逐层归一化与信任比率
- 将参数按层分组,对每一层 \(l\) 计算:
\[ \text{trust_ratio}_l = \frac{\|\theta_t^l\|}{\|\Delta_t^l + \lambda \theta_t^l\|} \]
- \(\theta_t^l\) 为第 \(l\) 层的参数向量。
- \(\lambda\) 为权重衰减系数。
- 分子为参数范数,分母为更新后参数的近似范数。信任比率衡量更新方向的可靠性。
步骤5:应用信任比率更新参数
\[\theta_{t+1}^l = \theta_t^l - \eta \cdot \text{trust_ratio}_l \cdot \Delta_t^l \]
- \(\eta\) 为全局学习率。
- 信任比率接近1时,更新步长由自适应学习率主导;偏离1时,自动缩放步长以避免震荡。
4. 关键创新点
- 层间解耦:每层独立归一化,适应不同层的梯度特性。
- 信任比率:动态调整步长,避免梯度爆炸/消失,支持大学习率训练。
5. 实现细节(PyTorch伪代码)
import torch
def lamb_update(params, grads, m, v, t, eta=0.01, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.01):
updates = {}
for layer_id, param in params.items():
# 更新一阶矩和二阶矩
m[layer_id] = beta1 * m[layer_id] + (1 - beta1) * grads[layer_id]
v[layer_id] = beta2 * v[layer_id] + (1 - beta2) * grads[layer_id]**2
# 偏差校正
m_hat = m[layer_id] / (1 - beta1**t)
v_hat = v[layer_id] / (1 - beta2**t)
# 计算自适应更新方向
delta = m_hat / (torch.sqrt(v_hat) + eps)
# 添加权重衰减
delta += weight_decay * param
# 计算信任比率
param_norm = torch.norm(param)
delta_norm = torch.norm(delta)
trust_ratio = param_norm / (delta_norm + eps) if param_norm > 0 and delta_norm > 0 else 1.0
# 应用更新
updates[layer_id] = eta * trust_ratio * delta
return updates
6. 总结
LAMB通过逐层自适应和信任比率机制,显著提升大模型训练的稳定性和收敛速度,特别适合Transformer类模型。其核心是将全局学习率与层局部特性解耦,实现更精细的优化控制。