长短期记忆网络(LSTM)中的三个门控机制及其作用
题目描述
长短期记忆网络(LSTM)是循环神经网络(RNN)的一种特殊变体,专门设计用来解决长期依赖问题,即模型需要学习并记住相隔较远的时间步之间的信息。与您已了解的梯度消失问题和LSTM的解决机制不同,本题将聚焦于L LSTM内部最核心的创新:三个门控机制。我们将详细探讨遗忘门、输入门和输出门各自的功能、计算过程,以及它们如何协同工作来有选择地更新和传递细胞状态。
解题过程
LSTM的关键在于其细胞状态(Cell State),可以将其理解为一个贯穿时间的“传送带”。信息可以相对不变地在这条传送带上流动。LSTM通过三个被称为“门”的结构来精确控制哪些信息应该被添加到细胞状态,哪些应该被移除。每个门都是一个神经网络层,通常由Sigmoid激活函数和点乘操作组成,能够有选择地让信息通过(0表示“完全不让任何信息通过”,1表示“让所有信息通过”)。
第一步:理解遗忘门(Forget Gate)
遗忘门是第一个门,它决定我们需要从上一个时间步的细胞状态 \(C_{t-1}\) 中丢弃(遗忘)哪些信息。
- 输入:遗忘门接收两个输入——当前时间步的输入 \(x_t\) 和上一个时间步的隐藏状态 \(h_{t-1}\)。
- 计算:它将这两个输入向量拼接起来,并通过一个全连接层(通常是带有偏置的线性变换),然后应用Sigmoid激活函数 \(\sigma\)。Sigmoid函数将输出值压缩到0和1之间。
- 公式:\(f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)\)
- 其中:
- \(f_t\) 是遗忘门的输出向量,其每个元素的值都在[0, 1]区间内。
- \(W_f\) 是遗忘门对应的权重矩阵。
- \([h_{t-1}, x_t]\) 表示将两个向量拼接(concat)成一个更长的向量。
- \(b_f\) 是偏置项。
- 作用:\(f_t\) 向量中的每个值,将逐点乘以(point-wise multiplication,符号为 \(\odot\) )上一个细胞状态 \(C_{t-1}\) 的对应位置。如果一个位置的 \(f_t\) 值接近0,那么 \(C_{t-1}\) 中对应位置的信息就会被“遗忘”(乘以一个接近0的数);如果接近1,则该信息会被几乎完整地保留。
第二步:理解输入门(Input Gate)和候选细胞状态
输入门决定我们要将哪些新信息存入细胞状态。这个过程分为两步。
-
输入门(决定更新哪些部分):
- 与遗忘门类似,输入门也接收 \(h_{t-1}\) 和 \(x_t\)。
- 公式:\(i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)\)
- \(i_t\) 是一个由0和1组成的向量,用于控制后续的候选值中有多少信息会被允许加入到细胞状态中。
-
候选细胞状态(提供新的候选值):
- 同时,我们会创建一个新的候选值向量 \(\tilde{C}_t\),这些值是可能被加入到细胞状态中的备选内容。这里我们使用tanh激活函数(输出范围在-1到1之间),以引入非线性并帮助调节数值。
- 公式:\(\tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C)\)
-
更新细胞状态:
- 现在,我们可以将旧的细胞状态 \(C_{t-1}\) 更新为新的细胞状态 \(C_t\)。
- 公式:\(C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t\)
- 解读:
- \(f_t \odot C_{t-1}\):这是“遗忘”阶段。我们选择性地丢弃旧状态中的信息。
- \(i_t \odot \tilde{C}_t\):这是“记忆”阶段。我们选择性地添加新的候选值。
- 将这两部分相加,就得到了更新后的、包含了当前时间步相关信息的细胞状态 \(C_t\)。
第三步:理解输出门(Output Gate)
输出门基于更新后的细胞状态,决定下一个隐藏状态 \(h_t\) 应该是什么。隐藏状态 \(h_t\) 通常包含用于当前时间步预测的输出信息,并作为下一个时间步的输入之一。
-
输出门(决定输出哪些部分):
- 输出门同样接收 \(h_{t-1}\) 和 \(x_t\)。
- 公式:\(o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)\)
- \(o_t\) 将控制细胞状态 \(C_t\) 中有多少信息会被输出。
-
计算隐藏状态(输出):
- 首先,我们将细胞状态 \(C_t\) 通过一个tanh激活函数(将其值规范到-1和1之间)。
- 然后,将这个规范化的 \(C_t\) 与输出门 \(o_t\) 进行逐点相乘。
- 公式:\(h_t = o_t \odot \tanh(C_t)\)
- 解读:输出门 \(o_t\) 像一个“过滤器”,决定细胞状态中的哪些特征将作为这个时间步的最终输出 \(h_t\)。这个 \(h_t\) 会被用于预测,并传递给下一个时间步。
总结协同工作流程
在每一个时间步t,LSTM单元按顺序执行以下操作:
- “忘记”:用遗忘门 \(f_t\) 决定从长期状态 \(C_{t-1}\) 中丢弃什么。
- “记忆”:用输入门 \(i_t\) 决定将哪些新信息(来自候选状态 \(\tilde{C}_t\))存储到长期状态中。
- “更新”:结合前两步,计算新的长期状态 \(C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t\)。
- “输出”:用输出门 \(o_t\) 基于新的长期状态 \(C_t\) ,计算并输出当前的短期状态/隐藏状态 \(h_t = o_t \odot \tanh(C_t)\)。
通过这三个精巧的门控机制,LSTM能够有效地学习长期依赖关系,在长序列数据上表现出色。