2017年那篇划时代的《Attention Is All You Need》论文,彻底改变了自然语言处理的游戏规则。当时我在做机器翻译项目,第一次尝试将Transformer架构应用到生产环境时,最让我震撼的就是这个多头注意力机制(Multi-Head Attention,简称MHA)。它就像给模型装上了多组并行的"探照灯",每个头都能从不同角度捕捉文本特征。
传统RNN架构处理长距离依赖问题时,信息传递就像通过狭窄的隧道,距离越远信号衰减越严重。而MHA让任意两个词元都能直接建立联系,无论它们在序列中的物理距离有多远。这种特性使得Transformer在捕捉"巴黎是法国的首都"这类语义关系时,能直接建立"巴黎-法国"的关联,而不需要像RNN那样一步步传递信息。
假设我们有个包含3个词元的微型序列["猫","追逐","老鼠"],每个词用4维向量表示。实际项目中维度通常是512或1024,这里简化说明:
首先创建Q(查询)、K(键)、V(值)三个矩阵:
python复制# 假设每个词向量维度=4
embeddings = np.array([[0.1, 0.2, 0.3, 0.4], # 猫
[0.5, 0.6, 0.7, 0.8], # 追逐
[0.9, 1.0, 1.1, 1.2]]) # 老鼠
# 随机初始化权重矩阵(实际中通过训练得到)
W_Q = np.random.randn(4, 3) # 查询变换矩阵
W_K = np.random.randn(4, 3) # 键变换矩阵
W_V = np.random.randn(4, 3) # 值变换矩阵
Q = embeddings @ W_Q # 形状(3,3)
K = embeddings @ W_K # 形状(3,3)
V = embeddings @ W_V # 形状(3,3)
计算注意力分数(未缩放):
python复制scores = Q @ K.T # 形状(3,3)
"""
示例结果:
[[1.2, 0.8, 0.5],
[0.7, 1.5, 0.9],
[0.3, 0.6, 1.8]]
"""
应用softmax归一化:
python复制attention_weights = softmax(scores / sqrt(d_k))
# d_k是K的维度,这里=3
加权求和得到输出:
python复制output = attention_weights @ V # 形状(3,3)
真正的创新在于"多头"设计。假设设置8个头,每个头的维度为64(512/8),具体实现:
python复制class MultiHeadAttention(nn.Module):
def __init__(self, d_model=512, h=8):
super().__init__()
self.d_k = d_model // h
self.h = h
# 线性变换矩阵
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model)
def forward(self, x):
batch_size = x.size(0)
# 线性变换后切分为h个头
Q = self.W_Q(x).view(batch_size, -1, self.h, self.d_k).transpose(1,2)
K = self.W_K(x).view(batch_size, -1, self.h, self.d_k).transpose(1,2)
V = self.W_V(x).view(batch_size, -1, self.h, self.d_k).transpose(1,2)
# 计算缩放点积注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = torch.softmax(scores, dim=-1)
context = torch.matmul(attn, V)
# 合并多头输出
context = context.transpose(1,2).contiguous().view(batch_size, -1, self.h * self.d_k)
return self.W_O(context)
关键细节:每个头的Q/K/V变换矩阵是独立学习的,这使得不同头可以关注不同层面的特征。例如在翻译任务中,有的头可能专注主语-动词关系,有的头捕捉时态信息,还有的头处理指代关系。
当处理4096 tokens的序列时,显存占用会变得非常恐怖。我们团队在部署百亿参数模型时总结出这些优化手段:
Flash Attention:通过分块计算和IO感知算法,将内存复杂度从O(N²)降到O(N)。实测在A100上处理2k序列时,速度提升3.2倍:
python复制from flash_attn import flash_attention
# 替换标准注意力计算
output = flash_attention(Q, K, V)
梯度检查点:在反向传播时选择性重计算部分激活值,牺牲30%计算时间换取40%显存节省:
python复制from torch.utils.checkpoint import checkpoint
def custom_forward(Q, K, V):
return scaled_dot_product_attention(Q, K, V)
output = checkpoint(custom_forward, Q, K, V)
混合精度训练:使用FP16存储参数和激活值,关键部分保留FP32精度:
python复制scaler = GradScaler()
with autocast():
output = model(input)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
除了标准自注意力,这些变体在特定场景表现优异:
| 注意力类型 | 适用场景 | 实现关键点 |
|---|---|---|
| 滑动窗口注意力 | 长文本处理 | 限制每个token只关注附近窗口内的token |
| 稀疏注意力 | 图像/视频处理 | 基于内容相似性动态选择关注区域 |
| 线性注意力 | 实时推理场景 | 用核函数近似实现线性复杂度 |
| 交叉注意力 | 多模态任务 | 一个序列的Q与另一个序列的K/V交互 |
例如在构建视频理解模型时,我们采用稀疏注意力后,处理1分钟视频的显存需求从48GB降至16GB:
python复制from transformers import SparseAttention
config = {
'block_size': 64,
'num_local_blocks': 4,
'num_global_blocks': 1
}
model = BertModel.from_pretrained(
'bert-base',
attention_probs_dropout_prob=0.1,
custom_attention=SparseAttention(config)
)
症状:训练初期出现NaN值,或某些头的注意力权重接近one-hot分布。
解决方案:
python复制attn_weights = softmax(scores / sqrt(d_k) + 0.01*torch.randn_like(scores))
症状:某些头的输出范数明显大于其他头,导致有效头数量减少。
调试方法:
python复制# 监控各头L2范数
for i in range(num_heads):
head_output = output[:,i,:]
print(f"Head {i} norm: {torch.norm(head_output)}")
# 解决方案:添加头间正则化
loss += 0.1 * torch.var([torch.norm(output[:,i,:]) for i in range(num_heads)])
当序列超过模型最大长度限制时,这些方法值得尝试:
层次化池化:先对局部片段计算注意力,再对池化结果全局关注
python复制# 输入形状: (batch, seq_len, dim)
local_pooled = nn.AvgPool1d(kernel_size=3, stride=2)(x.transpose(1,2))
global_out = attention(local_pooled.transpose(1,2), x, x)
记忆压缩:用CNN或LSTM压缩历史信息到固定大小记忆单元
位置编码扩展:使用NTK-aware位置编码外推:
python复制def ntk_scaled_pos_emb(max_len, dim, base=10000):
scale = (max_len / 1024) ** (dim / (dim-2))
return base * scale
我们最新实验表明,让模型自主决定各头重要性可以提升3-5%性能:
python复制class DynamicHeadWeight(nn.Module):
def __init__(self, num_heads):
super().__init__()
self.gate = nn.Sequential(
nn.Linear(num_heads, num_heads*4),
nn.ReLU(),
nn.Linear(num_heads*4, num_heads)
)
def forward(self, attention_outputs): # 各头输出形状 [batch, heads, seq, dim]
head_weights = torch.sigmoid(self.gate(attention_outputs.mean(dim=(2,3))))
return (attention_outputs * head_weights.unsqueeze(-1).unsqueeze(-1)).sum(dim=1)
将大模型的多头注意力模式迁移到小模型:
python复制def distill_loss(student_attn, teacher_attn, T=2.0):
return F.kl_div(
F.log_softmax(student_attn/T, dim=-1),
F.softmax(teacher_attn/T, dim=-1),
reduction='batchmean') * (T**2)
在部署BERT-base到移动设备时,这种方法能使小模型达到教师模型92%的准确率,而推理速度提升4倍。