Softmax回归的原理与多分类计算过程
题目描述
Softmax回归(或称多类逻辑回归)是逻辑回归的扩展,用于解决多分类问题(类别数 \(K \geq 3\))。给定一个输入向量 \(\mathbf{x}\),Softmax回归会计算其属于每个类别的概率,并选择概率最大的类别作为预测结果。核心在于通过Softmax函数将线性模型的输出转化为概率分布。
解题过程
1. 模型定义
假设有 \(K\) 个类别,每个类别对应一个参数向量 \(\mathbf{w}_k\)(其中 \(k = 1, 2, \dots, K\))。对于输入特征向量 \(\mathbf{x} \in \mathbb{R}^d\),模型先计算每个类别的线性得分:
\[z_k = \mathbf{w}_k^\top \mathbf{x} + b_k \]
这里 \(b_k\) 是偏置项。为简化表达,通常将偏置并入权重向量,即令 \(\mathbf{x} \leftarrow [\mathbf{x}; 1]\),\(\mathbf{w}_k \leftarrow [\mathbf{w}_k; b_k]\),此时 \(z_k = \mathbf{w}_k^\top \mathbf{x}\)。
2. Softmax函数:从得分到概率
Softmax函数将 \(K\) 个得分 \(\{z_1, z_2, \dots, z_K\}\) 转换为概率分布:
\[P(y=k \mid \mathbf{x}) = \frac{e^{z_k}}{\sum_{j=1}^{K} e^{z_j}} = \frac{e^{\mathbf{w}_k^\top \mathbf{x}}}{\sum_{j=1}^{K} e^{\mathbf{w}_j^\top \mathbf{x}}} \]
- 分母是所有类别得分的指数和,确保概率之和为 1。
- 指数函数 \(e^{z_k}\) 保证概率非负,且放大得分差异(得分高的类别概率更接近 1)。
3. 损失函数:交叉熵损失
训练目标是最大化真实标签的预测概率,通常使用交叉熵损失。设真实标签用 one-hot 向量表示(例如 \(y = [0, 0, 1, 0]\) 表示属于第 3 类),损失函数为:
\[L(\mathbf{W}) = -\sum_{i=1}^{N} \sum_{k=1}^{K} y_{ik} \log P(y_i = k \mid \mathbf{x}_i) \]
其中:
- \(N\) 是样本数量,\(\mathbf{W} = [\mathbf{w}_1, \dots, \mathbf{w}_K]\) 是所有权重参数。
- \(y_{ik} = 1\) 当样本 \(i\) 的真实类别为 \(k\),否则为 0。
- 实际计算时,对于每个样本 \(i\),只需计算其真实类别 \(k\) 对应的 \(-\log P(y_i = k \mid \mathbf{x}_i)\)。
4. 梯度下降优化
通过梯度下降最小化损失函数。需计算损失对权重 \(\mathbf{w}_j\) 的梯度。对于单个样本 \((\mathbf{x}, y)\)(真实类别为 \(k^*\)):
- 先计算概率向量 \(\mathbf{p} = [p_1, \dots, p_K]\),其中 \(p_j = P(y=j \mid \mathbf{x})\)。
- 梯度公式为:
\[\frac{\partial L}{\partial \mathbf{w}_j} = (p_j - \mathbb{I}_{j = k^*}) \mathbf{x} \]
这里 \(\mathbb{I}_{j = k^*}\) 是指示函数(当 \(j = k^*\) 时值为 1,否则为 0)。
- 直观解释:如果模型对样本 \(\mathbf{x}\) 的预测概率 \(p_j\) 高于真实需求(即 \(j \neq k^*\) 但 \(p_j > 0\)),则梯度为负,推动 \(\mathbf{w}_j\) 远离 \(\mathbf{x}\);对于真实类别 \(j = k^*\),梯度为 \((p_{k^*} - 1)\mathbf{x}\),推动 \(\mathbf{w}_{k^*}\) 靠近 \(\mathbf{x}\)。
5. 正则化
为防止过拟合,常在损失函数中加入 L2 正则化项:
\[L_{\text{reg}} = L(\mathbf{W}) + \frac{\lambda}{2} \sum_{k=1}^{K} \|\mathbf{w}_k\|^2 \]
此时梯度需额外加上 \(\lambda \mathbf{w}_j\)。
6. 预测阶段
对测试样本 \(\mathbf{x}\),计算其属于每个类别的概率 \(p_k\),并选择概率最大的类别:
\[\hat{y} = \arg\max_{k} p_k \]
关键点总结
- Softmax函数将线性得分映射为概率分布,适用于多分类。
- 交叉熵损失衡量预测概率与真实分布的差异。
- 梯度下降通过调整权重,使真实类别的概率接近 1,其他类别概率接近 0。
- Softmax回归常作为神经网络的最终输出层,与反向传播结合使用。