第 11 章 动手实验:调参与观察
本章目标:给出一套可操作的实验方案,通过亲手改参数、看结果,建立对超参数的直觉认知。
11.1 快速运行
确保安装了 PyTorch,然后直接运行:
bash
cd poems/
python main.py预期输出:
正在加载数据...
词表大小: 2439 | 训练样本数: 11814
使用设备: mps ← MacBook Pro 会自动选 MPS
模型参数量: 2,795,143
开始训练:20 轮 | 批大小 32 | 学习率 0.001
[==================================================] epoch 1/20 loss: 4.823156
[==================================================] epoch 2/20 loss: 3.941203
...
[==================================================] epoch 20/20 loss: 2.187634
── 生成示例(七言绝句)──
【春】
春风吹柳绿千条,
山色空蒙雨亦奇。
...训练 20 轮在 MPS 设备上约需 3–5 分钟,CPU 约需 15–20 分钟。
11.2 实验一:改变 hidden_size
修改 config.py:
python
HIDDEN_SIZE = 128 # 默认 512| hidden_size | 参数量 | 训练时间(MPS) | 最终 loss |
|---|---|---|---|
| 128 | ~60 万 | ~1 分钟 | ~2.8 |
| 256 | ~130 万 | ~2 分钟 | ~2.5 |
| 512(默认) | ~279 万 | ~4 分钟 | ~2.2 |
观察:hidden_size 越大,模型容量越强,loss 越低,但训练越慢。当 hidden_size 很小时,生成的诗重复率明显升高。
11.3 实验二:改变 seq_len
python
SEQ_LEN = 8 # 默认 24| seq_len | 样本数 | 训练稳定性 | 生成连贯性 |
|---|---|---|---|
| 8 | 更多(每首诗切更多片段) | 非常稳定 | 较差(记忆太短) |
| 24(默认) | 11,814 | 稳定 | 较好 |
| 48 | 更少 | 偶有震荡 | 好(但一首短诗可能不够长) |
观察:seq_len 过小,模型来不及学到字与字之间的远程联系;过大则序列样本减少,训练数据不足。
11.4 实验三:调节 temperature
模型训练好后,不需要重新训练,只改 DEFAULT_TEMPERATURE 即可:
python
DEFAULT_TEMPERATURE = 0.5 # 保守
DEFAULT_TEMPERATURE = 1.0 # 默认
DEFAULT_TEMPERATURE = 1.5 # 大胆建议用同一个起始字(如"山")分别生成 5 首,对比:
T=0.5:每次生成的诗是否高度相似?用字是否集中在少数高频字?T=1.5:是否出现更罕见的字?是否偶尔出现语义不通的句子?
11.5 扩展思考
完成上述实验后,可以进一步探索:
换用 LSTM 或 GRU
将 model.py 中的 nn.RNN 替换为 nn.LSTM 或 nn.GRU(接口几乎相同),对比训练曲线和生成质量。LSTM 通过遗忘门/输入门/输出门缓解梯度消失,在长序列上通常比 RNN 更稳定。
python
# 只需改一行
self.rnn = nn.LSTM(input_size=embedding_size, hidden_size=hidden_size,
num_layers=num_layers, batch_first=True,
dropout=dropout if num_layers > 1 else 0.0)注意:LSTM 的 hidden state 是一个元组
(h, c),forward 方法需要相应调整。
添加特殊标记
在词表中加入 <BOS>(句子开始)和 <EOS>(句子结束)标记,让模型学会在合适位置停止生成,而不是依赖外部的行数/字数控制。
扩充语料库
313 首唐诗对于一个语言模型来说非常有限。可以尝试添加更多诗词(如全唐诗数据集,约 5 万首),观察语料规模对生成质量的影响。
全书总结
本册从语言模型的概率基础出发,经过词嵌入、RNN、训练原理、推理策略等理论章节,再到 6 个代码模块的逐行拆解,最终完成了一个能"写古诗"的端到端深度学习项目。
核心收获:
- 语言模型 = 逐字估计条件概率
- RNN 通过隐藏状态实现序列记忆
- 梯度裁剪 + Dropout 是 RNN 训练的标配稳定手段
- 推理时
hidden必须持续传递 temperature是控制生成风格的简单有效工具