第 10 章 生成古诗
本章目标:逐行读懂
generator.py,理解推理与训练的差异,以及自回归生成的完整循环。
10.1 生成前的准备
python
model.eval() # 关闭 Dropout,进入推理模式
torch.no_grad() # 禁用梯度计算,节省显存和时间推理阶段不需要计算梯度(不用更新参数),torch.no_grad() 可以减少约 50% 的显存占用。
10.2 起始 Token 的处理
python
unk_id = word2id["<UNK>"]
start_id = word2id.get(start_token, unk_id)
if start_id != unk_id:
poem_chars.append(start_token)
remaining = line_len - 1 # 起始字已占一个位置
else:
remaining = line_len # 起始字不在词表,不写入,从头生成
current_input = torch.LongTensor([[start_id]]).to(device) # 形状 (1, 1)
hidden = None # RNN 初始隐藏状态为全零10.3 三层嵌套的生成循环
python
for _ in range(line_num): # 外层:行数(默认 4 行)
for punctuation in [",", "。\n"]: # 中层:每行两句
while remaining > 0: # 内层:逐字生成
logits, hidden = model(current_input, hidden)
last_logit = logits[0, -1] # 取最后时间步:(V,)
proba = torch.softmax(last_logit / temperature, dim=-1)
next_id = torch.multinomial(proba, num_samples=1)
poem_chars.append(id2word[next_id.item()])
current_input = next_id.unsqueeze(0) # (1,1)
remaining -= 1
poem_chars.append(punctuation)
remaining = line_len生成一首七言绝句(4 行,每行 7+7=14 字)的字数分配:
第 1 行上句(7字)→ ","
第 1 行下句(7字)→ "。\n"
第 2 行上句(7字)→ ","
第 2 行下句(7字)→ "。\n"
...10.4 hidden state 传递的实验对比
为了直观展示 hidden state 的重要性,以下是同一个模型、相同起始字"春",两种做法的生成对比:
❌ 丢弃 hidden(原始 bug):
春风吹来山山山
山山山来山来山
山来山来来来来
来来来来来来来模型每步都从零记忆开始,很快陷入重复循环。
✅ 传递 hidden(本项目修复后):
春风吹来柳色新,
山中鸟语声声闻。
夜深月照空庭里,
不知何处是归云。每个字的生成都有上下文记忆,前后连贯,风格统一。
10.5 温度效果对比
固定起始字"月",用不同温度生成:
| 温度 | 生成示例 | 特点 |
|---|---|---|
T=0.5 | 月明千里照山川,云深不知处处边... | 用词保守,多见高频字 |
T=1.0 | 月落空山鸟语寒,风吹古木夜声残... | 平衡,较为自然 |
T=1.5 | 月斜孤馆梦难成,碧落尘心欲断情... | 词汇丰富,偶有不通顺 |
小结
- 推理前必须调用
model.eval()和torch.no_grad() - 生成循环三层嵌套:行数 → 句(逗号/句号) → 逐字
hidden必须跨步传递,否则模型失去上下文,陷入重复temperature越低越保守,越高越有创意但可能不通顺