注意力机制最初来源于人类视觉系统的工作方式——我们不会同时处理视野中的所有信息,而是有选择地聚焦于关键区域。2014年,Bahdanau等人首次将这种思想引入机器翻译领域,通过动态计算源语言句子中各词对当前翻译目标词的重要性权重,显著提升了长句翻译效果。
传统RNN的序列建模存在两个根本缺陷:一是必须严格按时间步顺序计算,无法并行;二是长距离依赖容易丢失。注意力机制通过建立任意位置间的直接连接完美解决了这些问题。假设输入序列为X=(x₁,...,xₙ),计算目标位置i的表示时,注意力机制会:
关键理解:注意力权重的计算过程实际上构建了一个动态的内容寻址系统,类似于字典查询机制。查询向量q相当于检索关键词,键向量k相当于索引项,最终的输出是值向量v的加权组合。
自注意力是注意力机制的特例,其查询、键、值均来自同一输入序列。以Transformer中的实现为例,具体计算步骤如下:
python复制def scaled_dot_product_attention(Q, K, V, mask=None):
d_k = K.size(-1)
scores = Q @ K.transpose(-2,-1) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask==0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
return attn_weights @ V
点积结果的方差会随着维度dₖ增大而增长。假设q和k的分量是独立随机变量,均值为0,方差为1,则qᵀk的方差就是dₖ。缩放因子1/√dₖ确保softmax输入保持适度范围,避免梯度消失。
单头注意力的问题在于:
多头注意力的解决方案是:
数学表达:
[
\text{MultiHead}(Q,K,V) = \text{Concat}(head_1,...,head_h)W^O
]
[
\text{where } head_i = \text{Attention}(QW_i^Q,KW_i^K,VW_i^V)
]
典型配置:
PyTorch实现示例:
python复制class MultiHeadAttention(nn.Module):
def __init__(self, d_model=512, h=8):
super().__init__()
assert d_model % h == 0
self.d_k = d_model // h
self.h = h
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
batch_size = x.size(0)
# 线性变换并分头 [B,L,d_model] -> [B,L,h,d_k]
Q = self.W_q(x).view(batch_size,-1,self.h,self.d_k).transpose(1,2)
K = self.W_k(x).view(batch_size,-1,self.h,self.d_k).transpose(1,2)
V = self.W_v(x).view(batch_size,-1,self.h,self.d_k).transpose(1,2)
# 计算注意力 [B,h,L,d_k]
attn_output = scaled_dot_product_attention(Q, K, V, mask)
# 拼接并输出 [B,L,d_model]
output = attn_output.transpose(1,2).contiguous()\
.view(batch_size,-1,self.h*self.d_k)
return self.W_o(output)
| 类型 | 计算公式 | 特点 | 适用场景 |
|---|---|---|---|
| 加法注意力 | score=vᵀtanh(W₁q+W₂k) | 计算成本高但更灵活 | 早期机器翻译 |
| 点积注意力 | score=qᵀk | 计算高效但需缩放 | Transformer默认 |
| 相对位置注意力 | score=qᵀk + qᵀr_ | 显式编码相对位置 | 音乐生成等序列任务 |
| 稀疏注意力 | 只计算局部或特定位置的得分 | 降低O(n²)复杂度 | 超长序列处理 |
内存优化:
计算加速:
python复制# 使用FlashAttention (Dao et al., 2022)
from flash_attn import flash_attention
output = flash_attention(q, k, v, causal=True)
稳定训练:
注意力权重过于均匀:
某些头完全不活跃:
长序列效果差:
ViT将图像分块为16×16的patch序列,通过多头注意力实现全局建模。关键改进:
使用自注意力建模氨基酸残基间的相互作用:
结合CNN与自注意力的混合架构:
高效注意力机制:
注意力可解释性:
与其他机制的融合:
实践建议:当首次实现自注意力时,建议可视化注意力权重矩阵,观察模型实际学习到的关注模式。例如在机器翻译中,理想的对角线模式表示对齐关系,而分散的注意力可能捕捉到语法结构。