深度学习中优化器的Nadam算法原理与实现细节
题目描述
Nadam(Nesterov-accelerated Adaptive Moment Estimation)是结合了Nesterov动量和Adam优化器优势的算法。它通过自适应学习率调整和前瞻性梯度计算,在深度神经网络训练中实现更快的收敛速度和更好的稳定性。题目要求理解Nadam的数学原理、与Adam的关键差异,以及实际实现中的计算步骤。
解题过程
- 算法基础回顾
- Adam优化器:维护一阶矩(均值)\(m_t\) 和二阶矩(未中心化方差)\(v_t\) 的指数移动平均,通过偏差校正后更新参数:
\[ m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t,\quad v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2 \]
\[ \hat{m}_t = \frac{m_t}{1-\beta_1^t},\quad \hat{v}_t = \frac{v_t}{1-\beta_2^t},\quad \theta_t = \theta_{t-1} - \alpha \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} \]
- Nesterov动量:在计算梯度前先根据当前动量方向临时更新参数("展望一步"),再用该位置的梯度修正动量,公式为:
\[ m_t = \beta m_{t-1} + \alpha \nabla J(\theta_{t-1} - \beta m_{t-1}),\quad \theta_t = \theta_{t-1} - m_t \]
-
Nadam的核心思想
- 将Adam中的一阶矩估计 \(\hat{m}_t\) 替换为Nesterov动量的形式,即直接使用当前时刻的动量方向进行前瞻性梯度计算。
- 具体来说,在参数更新时,用 \(\hat{m}_t\) 的"未来时刻"估计(即 \(\hat{m}_{t+1}\) 的近似)替代原始Adam中的 \(\hat{m}_t\)。
-
数学推导步骤
- 首先,定义Adam中偏差校正后的一阶矩估计:
\[ \hat{m}_t = \frac{m_t}{1-\beta_1^t} = \frac{\beta_1 m_{t-1} + (1-\beta_1)g_t}{1-\beta_1^t} \]
- 注意到 \(m_t\) 可展开为:
\[ m_t = (1-\beta_1)\sum_{i=1}^t \beta_1^{t-i} g_i \]
故 $ \hat{m}_t $ 是梯度 $ g_i $ 的加权平均。
- Nadam的关键修改:将更新规则中的 \(\hat{m}_t\) 替换为 \(\hat{m}_t^{\text{nesterov}} = \beta_1 \hat{m}_{t+1} + (1-\beta_1) \frac{g_t}{1-\beta_1^t}\),其中 \(\hat{m}_{t+1}\) 是下一时刻的估计。但实际计算时,\(\hat{m}_{t+1}\) 未知,因此通过近似实现:
\[ \hat{m}_t^{\text{nesterov}} = \beta_1 \hat{m}_t + (1-\beta_1) \frac{g_t}{1-\beta_1^t} \]
这里用当前时刻的 $ \hat{m}_t $ 近似 $ \hat{m}_{t+1} $,相当于在计算梯度前先沿动量方向"展望"。
- 完整算法流程
- 初始化:参数 \(\theta_0\),一阶矩 \(m_0 = 0\),二阶矩 \(v_0 = 0\),超参数 \(\alpha, \beta_1, \beta_2, \epsilon\)。
- 循环迭代(每一步 t):
- 计算当前梯度 \(g_t = \nabla J(\theta_{t-1})\)。
- 更新一阶矩:\(m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t\)。
- 更新二阶矩:\(v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2\)。
- 偏差校正:
\[ \hat{m}_t = \frac{m_t}{1-\beta_1^t},\quad \hat{v}_t = \frac{v_t}{1-\beta_2^t} \]
5. 计算Nesterov风格的一阶矩:
\[ \hat{m}_t^{\text{nesterov}} = \beta_1 \hat{m}_t + (1-\beta_1) \frac{g_t}{1-\beta_1^t} \]
6. 更新参数:
\[ \theta_t = \theta_{t-1} - \alpha \frac{\hat{m}_t^{\text{nesterov}}}{\sqrt{\hat{v}_t} + \epsilon} \]
-
与Adam的对比
- Adam的更新项为 \(\frac{\hat{m}_t}{\sqrt{\hat{v}_t}}\),而Nadam使用 \(\frac{\hat{m}_t^{\text{nesterov}}}{\sqrt{\hat{v}_t}}\)。
- 效果:Nadam在损失函数存在较大曲率时(如循环神经网络训练)能更快收敛,因为Nesterov动量减少了振荡。
-
实现注意事项
- 超参数 \(\beta_1\) 通常设为0.9,\(\beta_2\) 为0.999,\(\epsilon\) 为 \(10^{-8}\)。
- 学习率 \(\alpha\) 需根据任务调整,一般从 \(10^{-3}\) 开始尝试。
- 在代码实现中,需注意偏差校正的分母 \(1-\beta_1^t\) 在 \(t=0\) 时为0,因此通常初始化 \(t=1\) 并避免除零错误。