Skip to content

第十五章:Transformer 架构

2017 年,Google 团队发表了划时代的论文 "Attention Is All You Need",提出了 Transformer 架构。它完全抛弃了循环结构,仅使用注意力机制来建模序列中的依赖关系,彻底改变了 NLP 的研究范式。


为什么需要 Transformer?

RNN/LSTM 的局限

问题说明
无法并行计算必须按时间步逐步处理,训练慢
长距离依赖虽然 LSTM 缓解了梯度消失,但仍然有限
信息瓶颈即使有注意力,编码器仍然需要逐步处理

Transformer 的优势

优势说明
完全并行所有位置可以同时计算,训练速度快
长距离依赖通过自注意力直接建模任意两个位置的关系
可扩展性可以通过增加层数和参数量来提升性能

整体架构

Transformer 采用编码器-解码器(Encoder-Decoder)架构:

┌─────────────────────────────────────────────────────────────┐
│                      Transformer                            │
│                                                             │
│  ┌─────────────────────┐    ┌─────────────────────┐       │
│  │      编码器          │    │      解码器          │       │
│  │  ┌───────────────┐  │    │  ┌───────────────┐  │       │
│  │  │  多头自注意力   │  │    │  │  掩码多头自注意力│  │       │
│  │  └───────┬───────┘  │    │  └───────┬───────┘  │       │
│  │          ↓          │    │          ↓          │       │
│  │  ┌───────────────┐  │    │  ┌───────────────┐  │       │
│  │  │  前馈神经网络   │  │    │  │  多头交叉注意力  │  │       │
│  │  └───────┬───────┘  │    │  └───────┬───────┘  │       │
│  │          ↓          │    │          ↓          │       │
│  │    × N 层堆叠       │    │  ┌───────────────┐  │       │
│  │                     │    │  │  前馈神经网络   │  │       │
│  └─────────────────────┘    │  └───────┬───────┘  │       │
│           │                 │          ↓          │       │
│           │                 │    × N 层堆叠       │       │
│           │                 └─────────────────────┘       │
│           ↓                          ↓                     │
│  ┌─────────────┐            ┌─────────────┐              │
│  │  输入嵌入    │            │  输出嵌入    │              │
│  │ + 位置编码   │            │ + 位置编码   │              │
│  └─────────────┘            └─────────────┘              │
└─────────────────────────────────────────────────────────────┘

编码器(Encoder)

编码器由 N 个相同的层堆叠而成(原论文中 N=6)。每一层包含两个子层:

子层 1:多头自注意力(Multi-Head Self-Attention)

自注意力让序列中的每个位置都能关注到序列中的所有其他位置:

MultiHead(Q,K,V)=Concat(head1,...,headh)WOheadi=Attention(QWiQ,KWiK,VWiV)Attention(Q,K,V)=Softmax(QKTdk)V

