第 3 章 循环神经网络(RNN)
本章目标:理解 RNN 如何通过隐藏状态对序列建立记忆,掌握其计算图结构,并了解梯度消失/爆炸问题。
3.1 为什么普通全连接网络不适合序列
设想用全连接网络(MLP)来处理古诗序列,会遇到两个根本问题:
问题一:输入长度固定。MLP 的输入层神经元数量在定义时就确定了,而一首诗可以有 20 个字,也可以有 100 个字,无法用同一个 MLP 处理变长序列。
问题二:无法共享参数。"春"出现在第 1 个位置和第 10 个位置,应该用相同的语义表示,但 MLP 中第 1 位和第 10 位对应完全不同的权重,无法复用。
RNN 通过引入隐藏状态(Hidden State) 和跨时间步的权重共享解决了这两个问题。
3.2 RNN 的核心公式
RNN 在每个时间步
其中:
:当前时间步的输入(嵌入向量),形状 :上一时间步传来的隐藏状态,形状 :当前时间步更新后的隐藏状态,形状 :所有时间步共享的可学习参数 :激活函数,将输出压缩到
直觉:
3.3 按时间展开(Unrolling)
RNN 的循环结构可以展开成一条线性计算图:
x₁ x₂ x₃ x₄ x₅
│ │ │ │ │
▼ ▼ ▼ ▼ ▼
h₀ → [RNN] → [RNN] → [RNN] → [RNN] → [RNN] → h₅
↓ ↓ ↓ ↓ ↓
y₁ y₂ y₃ y₄ y₅- 每个
[RNN]方块执行相同的公式,使用相同的权重 通常初始化为全零向量 - 每个位置的输出
,再接 softmax 即可得到下一个字的概率分布
语言模型中的对齐方式:
输入 x: 春 眠 不 觉 晓
输出 y: 眠 不 觉 晓 处PoemDataset 中 y = poem_id[i+1:] 的含义。
3.4 多层堆叠 RNN
单层 RNN 的表达能力有限。将多个 RNN 层叠加,上一层的输出作为下一层的输入,可以捕捉更高层次、更抽象的序列模式:
输入 x ──→ RNN Layer 1 ──→ RNN Layer 2 ──→ 输出 y
↕ h¹ ↕ h²
(低层特征) (高层特征)本项目使用 2 层 RNN(num_layers=2),在 PyTorch 中只需指定 num_layers=2 即可自动堆叠。
3.5 维度详解:output 和 hidden 的区别
PyTorch 的 RNN 返回两个张量,很多同学会在这里卡住:
output, hidden = rnn(embedded, hx)它们有什么区别?
output — 所有时间步的最后一层隐藏状态
output.shape = (N, L, H)- 轴 0(N):批次中的第几条样本
- 轴 1(L):序列的第几个时间步
- 轴 2(H):该时间步的隐藏向量(hidden_size 维)
output 包含每一个时间步的结果,因为语言模型需要在每个位置都预测下一个字,所以我们用 output 接上 Linear 层输出 logits。
hidden — 序列结束后所有层的隐藏状态
hidden.shape = (num_layers, N, H)- 轴 0(num_layers):第几层 RNN
- 轴 1(N):批次中的第几条样本
- 轴 2(H):该层最终隐藏向量
hidden 只保存最后一个时间步的隐藏状态,且按层堆叠。它的作用是把"记忆"传递给下一次调用。
用图理解两者的关系(以 2 层 RNN、序列长 4 为例)
时间步: t=1 t=2 t=3 t=4
│ │ │ │
Layer 1: h¹₀→[RNN]→h¹₁→[RNN]→h¹₂→[RNN]→h¹₃→[RNN]→h¹₄
↓ ↓ ↓ ↓
Layer 2: h²₀→[RNN]→h²₁→[RNN]→h²₂→[RNN]→h²₃→[RNN]→h²₄
↓ ↓ ↓ ↓
out₁ out₂ out₃ out₄
output = [out₁, out₂, out₃, out₄] # shape: (N, L=4, H) ← 所有时间步
hidden = [h¹₄, h²₄] # shape: (layers=2, N, H) ← 最后时间步,所有层常见误解
output[:, -1, :](output 最后时间步的最后一层)与 hidden[-1](hidden 最后一层)在数值上是相等的。两者都指向
3.5 梯度消失与梯度爆炸
展开后的 RNN 本质上是一个很深的网络(深度 = 序列长度)。反向传播时,梯度需要沿时间步反向流动:
每一步的
- 若
:连乘后梯度趋近 0 → 梯度消失,早期时间步的参数几乎得不到更新 - 若
:连乘后梯度趋近无穷 → 梯度爆炸,训练发散
这是 RNN 的经典痛点。本项目通过梯度裁剪(见第 4 章)应对梯度爆炸;梯度消失则可通过换用 LSTM/GRU 缓解(见第 11 章扩展思考)。
小结
- RNN 通过共享权重和隐藏状态处理变长序列
- 核心公式:
- 展开图清晰展示了序列上的前向传播流程
- 多层堆叠能捕捉更高层次的序列特征
- 长序列训练面临梯度消失/爆炸,需要额外手段应对