1. Transformer架构的核心价值与应用场景
2017年那篇《Attention Is All You Need》论文彻底改变了自然语言处理的游戏规则。Transformer架构凭借其独特的自注意力机制,在机器翻译任务上首次实现了完全基于注意力机制的端到端训练,效果远超当时的RNN和LSTM模型。如今从BERT到GPT-3,几乎所有主流NLP模型都建立在Transformer的基础之上。
PyTorch作为当前最受欢迎的深度学习框架之一,其动态计算图和直观的API设计使得实现复杂模型变得异常简单。本文将带您从零开始构建一个标准的Transformer模型,这个实现将包含完整的Encoder-Decoder结构、多头注意力机制以及位置编码等核心组件。这个基础架构可以直接用于机器翻译任务,也是理解更复杂变体(如BERT)的最佳起点。
2. 环境准备与基础配置
2.1 PyTorch环境搭建
推荐使用Python 3.8+和PyTorch 1.10+版本,这些版本对Transformer相关操作有更好的优化。通过conda可以快速创建隔离环境:
bash复制conda create -n transformer python=3.8
conda activate transformer
pip install torch torchtext torchdata -f https://download.pytorch.org/whl/cu113/torch_stable.html
注意:如果使用GPU加速,请确保安装对应CUDA版本的PyTorch。可以通过
torch.cuda.is_available()验证GPU是否可用。
2.2 关键超参数定义
我们先定义模型的核心参数,这些参数直接影响模型容量和性能:
python复制import torch
import math
class Config:
def __init__(self):
self.src_vocab_size = 5000 # 源语言词表大小
self.tgt_vocab_size = 5000 # 目标语言词表大小
self.d_model = 512 # 嵌入维度
self.n_head = 8 # 注意力头数
self.num_encoder_layers = 6 # Encoder层数
self.num_decoder_layers = 6 # Decoder层数
self.d_ff = 2048 # 前馈网络维度
self.dropout = 0.1 # Dropout率
self.max_seq_len = 100 # 最大序列长度
3. 核心组件实现
3.1 位置编码(Positional Encoding)
Transformer需要显式地注入序列的位置信息,因为自注意力机制本身不具备感知位置的能力。我们使用正弦和余弦函数的组合来生成位置编码:
python复制class PositionalEncoding(torch.nn.Module):
def __init__(self, d_model, max_len=100):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:, :x.size(1)]
技巧:位置编码不需要训练,因此使用register_buffer将其注册为模型的缓冲区而非可训练参数。
3.2 多头注意力机制(Multi-Head Attention)
这是Transformer最核心的组件,允许模型同时关注输入序列的不同位置:
python复制class MultiHeadAttention(torch.nn.Module):
def __init__(self, d_model, n_head, dropout=0.1):
super().__init__()
assert d_model % n_head == 0
self.d_k = d_model // n_head
self.n_head = n_head
self.w_q = torch.nn.Linear(d_model, d_model)
self.w_k = torch.nn.Linear(d_model, d_model)
self.w_v = torch.nn.Linear(d_model, d_model)
self.fc = torch.nn.Linear(d_model, d_model)
self.dropout = torch.nn.Dropout(dropout)
self.scale = torch.sqrt(torch.FloatTensor([self.d_k]))
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
# 线性变换并分头
q = self.w_q(q).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
k = self.w_k(k).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
v = self.w_v(v).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
# 计算注意力得分
scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale.to(q.device)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn = torch.softmax(scores, dim=-1)
attn = self.dropout(attn)
# 应用注意力权重并合并头
output = torch.matmul(attn, v).transpose(1, 2).contiguous()
output = output.view(batch_size, -1, self.n_head * self.d_k)
return self.fc(output)
3.3 前馈网络(Position-wise Feed Forward)
每个编码器和解码器层都包含一个全连接前馈网络:
python复制class FeedForward(torch.nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.linear1 = torch.nn.Linear(d_model, d_ff)
self.dropout = torch.nn.Dropout(dropout)
self.linear2 = torch.nn.Linear(d_ff, d_model)
def forward(self, x):
return self.linear2(self.dropout(torch.relu(self.linear1(x))))
4. Encoder与Decoder实现
4.1 Encoder层结构
每个Encoder层包含一个多头注意力子层和一个前馈网络子层,每个子层都有残差连接和层归一化:
python复制class EncoderLayer(torch.nn.Module):
def __init__(self, d_model, n_head, d_ff, dropout):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_head, dropout)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm1 = torch.nn.LayerNorm(d_model)
self.norm2 = torch.nn.LayerNorm(d_model)
self.dropout1 = torch.nn.Dropout(dropout)
self.dropout2 = torch.nn.Dropout(dropout)
def forward(self, x, mask):
# 自注意力子层
attn_output = self.self_attn(x, x, x, mask)
x = x + self.dropout1(attn_output)
x = self.norm1(x)
# 前馈网络子层
ffn_output = self.ffn(x)
x = x + self.dropout2(ffn_output)
return self.norm2(x)
4.2 Decoder层结构
Decoder比Encoder更复杂,包含两个注意力子层:
python复制class DecoderLayer(torch.nn.Module):
def __init__(self, d_model, n_head, d_ff, dropout):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_head, dropout)
self.enc_attn = MultiHeadAttention(d_model, n_head, dropout)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm1 = torch.nn.LayerNorm(d_model)
self.norm2 = torch.nn.LayerNorm(d_model)
self.norm3 = torch.nn.LayerNorm(d_model)
self.dropout1 = torch.nn.Dropout(dropout)
self.dropout2 = torch.nn.Dropout(dropout)
self.dropout3 = torch.nn.Dropout(dropout)
def forward(self, x, enc_output, src_mask, tgt_mask):
# 自注意力子层(关注已生成的目标序列)
attn_output = self.self_attn(x, x, x, tgt_mask)
x = x + self.dropout1(attn_output)
x = self.norm1(x)
# 编码器-解码器注意力子层(关注源序列)
attn_output = self.enc_attn(x, enc_output, enc_output, src_mask)
x = x + self.dropout2(attn_output)
x = self.norm2(x)
# 前馈网络子层
ffn_output = self.ffn(x)
x = x + self.dropout3(ffn_output)
return self.norm3(x)
5. 完整Transformer组装
5.1 嵌入层与输出层
python复制class Transformer(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.encoder_embedding = torch.nn.Embedding(config.src_vocab_size, config.d_model)
self.decoder_embedding = torch.nn.Embedding(config.tgt_vocab_size, config.d_model)
self.pos_encoding = PositionalEncoding(config.d_model, config.max_seq_len)
self.encoder_layers = torch.nn.ModuleList([
EncoderLayer(config.d_model, config.n_head, config.d_ff, config.dropout)
for _ in range(config.num_encoder_layers)
])
self.decoder_layers = torch.nn.ModuleList([
DecoderLayer(config.d_model, config.n_head, config.d_ff, config.dropout)
for _ in range(config.num_decoder_layers)
])
self.fc_out = torch.nn.Linear(config.d_model, config.tgt_vocab_size)
self.dropout = torch.nn.Dropout(config.dropout)
5.2 前向传播逻辑
python复制 def forward(self, src, tgt, src_mask=None, tgt_mask=None):
# 编码器部分
src_embedded = self.dropout(self.pos_encoding(self.encoder_embedding(src)))
enc_output = src_embedded
for layer in self.encoder_layers:
enc_output = layer(enc_output, src_mask)
# 解码器部分
tgt_embedded = self.dropout(self.pos_encoding(self.decoder_embedding(tgt)))
dec_output = tgt_embedded
for layer in self.decoder_layers:
dec_output = layer(dec_output, enc_output, src_mask, tgt_mask)
return self.fc_out(dec_output)
6. 训练技巧与优化策略
6.1 掩码生成方法
Transformer需要两种掩码:
- 源序列填充掩码(防止关注padding token)
- 目标序列因果掩码(防止解码器看到未来信息)
python复制def create_padding_mask(seq, pad_idx):
return (seq != pad_idx).unsqueeze(1).unsqueeze(2)
def create_lookahead_mask(size):
return torch.triu(torch.ones(size, size), diagonal=1).bool()
6.2 学习率调度器
Transformer通常使用带预热的学习率调度:
python复制class TransformerScheduler(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, d_model, warmup_steps):
self.d_model = d_model
self.warmup_steps = warmup_steps
super().__init__(optimizer)
def get_lr(self):
step = self.last_epoch + 1
return [
(self.d_model ** -0.5) * min(step ** -0.5, step * self.warmup_steps ** -1.5)
for _ in self.base_lrs
]
6.3 标签平滑正则化
减轻模型对标签的过度自信:
python复制class LabelSmoothingLoss(torch.nn.Module):
def __init__(self, smoothing=0.1, pad_idx=0):
super().__init__()
self.smoothing = smoothing
self.pad_idx = pad_idx
def forward(self, pred, target):
n_class = pred.size(-1)
log_pred = torch.log_softmax(pred, dim=-1)
with torch.no_grad():
true_dist = torch.zeros_like(log_pred)
true_dist.fill_(self.smoothing / (n_class - 1))
true_dist.scatter_(1, target.unsqueeze(1), 1 - self.smoothing)
true_dist[:, self.pad_idx] = 0
mask = (target == self.pad_idx).unsqueeze(1)
true_dist.masked_fill_(mask, 0)
return torch.mean(-torch.sum(true_dist * log_pred, dim=-1))
7. 模型训练与验证
7.1 训练循环示例
python复制def train(model, iterator, optimizer, criterion, clip):
model.train()
epoch_loss = 0
for batch in iterator:
src = batch.src
tgt = batch.tgt[:-1] # 去掉最后一个token
tgt_out = batch.tgt[1:] # 去掉第一个token
optimizer.zero_grad()
output = model(src, tgt)
loss = criterion(output.view(-1, output.shape[-1]),
tgt_out.view(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(iterator)
7.2 验证与推理
推理时使用beam search提高生成质量:
python复制def beam_search(model, src, src_mask, max_len, beam_size, pad_idx, eos_idx):
model.eval()
with torch.no_grad():
# 编码源序列
enc_output = model.encode(src, src_mask)
# 初始化beam
beams = [([pad_idx], 0)]
completed = []
for _ in range(max_len):
new_beams = []
for seq, score in beams:
if seq[-1] == eos_idx:
completed.append((seq, score))
continue
# 准备解码器输入
tgt = torch.LongTensor(seq).unsqueeze(0).to(src.device)
tgt_mask = create_lookahead_mask(tgt.size(1)).to(src.device)
# 获取预测
output = model.decode(tgt, enc_output, src_mask, tgt_mask)
logits = model.fc_out(output[:, -1, :])
log_probs = torch.log_softmax(logits, dim=-1)
topk_probs, topk_idx = log_probs.topk(beam_size, dim=-1)
# 扩展beam
for i in range(beam_size):
new_seq = seq + [topk_idx[0, i].item()]
new_score = score + topk_probs[0, i].item()
new_beams.append((new_seq, new_score))
# 选择top-k beams
beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]
# 合并已完成和未完成的序列
candidates = beams + completed
return sorted(candidates, key=lambda x: x[1], reverse=True)[0][0]
8. 性能优化技巧
8.1 混合精度训练
利用NVIDIA的Apex库实现混合精度训练:
python复制from apex import amp
model = Transformer(config).cuda()
optimizer = torch.optim.Adam(model.parameters())
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
8.2 梯度累积
在显存有限时,通过累积梯度模拟更大batch size:
python复制accumulation_steps = 4
for i, batch in enumerate(iterator):
loss = model(batch)
loss = loss / accumulation_steps
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
8.3 模型并行
对于超大模型,可以将不同层分配到不同GPU:
python复制class ParallelTransformer(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.encoder_layers = torch.nn.ModuleList([
EncoderLayer(config).to(f'cuda:{i % torch.cuda.device_count()}')
for i in range(config.num_encoder_layers)
])
def forward(self, x):
for layer in self.encoder_layers:
x = x.to(layer.w_q.weight.device)
x = layer(x)
return x
9. 常见问题排查
9.1 训练不收敛的可能原因
- 学习率设置不当:Transformer通常需要较小的学习率(如0.0001)
- 梯度爆炸:添加梯度裁剪(
clip_grad_norm_) - 初始化问题:确保参数初始化范围合理(如Xavier初始化)
- 掩码错误:验证注意力掩码是否正确应用
9.2 推理时重复生成问题
解决方法:
- 增加温度参数(Temperature)使softmax更平滑
- 使用top-k或top-p采样
- 添加重复惩罚(Repetition Penalty)
python复制def top_p_sampling(logits, p=0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# 移除累积概率超过p的token
sorted_indices_to_remove = cumulative_probs > p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float('-inf')
return torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1)
10. 模型部署与生产化
10.1 TorchScript导出
将模型转换为TorchScript以提高推理速度:
python复制model.eval()
example_input = (torch.randint(0, 100, (1, 10)), torch.randint(0, 100, (1, 5)))
traced_model = torch.jit.trace(model, example_input)
traced_model.save("transformer.pt")
10.2 ONNX格式转换
python复制torch.onnx.export(
model,
example_input,
"transformer.onnx",
input_names=["src", "tgt"],
output_names=["output"],
dynamic_axes={
"src": {0: "batch", 1: "src_seq"},
"tgt": {0: "batch", 1: "tgt_seq"},
"output": {0: "batch", 1: "tgt_seq"}
}
)
10.3 量化加速
python复制quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
在实际部署中发现,经过量化的Transformer模型推理速度可提升2-3倍,而精度损失通常在1%以内。对于生产环境,建议使用TensorRT进一步优化,特别是对于需要低延迟的场景。