Skip to content

GRU(门控循环单元)

GRU(Gated Recurrent Unit)是 LSTM 的简化版本:保留门控思想,删除 Cell State,合并为 2 个门,参数更少,训练更快,在多数任务上效果与 LSTM 相当。

1. 设计目标

GRU 由 Cho 等人在 2014 年提出,目标是:

  • ✅ 保留 LSTM 的长期依赖建模能力
  • 减少结构复杂性(2门 vs 3门,1状态 vs 2状态)
  • 提升训练与推理效率

2. GRU vs LSTM:结构对比

对比项LSTMGRU
门控数量3(遗忘 + 输入 + 输出)2(更新 + 重置)
内部状态ht + ct(双状态)只有 ht(单状态)
参数量4× 隐藏层大小3× 隐藏层大小
表达能力更强(复杂长文本)接近,略弱
推理速度
工程难度复杂简单

3. 两个门控机制

更新门 zt(Update Gate)— 替代 LSTM 的遗忘门+输入门

更新门

zt=σ(Wzxt+Uzht1+bz)
  • 输出 [0,1]:控制保留多少旧记忆 vs 写入多少新信息
  • zt1:主要保留历史状态(记住远处信息)
  • zt0:主要用新候选状态(忘掉旧信息,专注当前)

重置门 rt(Reset Gate)— 控制历史信息的影响

重置门

rt=σ(Wrxt+Urht1+br)
  • 控制计算候选状态时,历史 ht1 的参与程度
  • rt0:完全忽略历史状态(从零开始处理当前输入)
  • rt1:充分利用历史状态

4. 完整公式

候选隐藏状态

候选状态

h~t=tanh(Whxt+Uh(rtht1)+bh)
  • rtht1:重置门控制历史信息的保留比例

最终输出

最终输出

ht=(1zt)ht1+zth~t

直觉理解

  • (1zt):从旧记忆中保留多少
  • zt:从新候选中接受多少
  • 通过"线性插值"混合历史与当前

5. 公式汇总

组件公式作用
更新门zt=σ(Wzxt+Uzht1)控制新旧信息比例
重置门rt=σ(Wrxt+Urht1)控制历史影响程度
候选状态h~t=tanh(Whxt+Uh(rtht1))融合当前输入和过滤后的历史
最终状态ht=(1zt)ht1+zth~t线性插值混合新旧信息

6. PyTorch 实现

基础 GRU

python
import torch
import torch.nn as nn

class GRUModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=1):
        super().__init__()
        self.gru = nn.GRU(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=False,
            dropout=0.3 if num_layers > 1 else 0
        )
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(0.3)
    
    def forward(self, x):
        # output: [batch, seq_len, hidden_dim]
        # h_n:    [num_layers, batch, hidden_dim]
        output, h_n = self.gru(x)
        
        # 取最后时间步用于分类
        last_out = output[:, -1, :]
        return self.fc(self.dropout(last_out))

# 使用
model = GRUModel(input_dim=100, hidden_dim=256, output_dim=3)
x = torch.randn(32, 20, 100)
print(model(x).shape)  # [32, 3]

文本分类完整示例

python
import torch
import torch.nn as nn

class TextGRU(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim, 
                 num_layers=2, bidirectional=True):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.gru = nn.GRU(
            embed_dim, hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=0.3
        )
        
        # 双向时输出维度 * 2
        fc_input = hidden_dim * 2 if bidirectional else hidden_dim
        self.fc = nn.Linear(fc_input, output_dim)
        self.dropout = nn.Dropout(0.3)
    
    def forward(self, x):
        # x: [batch, seq_len](词索引)
        embedded = self.dropout(self.embedding(x))  # [batch, seq, embed]
        output, h_n = self.gru(embedded)
        
        if self.gru.bidirectional:
            # 拼接双向最后隐藏状态
            h = torch.cat([h_n[-2], h_n[-1]], dim=1)  # [batch, hidden*2]
        else:
            h = h_n[-1]   # [batch, hidden]
        
        return self.fc(self.dropout(h))

# 训练循环
model = TextGRU(vocab_size=20000, embed_dim=128, 
                hidden_dim=256, output_dim=5)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# 单步训练
texts = torch.randint(0, 20000, (32, 50))   # [batch=32, seq=50]
labels = torch.randint(0, 5, (32,))

optimizer.zero_grad()
logits = model(texts)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
print(f"Loss: {loss.item():.4f}")

7. 适用场景

场景是否适合
短文本分类(评论、微博)✅ 非常适合
嵌入式/手机端推理✅ 轻量高效
实时流数据处理✅ 低延迟
小规模语言建模✅ 可用
大规模预训练模型❌ 用 Transformer
超长序列(>500步)❌ 用 Transformer

8. GRU 的局限性

局限原因
长距离依赖仍有限序列极长时仍有信息遗忘
无法并行时间步依赖限制 GPU 利用率
不支持大规模预训练生态不如 Transformer 完善
超长序列表现差100+ 步后记忆衰减

9. 总结

特性GRU
提出时间2014(Cho 等人)
核心机制更新门 + 重置门
状态数量1个(只有 ht
参数量LSTM 的 3/4
最大优势简单高效,资源友好
最大局限超长序列不如 Transformer

GRU 是 LSTM 和 Transformer 之间的性价比之选——在不需要大模型预训练的场景下,是轻量序列建模的首选。

AI 知识体系 — 从机器学习到大语言模型