其中:

  • Q=K=V=X(自注意力,输入序列自身)
  • h 是注意力头数(原论文中 h=8
  • dk 是每个头的 Key 维度
  • WiQ,WiK,WiV 是每个头的投影矩阵
  • WO 是输出投影矩阵

子层 2:前馈神经网络(Feed-Forward Network)

每个位置独立地通过一个两层的前馈网络:

FFN(x)=ReLU(xW1+b1)W2+b2

其中:

  • W1Rdmodel×dffW2Rdff×dmodel
  • dff 是前馈网络的隐藏层维度(原论文中 dff=4×dmodel=2048

残差连接与层归一化

每个子层都使用残差连接(Residual Connection)和层归一化(Layer Normalization):

output=LayerNorm(x+SubLayer(x))

其中:

  • x+SubLayer(x) 是残差连接,帮助梯度流动
  • LayerNorm 稳定训练过程

解码器(Decoder)

解码器同样由 N 个相同的层堆叠而成。每一层包含三个子层:

子层 1:掩码多头自注意力(Masked Multi-Head Self-Attention)

与编码器的自注意力类似,但增加了掩码(Mask),防止解码器"看到"未来的词:

MaskedAttention(Q,K,V)=Softmax(QKTdk+M)V

其中掩码矩阵 M 的定义为:

Mij={0if ijif i<j

为什么需要掩码?

在自回归生成中,解码器在预测第 t 个词时只能看到前 t1 个词,不能"偷看"未来的词:

预测 "爱" 时:
输入:  <SOS> 我
掩码:  可见  可见  不可见  不可见
              爱    自然语言处理

子层 2:多头交叉注意力(Multi-Head Cross-Attention)

交叉注意力让解码器关注编码器的输出:

  • Q 来自解码器
  • K,V 来自编码器的输出

这类似于 Seq2Seq 中的注意力机制,但使用了多头注意力。

子层 3:前馈神经网络

与编码器相同。


位置编码(Positional Encoding)

为什么需要位置编码?

Transformer 没有循环结构,无法感知词语的顺序。例如:

"我 爱 你" 和 "你 爱 我" 在没有位置编码时,自注意力的输出是相同的!

因此需要位置编码来注入位置信息。

正弦位置编码

原论文使用正弦和余弦函数生成位置编码:

PE(pos,2i)=sin(pos100002i/dmodel)PE(pos,2i+1)=cos(pos100002i/dmodel)

其中:

  • pos 是位置索引(0, 1, 2, ...)
  • i 是维度索引(0, 1, ..., dmodel/21
  • dmodel 是模型维度

位置编码的直觉

位置 0: [sin(0/10000⁰), cos(0/10000⁰), sin(0/10000²), ...]
位置 1: [sin(1/10000⁰), cos(1/10000⁰), sin(1/10000²), ...]
位置 2: [sin(2/10000⁰), cos(2/10000⁰), sin(2/10000²), ...]

每个位置都有唯一的位置编码,且不同位置之间的编码具有可区分的模式。

最终输入

输入嵌入与位置编码相加:

input=TokenEmbedding(x)+PositionalEncoding(pos)

缩放点积注意力

为什么需要缩放?

dk 较大时,点积 QKT 的值会很大,导致 Softmax 的梯度很小(进入饱和区)。

缩放因子 dk 确保点积的方差保持在合理范围内:

Attention(Q,K,V)=Softmax(QKTdk)V

数学解释

假设 QK 的每个元素都是独立的均值为 0、方差为 1 的随机变量,那么 QK 的方差为 dk。除以 dk 后,方差恢复为 1。


超参数汇总

超参数符号原论文值说明
模型维度dmodel512输入/输出的向量维度
前馈维度dff2048FFN 的隐藏层维度
注意力头数h8多头注意力的头数
每头维度dk=dv64dmodel/h
编码器层数N6编码器的层数
解码器层数N6解码器的层数
Dropout-0.1Dropout 比率

Transformer 的计算复杂度

自注意力的时间复杂度为 O(n2d),其中 n 是序列长度,d 是模型维度:

操作时间复杂度说明
自注意力O(n2d)每个位置关注所有位置
FFNO(nd2)每个位置独立计算
总计O(n2d+nd2)n<d 时,FFN 主导

长序列的挑战

当序列长度 n 很大时,O(n2) 的复杂度会成为瓶颈。这也是后续各种高效 Transformer(如 Linformer、Longformer)的研究动机。


PyTorch 实现

python
import torch
import torch.nn as nn
import math

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        
        self.d_model = d_model
        
        # 嵌入层
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        # Transformer
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        
        # 输出层
        self.fc_out = nn.Linear(d_model, vocab_size)
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        # 嵌入
        src_emb = self.embedding(src) * math.sqrt(self.d_model)
        tgt_emb = self.embedding(tgt) * math.sqrt(self.d_model)
        
        # Transformer
        output = self.transformer(src_emb, tgt_emb, src_mask=src_mask, tgt_mask=tgt_mask)
        
        # 输出
        logits = self.fc_out(output)
        return logits

# 使用
model = TransformerModel(vocab_size=30000)
src = torch.randint(0, 30000, (2, 10))  # (batch, src_len)
tgt = torch.randint(0, 30000, (2, 8))   # (batch, tgt_len)
output = model(src, tgt)  # (2, 8, 30000)

小结

组件作用
多头自注意力建模序列内部的依赖关系
前馈神经网络对每个位置进行非线性变换
位置编码注入位置信息
残差连接帮助梯度流动
层归一化稳定训练
掩码防止解码器看到未来信息

Transformer 的设计思想——完全基于注意力、抛弃循环结构——开创了 NLP 的新纪元。接下来两章我们将深入讲解自注意力和位置编码的细节。

AI 知识体系 — 从机器学习到大语言模型