注意力机制在序列到序列(Seq2Seq)模型中的应用,是深度学习领域近年来最具突破性的技术之一。我在自然语言处理项目中首次尝试注意力机制是在2017年,当时为了提升机器翻译质量,传统Seq2Seq模型在长句子翻译上的表现令人沮丧。引入注意力机制后,BLEU评分直接提升了15个百分点,这种改进幅度让我至今记忆犹新。
传统的Seq2Seq模型使用固定长度的上下文向量来编码整个输入序列,这就像要求你在阅读完一本小说后,只用一句话来概括全部情节细节——信息丢失在所难免。而注意力机制相当于给模型装上了"可调节的聚光灯",让它能够动态地关注输入序列中最相关的部分,这种机制特别适合处理长度差异大的序列任务。
编码器通常采用双向LSTM或GRU,我在实际项目中发现GRU的计算效率更高,特别是在批量处理短文本时(如聊天机器人场景)。以下是一个典型的编码器实现:
python复制class Encoder(nn.Module):
def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
super().__init__()
self.embedding = nn.Embedding(input_dim, emb_dim)
self.rnn = nn.GRU(emb_dim, hid_dim, n_layers,
dropout=dropout, bidirectional=True)
self.fc = nn.Linear(hid_dim*2, hid_dim)
def forward(self, src):
embedded = self.embedding(src)
outputs, hidden = self.rnn(embedded)
# 合并双向输出
hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)))
return outputs, hidden
关键细节:双向RNN的最终隐藏状态需要特殊处理。我通常使用一个全连接层将双向的最终状态合并为单向量,并通过tanh激活确保数值稳定性。
Bahdanau注意力是最常用的加性注意力,但在实际部署中我发现它的计算开销较大。对于实时性要求高的应用(如语音识别),可以改用Luong的乘性注意力:
python复制class Attention(nn.Module):
def __init__(self, hid_dim):
super().__init__()
self.attn = nn.Linear(hid_dim * 2, hid_dim)
self.v = nn.Linear(hid_dim, 1, bias=False)
def forward(self, hidden, encoder_outputs):
# hidden: [batch_size, hid_dim]
# encoder_outputs: [src_len, batch_size, hid_dim*2]
src_len = encoder_outputs.shape[0]
hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs.permute(1,0,2)), dim=2)))
attention = self.v(energy).squeeze(2)
return F.softmax(attention, dim=1)
性能优化:在批量处理时,使用矩阵运算代替循环可以提升5-8倍速度。我曾在一个电商评论分类项目中将推理时间从120ms降到23ms,仅通过优化注意力得分计算部分。
解码器需要同时处理三个输入:上一个时间步的输出、当前的隐藏状态,以及注意力加权的上下文向量。这里有个容易踩坑的地方——教师强制(teacher forcing)的比例设置:
python复制class Decoder(nn.Module):
def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout, attention):
super().__init__()
self.output_dim = output_dim
self.attention = attention
self.embedding = nn.Embedding(output_dim, emb_dim)
self.rnn = nn.GRU(emb_dim + hid_dim, hid_dim, n_layers, dropout=dropout)
self.fc_out = nn.Linear(emb_dim + hid_dim*2, output_dim)
def forward(self, input, hidden, encoder_outputs):
input = input.unsqueeze(0)
embedded = self.embedding(input)
a = self.attention(hidden, encoder_outputs)
weighted = torch.bmm(a.unsqueeze(1), encoder_outputs.permute(1,0,2))
rnn_input = torch.cat((embedded, weighted.permute(1,0,2)), dim=2)
output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
prediction = self.fc_out(torch.cat((output.squeeze(0),
weighted.squeeze(1),
embedded.squeeze(0)), dim=1))
return prediction, hidden.squeeze(0), a
训练技巧:教师强制比例应采用计划采样(scheduled sampling),从100%开始线性衰减到30%左右。在对话生成任务中,这能有效缓解暴露偏差问题。
对于Seq2Seq任务,数据预处理有几个关键点常被忽视:
我常用的数据迭代器实现:
python复制def create_iterator(dataset, batch_size, device):
dataset.sort(key=lambda x: len(x[0]), reverse=True)
batches = []
for i in range(0, len(dataset), batch_size):
src = [x[0] for x in dataset[i:i+batch_size]]
trg = [x[1] for x in dataset[i:i+batch_size]]
src_pad = pad_sequence(src, padding_value=PAD_IDX)
trg_pad = pad_sequence(trg, padding_value=PAD_IDX)
batches.append((src_pad.to(device), trg_pad.to(device)))
return batches
标准的训练循环需要添加几个关键改进:
python复制def train(model, iterator, optimizer, criterion, clip):
model.train()
epoch_loss = 0
for i, (src, trg) in enumerate(iterator):
optimizer.zero_grad()
output, _ = model(src, trg)
loss = criterion(output[1:].view(-1, output.shape[2]),
trg[1:].view(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(iterator)
实际经验:在8GB显存的GPU上,当批量大小为128时,如果序列长度超过50词就可能导致OOM。这时可以采用动态批量(dynamic batching)策略,根据序列长度调整批量大小。
| 问题现象 | 可能原因 | 解决方案 |
|---|---|---|
| 损失值NaN | 学习率过高/梯度爆炸 | 添加梯度裁剪,降低学习率10倍 |
| 输出重复词 | 注意力分布过于分散 | 增加注意力温度参数 |
| 长序列质量差 | 编码器信息丢失 | 改用Transformer或CNN编码器 |
| 推理结果乱码 | 解码器未重置隐藏状态 | 确保每个样本推理时初始化隐藏状态 |
理解模型关注点的最佳方式是可视化注意力权重。这个简单的可视化函数曾帮我发现模型错误关注标点符号的问题:
python复制def plot_attention(src, trg, attention):
fig = plt.figure(figsize=(12,8))
ax = fig.add_subplot(111)
cax = ax.matshow(attention.numpy(), cmap='bone')
ax.set_xticklabels([''] + src, rotation=90)
ax.set_yticklabels([''] + trg)
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
plt.show()
单头注意力在处理复杂关系时可能不足。多头注意力将查询、键和值线性投影到多个子空间,可以并行捕捉不同方面的关系:
python复制class MultiHeadAttention(nn.Module):
def __init__(self, hid_dim, n_heads, dropout):
super().__init__()
self.hid_dim = hid_dim
self.n_heads = n_heads
self.head_dim = hid_dim // n_heads
self.fc_q = nn.Linear(hid_dim, hid_dim)
self.fc_k = nn.Linear(hid_dim, hid_dim)
self.fc_v = nn.Linear(hid_dim, hid_dim)
self.fc_o = nn.Linear(hid_dim, hid_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
batch_size = query.shape[0]
Q = self.fc_q(query).view(batch_size, -1, self.n_heads, self.head_dim).transpose(1,2)
K = self.fc_k(key).view(batch_size, -1, self.n_heads, self.head_dim).transpose(1,2)
V = self.fc_v(value).view(batch_size, -1, self.n_heads, self.head_dim).transpose(1,2)
energy = torch.matmul(Q, K.transpose(-2,-1)) / math.sqrt(self.head_dim)
if mask is not None:
energy = energy.masked_fill(mask==0, -1e10)
attention = torch.softmax(energy, dim=-1)
x = torch.matmul(self.dropout(attention), V)
x = x.transpose(1,2).contiguous().view(batch_size, -1, self.hid_dim)
x = self.fc_o(x)
return x, attention
对于需要从输入直接复制内容的任务(如对话系统中的实体识别),可以结合指针网络:
python复制class PointerGenerator(nn.Module):
def __init__(self, hid_dim):
super().__init__()
self.pointer = nn.Linear(hid_dim*3, 1)
def forward(self, decoder_hidden, encoder_outputs, decoder_input):
p_gen = torch.sigmoid(self.pointer(torch.cat(
[decoder_hidden, encoder_outputs, decoder_input], dim=-1)))
return p_gen
在最后的项目部署中,我将模型转换为TorchScript时遇到了注意力缓存的问题。解决方案是重写注意力得分的计算方式,避免使用动态形状的操作。这个教训让我明白:生产环境中的模型实现往往需要与实验版本有所区别。