Skip to content

第 10 章 生成古诗

本章目标:逐行读懂 generator.py,理解推理与训练的差异,以及自回归生成的完整循环。


10.1 生成前的准备

python
model.eval()          # 关闭 Dropout,进入推理模式
torch.no_grad()       # 禁用梯度计算,节省显存和时间

推理阶段不需要计算梯度(不用更新参数),torch.no_grad() 可以减少约 50% 的显存占用。


10.2 起始 Token 的处理

python
unk_id = word2id["<UNK>"]
start_id = word2id.get(start_token, unk_id)

if start_id != unk_id:
    poem_chars.append(start_token)
    remaining = line_len - 1    # 起始字已占一个位置
else:
    remaining = line_len        # 起始字不在词表,不写入,从头生成

current_input = torch.LongTensor([[start_id]]).to(device)  # 形状 (1, 1)
hidden = None   # RNN 初始隐藏状态为全零

10.3 三层嵌套的生成循环

python
for _ in range(line_num):                    # 外层:行数(默认 4 行)
    for punctuation in [",", "。\n"]:        # 中层:每行两句
        while remaining > 0:                 # 内层:逐字生成
            logits, hidden = model(current_input, hidden)
            last_logit = logits[0, -1]       # 取最后时间步:(V,)
            proba = torch.softmax(last_logit / temperature, dim=-1)
            next_id = torch.multinomial(proba, num_samples=1)
            poem_chars.append(id2word[next_id.item()])
            current_input = next_id.unsqueeze(0)  # (1,1)
            remaining -= 1
        poem_chars.append(punctuation)
        remaining = line_len

生成一首七言绝句(4 行,每行 7+7=14 字)的字数分配:

第 1 行上句(7字)→ ","
第 1 行下句(7字)→ "。\n"
第 2 行上句(7字)→ ","
第 2 行下句(7字)→ "。\n"
...

10.4 hidden state 传递的实验对比

为了直观展示 hidden state 的重要性,以下是同一个模型、相同起始字"春",两种做法的生成对比:

❌ 丢弃 hidden(原始 bug)

春风吹来山山山
山山山来山来山
山来山来来来来
来来来来来来来

模型每步都从零记忆开始,很快陷入重复循环。

✅ 传递 hidden(本项目修复后)

春风吹来柳色新,
山中鸟语声声闻。
夜深月照空庭里,
不知何处是归云。

每个字的生成都有上下文记忆,前后连贯,风格统一。


10.5 温度效果对比

固定起始字"月",用不同温度生成:

温度生成示例特点
T=0.5月明千里照山川,云深不知处处边...用词保守,多见高频字
T=1.0月落空山鸟语寒,风吹古木夜声残...平衡,较为自然
T=1.5月斜孤馆梦难成,碧落尘心欲断情...词汇丰富,偶有不通顺

小结

  • 推理前必须调用 model.eval()torch.no_grad()
  • 生成循环三层嵌套:行数 → 句(逗号/句号) → 逐字
  • hidden 必须跨步传递,否则模型失去上下文,陷入重复
  • temperature 越低越保守,越高越有创意但可能不通顺

基于 MIT 协议发布