Self-Attention(自注意力)是Transformer架构的核心组件,它彻底改变了传统序列建模的方式。我第一次接触这个概念是在2017年那篇著名的《Attention Is All You Need》论文中,当时就被它优雅的设计所震撼。与RNN需要逐步处理序列不同,Self-Attention允许模型直接计算序列中任意两个元素之间的关系权重,这种全局视角带来了显著的性能提升。
举个生活中的例子:当你阅读这篇文章时,眼睛会不自觉地关注当前正在阅读的词语(高注意力),同时也会根据上下文关系偶尔回看前文某些关键词(中等注意力),而对已经理解过的辅助词则几乎不再关注(低注意力)。Self-Attention正是模拟了这种动态的注意力分配机制。
假设我们有一个包含n个token的输入序列X ∈ ℝ^(n×d_model),其中d_model是嵌入维度(通常为512或768)。首先通过三个不同的线性变换得到:
这里d_k和d_v通常是相同的维度(如64)。这三个矩阵分别代表:
实际实现时,这三个线性变换可以通过一个大的矩阵乘法和分割操作高效完成,这是工程实现的重要优化点。
注意力分数的核心计算公式为:
Attention(Q,K,V) = softmax(QK^T/√d_k)V
分步解析:
python复制# 简化版的PyTorch实现
def scaled_dot_product_attention(Q, K, V):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
attn_weights = F.softmax(scores, dim=-1)
return torch.matmul(attn_weights, V)
原始论文进一步提出了Multi-Head Attention,将上述过程并行执行h次(通常h=8):
数学表达:
MultiHead(Q,K,V) = Concat(head_1,...,head_h)W_O
其中 head_i = Attention(QW_Q^i, KW_K^i, VW_V^i)
这种设计允许模型在不同表示子空间中学习不同的注意力模式,比如一个头关注局部语法关系,另一个头关注长距离语义依赖。
在实际部署中,我们需要考虑几个关键优化点:
python复制# 带掩码的多头注意力实现示例
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 线性变换并分头
Q = self.W_Q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
K = self.W_K(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
V = self.W_V(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
# 计算缩放点积注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V)
# 合并多头输出
output = output.transpose(1,2).contiguous().view(batch_size, -1, self.d_model)
return self.W_O(output)
Self-Attention的计算复杂度为O(n²·d),主要来自:
相比之下,RNN的复杂度是O(n·d²)。因此对于长序列(如n>d),Self-Attention的计算开销会显著增加。这也是后续研究提出稀疏注意力、线性注意力等变体的主要原因。
原始Transformer使用绝对位置编码,后续工作如《Self-Attention with Relative Position Representations》提出了相对位置编码:
e_ij = (x_i + p_i)W_Q((x_j + p_j)W_K)^T
= x_iW_QW_K^Tx_j^T + x_iW_QW_K^Tp_j^T + p_iW_QW_K^Tx_j^T + p_iW_QW_K^Tp_j^T
其中只有第二、三项与相对位置有关。改进方法直接建模相对位置:
a_ij = x_iW_Q(x_jW_K + r_ij)^T
其中r_ij是学习到的相对位置嵌入。
为了降低O(n²)复杂度,研究者提出了多种稀疏注意力:
理解模型学到的注意力模式非常重要:
python复制def plot_attention(attention_weights, source, target):
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(111)
cax = ax.matshow(attention_weights, cmap='bone')
ax.set_xticklabels([''] + source, rotation=90)
ax.set_yticklabels([''] + target)
plt.show()
# 示例:可视化编码器最后一层的自注意力
layer = 5
head = 3
attention_weights = model.encoder.layers[layer].self_attn.attn[0,head].data
plot_attention(attention_weights, tokens, tokens)
注意力权重过于均匀:
训练不稳定:
长序列性能下降:
不同深度学习框架在实现细节上可能有差异:
| 特性 | PyTorch实现 | TensorFlow实现 |
|---|---|---|
| 多头处理 | view + transpose | reshape + transpose |
| 掩码处理 | masked_fill(-1e9) | tf.where + large negative |
| 梯度计算 | 默认开启 | 需要明确控制梯度流 |
如《Longformer》提出的混合注意力模式:
《Memformer》等工作通过压缩记忆来扩展上下文长度:
在多模态任务中,Self-Attention可以自然扩展为交叉注意力:
我在实际项目中发现,理解Self-Attention的核心机制后,可以根据具体任务需求灵活调整注意力计算方式。比如在处理长文档时,采用分块稀疏注意力可以显著提升效率;而在需要精细对齐的任务(如机器翻译)中,标准的全局注意力仍然表现最佳。