深度学习中优化器的SGD with Gradient Projection(带梯度投影的随机梯度下降)算法原理与实现细节
题目描述:
SGD with Gradient Projection 是一种改进的随机梯度下降算法,核心思想是在参数更新前对梯度进行投影操作,将梯度约束在可行域内。这种方法特别适用于求解带约束的优化问题(如参数需满足球面约束、非负约束等),或在训练中保持参数的理论性质(如概率分布的归一化约束)。与传统梯度裁剪直接缩放梯度不同,梯度投影通过数学投影将梯度映射到约束空间的切平面,确保更新后的参数仍满足约束条件。
解题过程:
- 问题建模
假设需要最小化损失函数 \(L(\theta)\),且参数 \(\theta\) 需满足约束条件 \(\theta \in C\),其中 \(C\) 是可行域(例如 \(\|\theta\|_2 \leq 1\))。目标转化为求解带约束优化问题:
\[ \min_{\theta \in C} L(\theta) \]
直接使用SGD更新 \(\theta_{t+1} = \theta_t - \eta \nabla L(\theta_t)\) 可能破坏约束,需引入投影步骤。
- 投影操作定义
投影函数 \(\Pi_C\) 将任意参数 \(\theta\) 映射到可行域 \(C\) 中距离最近的点:
\[ \Pi_C(\theta) = \arg\min_{z \in C} \|z - \theta\|_2^2 \]
例如,若 \(C = \{\theta \mid \|\theta\|_2 \leq 1\}\),则投影为 \(\Pi_C(\theta) = \frac{\theta}{\max(1, \|\theta\|_2)}\)。
- 梯度投影算法步骤
- 步骤1:计算当前梯度
在迭代 \(t\) 时,采样小批量数据计算梯度 \(g_t = \nabla L(\theta_t)\)。 - 步骤2:投影梯度到切空间
若约束空间 \(C\) 是凸集,先将梯度投影到 \(\theta_t\) 处切空间:
- 步骤1:计算当前梯度
\[ \tilde{g}_t = P_{T_C(\theta_t)}(g_t) \]
其中 $ T_C(\theta_t) $ 是 $ C $ 在 $ \theta_t $ 处的切锥。对于简单约束(如球面约束),可直接计算投影梯度。
- 步骤3:更新参数
使用投影后的梯度更新参数:
\[ \theta_{t+1} = \Pi_C(\theta_t - \eta \tilde{g}_t) \]
这里先沿投影梯度方向更新,再将结果投影回可行域 $ C $。
- 实例:球面约束的投影梯度下降
若约束为 \(\|\theta\|_2 = 1\),则投影梯度计算为:
\[ \tilde{g}_t = (I - \theta_t \theta_t^T) g_t \]
此操作移除梯度在 \(\theta_t\) 方向的分量,确保更新沿切平面进行。参数更新后通过归一化投影回球面:
\[ \theta_{t+1} = \frac{\theta_t - \eta \tilde{g}_t}{\|\theta_t - \eta \tilde{g}_t\|_2} \]
-
与梯度裁剪的区别
- 梯度裁剪直接缩放梯度范数以限制更新幅度,但可能改变梯度方向。
- 梯度投影保持梯度的几何意义,确保参数始终满足约束,适合理论要求严格的场景(如正交化、概率 simplex 约束)。
-
实现细节
- 约束需满足凸性以保证投影唯一性。
- 投影计算需高效:例如对非负约束 \(\theta \geq 0\),投影为 \(\max(0, \theta)\)。
- 在深度学习框架中,可通过自定义梯度函数实现投影操作。
总结:梯度投影法通过将梯度约束在可行域的切空间,兼顾优化效率与理论约束,是处理带约束优化问题的有效工具。