1. Transformer模型实现详解:从理论到代码实践
最近在复现Transformer模型时,发现很多教程要么过于理论化,要么代码实现不够完整。今天我想分享一个完整的Transformer实现过程,从数据准备到模型训练,再到预测评估,手把手带你理解这个革命性的模型架构。
Transformer最初由Google在2017年提出,彻底改变了自然语言处理领域的格局。与传统的RNN和CNN不同,它完全基于注意力机制,能够并行处理序列数据,同时捕捉长距离依赖关系。下面我们就从最基础的部分开始,逐步构建一个完整的Transformer模型。
2. 数据准备与预处理
2.1 数据集生成
我们首先构建一个简单的音标到字母的映射任务。这个任务虽然简单,但足以展示Transformer的核心机制:
python复制# 定义音标和字母列表
soundmark = ['ei', 'bi:', 'si:', 'di:', 'i:', 'ef', 'dʒi:', 'eit∫', 'ai', 'dʒei',
'kei', 'el', 'em', 'en', 'əu', 'pi:', 'kju:', 'ɑ:', 'es', 'ti:',
'ju:', 'vi:', 'd∧blju:', 'eks', 'wai', 'zi:']
alphabet = ['a','b','c','d','e','f','g','h','i','j','k','l','m',
'n','o','p','q','r','s','t','u','v','w','x','y','z']
# 生成带噪声的训练数据
t = 1000 # 样本总数
r = 0.9 # 正确映射概率
seq_len = 6 # 序列长度
src_tokens, tgt_tokens = [],[]
for i in range(t):
src, tgt = [],[]
for j in range(seq_len):
ind = random.randint(0,25)
src.append(soundmark[ind])
# 90%概率正确映射,10%概率随机噪声
tgt.append(alphabet[ind] if random.random() < r else alphabet[random.randint(0,25)])
src_tokens.append(src)
tgt_tokens.append(tgt)
提示:在实际NLP任务中,这种噪声模拟了真实场景中的拼写错误或翻译偏差,有助于提高模型的鲁棒性。
2.2 词表构建
Transformer需要将文本转换为数字表示,因此我们需要构建词表:
python复制class Vocab:
def __init__(self, tokens):
self.tokens = tokens
# 特殊token
self.token2index = {'<pad>': 0, '<bos>': 1, '<eos>': 2, '<unk>': 3}
# 按词频排序添加普通token
self.token2index.update({
token: index+4 for index, (token, freq) in enumerate(
sorted(Counter(flatten(self.tokens)).items(),
key=lambda x: x[1], reverse=True))
})
self.index2token = {v:k for k,v in self.token2index.items()}
def __getitem__(self, query):
if isinstance(query, str):
return self.token2index.get(query, 3) # 未知词返回<unk>
elif isinstance(query, int):
return self.index2token.get(query, '<unk>')
elif isinstance(query, (list, tuple)):
return [self.__getitem__(item) for item in query]
def __len__(self):
return len(self.index2token)
特殊token的作用:
<pad>:用于填充不等长序列<bos>:标记序列开始<eos>:标记序列结束<unk>:处理未见过的词
2.3 数据加载器
将数据转换为PyTorch的DataLoader格式:
python复制# 添加特殊token并转换为tensor
encoder_input = torch.tensor([src_vocab[line + ['<pad>']] for line in src_tokens])
decoder_input = torch.tensor([tgt_vocab[['<bos>'] + line] for line in tgt_tokens])
decoder_output = torch.tensor([tgt_vocab[line + ['<eos>']] for line in tgt_tokens])
# 自定义Dataset类
class MyDataSet(Data.Dataset):
def __init__(self, enc_inputs, dec_inputs, dec_outputs):
self.enc_inputs = enc_inputs
self.dec_inputs = dec_inputs
self.dec_outputs = dec_outputs
def __getitem__(self, idx):
return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]
def __len__(self):
return len(self.enc_inputs)
# 创建数据加载器
train_loader = DataLoader(MyDataSet(encoder_input[:800], decoder_input[:800], decoder_output[:800]),
batch_size=16, shuffle=True)
test_loader = DataLoader(MyDataSet(encoder_input[800:], decoder_input[800:], decoder_output[800:]),
batch_size=1)
3. Transformer核心组件实现
3.1 位置编码
由于Transformer没有递归或卷积结构,需要显式地注入位置信息:
python复制def get_sinusoid_encoding_table(n_position, d_model):
def cal_angle(position, hid_idx):
return position / (10000 ** (2 * (hid_idx//2) / d_model))
def get_posi_angle_vec(position):
return [cal_angle(position, hid_j) for hid_j in range(d_model)]
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # 偶数维用sin
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # 奇数维用cos
return torch.FloatTensor(sinusoid_table)
位置编码使用不同频率的正弦和余弦函数,使得模型能够学习到相对位置信息。这种编码方式可以处理比训练时更长的序列,具有良好的外推性。
3.2 注意力机制
3.2.1 缩放点积注意力
python复制class ScaledDotProductAttention(nn.Module):
def __init__(self):
super().__init__()
def forward(self, Q, K, V, attn_mask):
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)
# 应用掩码
scores.masked_fill_(attn_mask, -1e9)
# softmax归一化
attn = nn.Softmax(dim=-1)(scores)
# 加权求和
context = torch.matmul(attn, V)
return context, attn
缩放因子√d_k的作用是防止点积结果过大导致softmax梯度消失。
3.2.2 多头注意力
python复制class MultiHeadAttention(nn.Module):
def __init__(self):
super().__init__()
self.W_Q = nn.Linear(d_model, d_k * n_heads)
self.W_K = nn.Linear(d_model, d_k * n_heads)
self.W_V = nn.Linear(d_model, d_v * n_heads)
self.fc = nn.Linear(n_heads * d_v, d_model)
def forward(self, input_Q, input_K, input_V, attn_mask):
residual, batch_size = input_Q, input_Q.size(0)
# 线性变换并分头
Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)
K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1,2)
V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1,2)
# 扩展掩码维度
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
# 计算注意力
context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)
# 拼接多头结果
context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v)
# 输出线性变换
output = self.fc(context)
# 残差连接和层归一化
return nn.LayerNorm(d_model)(output + residual), attn
多头注意力的优势在于:
- 允许模型在不同位置共同关注来自不同表示子空间的信息
- 提高了模型的表达能力
- 并行计算效率高
3.3 前馈网络
python复制class PoswiseFeedForwardNet(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
def forward(self, inputs):
residual = inputs
output = self.fc(inputs)
return nn.LayerNorm(d_model)(output + residual)
前馈网络由两个线性变换和一个ReLU激活组成,中间层的维度通常比输入大(d_ff=2048 vs d_model=512),这种"瓶颈"结构有助于捕捉更复杂的特征。
4. 编码器和解码器实现
4.1 编码器层
python复制class EncoderLayer(nn.Module):
def __init__(self):
super().__init__()
self.enc_self_attn = MultiHeadAttention()
self.pos_ffn = PoswiseFeedForwardNet()
def forward(self, enc_inputs, enc_self_attn_mask):
# 自注意力
enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)
# 前馈网络
enc_outputs = self.pos_ffn(enc_outputs)
return enc_outputs, attn
4.2 解码器层
python复制class DecoderLayer(nn.Module):
def __init__(self):
super().__init__()
self.dec_self_attn = MultiHeadAttention()
self.dec_enc_attn = MultiHeadAttention()
self.pos_ffn = PoswiseFeedForwardNet()
def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
# 带掩码的自注意力
dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
# 编码器-解码器注意力
dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
# 前馈网络
dec_outputs = self.pos_ffn(dec_outputs)
return dec_outputs, dec_self_attn, dec_enc_attn
解码器与编码器的关键区别:
- 解码器自注意力使用带掩码的多头注意力,防止看到未来信息
- 增加了编码器-解码器注意力层,让解码器可以关注编码器的输出
5. 完整Transformer模型
python复制class Transformer(nn.Module):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
self.projection = nn.Linear(d_model, tgt_vocab_size)
def forward(self, enc_inputs, dec_inputs):
enc_outputs, enc_self_attns = self.encoder(enc_inputs)
dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
dec_logits = self.projection(dec_outputs)
return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns
6. 模型训练与评估
6.1 训练配置
python复制# 超参数
d_model = 512
d_ff = 2048
d_k = d_v = 64
n_layers = 6
n_heads = 8
num_epochs = 50
# 初始化模型
model = Transformer()
criterion = nn.CrossEntropyLoss(ignore_index=0) # 忽略pad的损失
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.99)
6.2 训练循环
python复制loss_history = []
for epoch in range(num_epochs):
model.train()
total_loss = 0
for enc_inputs, dec_inputs, dec_outputs in train_loader:
# 前向传播
outputs, _, _, _ = model(enc_inputs, dec_inputs)
loss = criterion(outputs, dec_outputs.view(-1))
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
loss_history.append(avg_loss)
print(f'Epoch {epoch+1}, Loss: {avg_loss:.4f}')
6.3 模型评估
python复制model.eval()
correct = error = 0
translation_results = []
for enc_inputs, dec_inputs, dec_outputs in test_loader:
outputs, _, _, _ = model(enc_inputs, dec_inputs)
outputs = outputs.squeeze()
# 生成预测序列
pred_seq = []
for output in outputs:
next_token = output.argmax().item()
if next_token == tgt_vocab['<eos>']:
break
pred_seq.append(next_token)
# 处理真实序列
tgt_seq = dec_outputs.squeeze().tolist()
if tgt_vocab['<eos>'] in tgt_seq:
tgt_seq = tgt_seq[:tgt_seq.index(tgt_vocab['<eos>'])]
# 计算准确率
for i in range(len(tgt_seq)):
if i >= len(pred_seq) or pred_seq[i] != tgt_seq[i]:
error += 1
else:
correct += 1
translation_results.append((
' '.join(tgt_vocab[tgt_seq]),
' '.join(tgt_vocab[pred_seq])
))
print(f'Character Accuracy: {correct/(correct+error):.2%}')
7. 关键技术与经验分享
7.1 注意力机制实现细节
- 掩码处理:在实现注意力机制时,正确处理掩码至关重要。我们实现了两种掩码:
- Pad掩码:屏蔽填充位置,防止模型关注无意义的填充token
- 未来掩码:确保解码器在预测时只能看到当前位置及之前的信息
python复制def get_attn_pad_mask(seq_q, seq_k):
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)
return pad_attn_mask.expand(seq_q.size(0), seq_q.size(1), seq_k.size(1))
def get_attn_subsequence_mask(seq):
attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
subsequence_mask = np.triu(np.ones(attn_shape), k=1)
return torch.from_numpy(subsequence_mask).byte()
- 多头注意力的维度变换:在多头注意力中,我们需要将d_model维度的输入拆分为n_heads个头,每个头有d_k维度:
python复制Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)
7.2 训练技巧
-
学习率调度:虽然我们使用了固定学习率,但在实际应用中,可以考虑使用学习率预热(warmup)策略,这在Transformer论文中被证明有效。
-
标签平滑:对于分类任务,可以使用标签平滑(label smoothing)来防止模型对预测结果过于自信,提高泛化能力。
-
梯度裁剪:当使用较大的学习率时,梯度裁剪可以防止梯度爆炸问题。
7.3 常见问题排查
-
模型不收敛:
- 检查数据预处理是否正确,特别是特殊token的处理
- 验证注意力掩码是否正确应用
- 尝试降低学习率或使用学习率预热
-
过拟合:
- 增加dropout层
- 使用更大的训练数据集
- 尝试权重衰减(weight decay)
-
训练速度慢:
- 检查是否使用了GPU加速
- 增大batch size(在显存允许范围内)
- 使用混合精度训练
8. 模型优化与扩展
8.1 性能优化
-
内存优化:对于长序列,注意力计算的内存消耗是O(n²)。可以考虑使用稀疏注意力或内存高效的注意力实现。
-
计算优化:使用PyTorch的torch.jit.script或ONNX导出模型,可以获得更好的推理性能。
8.2 功能扩展
-
多语言支持:可以通过共享词表或添加语言嵌入来实现多语言翻译。
-
预训练微调:可以基于大规模语料预训练Transformer,然后在特定任务上微调。
-
模型压缩:通过知识蒸馏或量化,可以减小模型大小,提高推理速度。
9. 实际应用中的考量
在实际项目中应用Transformer时,还需要考虑以下方面:
-
数据流水线优化:使用PyTorch的DataLoader的num_workers参数实现并行数据加载,提高GPU利用率。
-
混合精度训练:使用torch.cuda.amp模块可以显著减少显存占用并加速训练。
-
模型部署:考虑使用TorchScript或ONNX格式导出模型,便于生产环境部署。
-
监控与可视化:使用TensorBoard或WandB等工具监控训练过程,可视化注意力权重。
10. 总结与展望
通过这个完整的Transformer实现,我们深入理解了自注意力机制的工作原理以及Transformer架构的精妙设计。虽然我们实现的只是一个简单的音标转换任务,但同样的架构经过适当扩展,可以应用于机器翻译、文本摘要、问答系统等各种NLP任务。
Transformer的成功不仅在于其出色的性能,更在于其通用性和可扩展性。从最初的Transformer到后来的BERT、GPT等模型,这一架构已经彻底改变了自然语言处理领域。理解Transformer的实现细节,是掌握现代NLP技术的重要基础。