Skip to content

第 11 章 动手实验:调参与观察

本章目标:给出一套可操作的实验方案,通过亲手改参数、看结果,建立对超参数的直觉认知。


11.1 快速运行

确保安装了 PyTorch,然后直接运行:

bash
cd poems/
python main.py

预期输出:

正在加载数据...
词表大小: 2439 | 训练样本数: 11814

使用设备: mps        ← MacBook Pro 会自动选 MPS

模型参数量: 2,795,143

开始训练:20 轮 | 批大小 32 | 学习率 0.001

[==================================================] epoch  1/20  loss: 4.823156
[==================================================] epoch  2/20  loss: 3.941203
...
[==================================================] epoch 20/20  loss: 2.187634

── 生成示例(七言绝句)──

【春】
春风吹柳绿千条,
山色空蒙雨亦奇。
...

训练 20 轮在 MPS 设备上约需 3–5 分钟,CPU 约需 15–20 分钟


11.2 实验一:改变 hidden_size

修改 config.py

python
HIDDEN_SIZE = 128   # 默认 512
hidden_size参数量训练时间(MPS)最终 loss
128~60 万~1 分钟~2.8
256~130 万~2 分钟~2.5
512(默认)~279 万~4 分钟~2.2

观察hidden_size 越大,模型容量越强,loss 越低,但训练越慢。当 hidden_size 很小时,生成的诗重复率明显升高。


11.3 实验二:改变 seq_len

python
SEQ_LEN = 8    # 默认 24
seq_len样本数训练稳定性生成连贯性
8更多(每首诗切更多片段)非常稳定较差(记忆太短)
24(默认)11,814稳定较好
48更少偶有震荡好(但一首短诗可能不够长)

观察seq_len 过小,模型来不及学到字与字之间的远程联系;过大则序列样本减少,训练数据不足。


11.4 实验三:调节 temperature

模型训练好后,不需要重新训练,只改 DEFAULT_TEMPERATURE 即可:

python
DEFAULT_TEMPERATURE = 0.5   # 保守
DEFAULT_TEMPERATURE = 1.0   # 默认
DEFAULT_TEMPERATURE = 1.5   # 大胆

建议用同一个起始字(如"山")分别生成 5 首,对比:

  • T=0.5:每次生成的诗是否高度相似?用字是否集中在少数高频字?
  • T=1.5:是否出现更罕见的字?是否偶尔出现语义不通的句子?

11.5 扩展思考

完成上述实验后,可以进一步探索:

换用 LSTM 或 GRU

model.py 中的 nn.RNN 替换为 nn.LSTMnn.GRU(接口几乎相同),对比训练曲线和生成质量。LSTM 通过遗忘门/输入门/输出门缓解梯度消失,在长序列上通常比 RNN 更稳定。

python
# 只需改一行
self.rnn = nn.LSTM(input_size=embedding_size, hidden_size=hidden_size,
                   num_layers=num_layers, batch_first=True,
                   dropout=dropout if num_layers > 1 else 0.0)

注意:LSTM 的 hidden state 是一个元组 (h, c),forward 方法需要相应调整。

添加特殊标记

在词表中加入 <BOS>(句子开始)和 <EOS>(句子结束)标记,让模型学会在合适位置停止生成,而不是依赖外部的行数/字数控制。

扩充语料库

313 首唐诗对于一个语言模型来说非常有限。可以尝试添加更多诗词(如全唐诗数据集,约 5 万首),观察语料规模对生成质量的影响。


全书总结

本册从语言模型的概率基础出发,经过词嵌入、RNN、训练原理、推理策略等理论章节,再到 6 个代码模块的逐行拆解,最终完成了一个能"写古诗"的端到端深度学习项目。

核心收获:

  1. 语言模型 = 逐字估计条件概率
  2. RNN 通过隐藏状态实现序列记忆
  3. 梯度裁剪 + Dropout 是 RNN 训练的标配稳定手段
  4. 推理时 hidden 必须持续传递
  5. temperature 是控制生成风格的简单有效工具

基于 MIT 协议发布