Skip to content

第 3 章 循环神经网络(RNN)

本章目标:理解 RNN 如何通过隐藏状态对序列建立记忆,掌握其计算图结构,并了解梯度消失/爆炸问题。


3.1 为什么普通全连接网络不适合序列

设想用全连接网络(MLP)来处理古诗序列,会遇到两个根本问题:

问题一:输入长度固定。MLP 的输入层神经元数量在定义时就确定了,而一首诗可以有 20 个字,也可以有 100 个字,无法用同一个 MLP 处理变长序列。

问题二:无法共享参数。"春"出现在第 1 个位置和第 10 个位置,应该用相同的语义表示,但 MLP 中第 1 位和第 10 位对应完全不同的权重,无法复用。

RNN 通过引入隐藏状态(Hidden State) 和跨时间步的权重共享解决了这两个问题。


3.2 RNN 的核心公式

RNN 在每个时间步 t 执行同一个操作:

ht=tanh(Whht1+Wxxt+b)

其中:

  • xt:当前时间步的输入(嵌入向量),形状 (E,)
  • ht1:上一时间步传来的隐藏状态,形状 (H,)
  • ht:当前时间步更新后的隐藏状态,形状 (H,)
  • Wx,Wh,b:所有时间步共享的可学习参数
  • tanh:激活函数,将输出压缩到 (1,1)

直觉ht 是一个"记忆向量",它编码了从序列开头到第 t 个字的全部历史信息。每一步都在旧记忆 ht1 的基础上,融合新输入 xt,更新出新记忆 ht


3.3 按时间展开(Unrolling)

RNN 的循环结构可以展开成一条线性计算图:

x₁    x₂    x₃    x₄    x₅
 │     │     │     │     │
 ▼     ▼     ▼     ▼     ▼
h₀ → [RNN] → [RNN] → [RNN] → [RNN] → [RNN] → h₅
         ↓         ↓         ↓         ↓         ↓
        y₁        y₂        y₃        y₄        y₅
  • 每个 [RNN] 方块执行相同的公式,使用相同的权重
  • h0 通常初始化为全零向量
  • 每个位置的输出 yt=Woutht,再接 softmax 即可得到下一个字的概率分布

语言模型中的对齐方式

输入 x:  春   眠   不   觉   晓
输出 y:  眠   不   觉   晓   处

yt 是对 xt+1 的预测,即给定前 t 个字,预测第 t+1 个字。这正是 PoemDatasety = poem_id[i+1:] 的含义。


3.4 多层堆叠 RNN

单层 RNN 的表达能力有限。将多个 RNN 层叠加,上一层的输出作为下一层的输入,可以捕捉更高层次、更抽象的序列模式:

输入 x ──→ RNN Layer 1 ──→ RNN Layer 2 ──→ 输出 y
              ↕ h¹                ↕ h²
           (低层特征)         (高层特征)

本项目使用 2 层 RNNnum_layers=2),在 PyTorch 中只需指定 num_layers=2 即可自动堆叠。


3.5 维度详解:output 和 hidden 的区别

PyTorch 的 RNN 返回两个张量,很多同学会在这里卡住:

python
output, hidden = rnn(embedded, hx)

它们有什么区别?

output — 所有时间步的最后一层隐藏状态

output.shape = (N, L, H)
  • 轴 0(N):批次中的第几条样本
  • 轴 1(L):序列的第几个时间步
  • 轴 2(H):该时间步的隐藏向量(hidden_size 维)

output 包含每一个时间步的结果,因为语言模型需要在每个位置都预测下一个字,所以我们用 output 接上 Linear 层输出 logits。

hidden — 序列结束后所有层的隐藏状态

hidden.shape = (num_layers, N, H)
  • 轴 0(num_layers):第几层 RNN
  • 轴 1(N):批次中的第几条样本
  • 轴 2(H):该层最终隐藏向量

hidden 只保存最后一个时间步的隐藏状态,且按层堆叠。它的作用是把"记忆"传递给下一次调用。

用图理解两者的关系(以 2 层 RNN、序列长 4 为例)

时间步:        t=1      t=2      t=3      t=4
               │        │        │        │
Layer 1:  h¹₀→[RNN]→h¹₁→[RNN]→h¹₂→[RNN]→h¹₃→[RNN]→h¹₄
                ↓        ↓        ↓        ↓
Layer 2:  h²₀→[RNN]→h²₁→[RNN]→h²₂→[RNN]→h²₃→[RNN]→h²₄
                ↓        ↓        ↓        ↓
              out₁     out₂     out₃     out₄

output  = [out₁, out₂, out₃, out₄]          # shape: (N, L=4, H)  ← 所有时间步
hidden  = [h¹₄, h²₄]                        # shape: (layers=2, N, H) ← 最后时间步,所有层

常见误解

output[:, -1, :](output 最后时间步的最后一层)与 hidden[-1](hidden 最后一层)在数值上是相等的。两者都指向 h42——但它们的用途不同:output 用于逐步预测,hidden 用于传递状态。


3.5 梯度消失与梯度爆炸

展开后的 RNN 本质上是一个很深的网络(深度 = 序列长度)。反向传播时,梯度需要沿时间步反向流动:

Lh1=LhTt=2Ththt1

每一步的 htht1 都包含 Whtanh 的乘积:

  • Wh<1:连乘后梯度趋近 0 → 梯度消失,早期时间步的参数几乎得不到更新
  • Wh>1:连乘后梯度趋近无穷 → 梯度爆炸,训练发散

这是 RNN 的经典痛点。本项目通过梯度裁剪(见第 4 章)应对梯度爆炸;梯度消失则可通过换用 LSTM/GRU 缓解(见第 11 章扩展思考)。


小结

  • RNN 通过共享权重和隐藏状态处理变长序列
  • 核心公式:ht=tanh(Whht1+Wxxt+b)
  • 展开图清晰展示了序列上的前向传播流程
  • 多层堆叠能捕捉更高层次的序列特征
  • 长序列训练面临梯度消失/爆炸,需要额外手段应对

基于 MIT 协议发布