注意力机制(Attention Mechanism)最初源于人类视觉系统的启发。当我们观察一幅画时,并不会均匀处理所有视觉信息,而是会聚焦于特定区域。这种生物特性被抽象为计算模型中的权重分配机制。
核心数学表达可以简化为:
code复制Attention(Q,K,V) = softmax(QK^T/√d_k)V
其中Q(Query)代表当前需要计算的特征,K(Key)是待比较的特征集合,V(Value)是实际的特征值。这个公式实现了三个关键功能:
实际实现时需要注意:当d_k(特征维度)较大时,点积结果会变得极大,导致softmax梯度消失,因此需要√d_k进行缩放。
标准的自注意力实现包含以下步骤:
python复制class SelfAttention(nn.Module):
def __init__(self, embed_size):
super().__init__()
self.embed_size = embed_size
# 初始化Q,K,V的线性变换矩阵
self.values = nn.Linear(embed_size, embed_size, bias=False)
self.keys = nn.Linear(embed_size, embed_size, bias=False)
self.queries = nn.Linear(embed_size, embed_size, bias=False)
def forward(self, x):
# 获取batch大小
N = x.shape[0]
# 生成Q,K,V
Q = self.queries(x) # (N, seq_len, embed_size)
K = self.keys(x) # (N, seq_len, embed_size)
V = self.values(x) # (N, seq_len, embed_size)
# 计算注意力分数
energy = torch.matmul(Q, K.permute(0,2,1)) # (N, seq_len, seq_len)
energy = energy / (self.embed_size ** 0.5)
# 应用softmax
attention = torch.softmax(energy, dim=2)
# 加权求和
out = torch.matmul(attention, V)
return out
多头注意力的关键实现细节:
python复制class MultiHeadAttention(nn.Module):
def __init__(self, embed_size=512, heads=8):
super().__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert self.head_dim * heads == embed_size, "Embed size needs to be divisible by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads*self.head_dim, embed_size)
def forward(self, x):
N = x.shape[0]
seq_len = x.shape[1]
# 拆分输入到多个头
x = x.reshape(N, seq_len, self.heads, self.head_dim)
Q = self.queries(x)
K = self.keys(x)
V = self.values(x)
energy = torch.einsum("nqhd,nkhd->nhqk", [Q, K])
energy = energy / (self.embed_size ** 0.5)
attention = torch.softmax(energy, dim=3)
out = torch.einsum("nhql,nlhd->nqhd", [attention, V])
out = out.reshape(N, seq_len, self.heads*self.head_dim)
out = self.fc_out(out)
return out
使用einsum而非matmul可以更清晰地表达高维张量运算。实际部署时,当序列长度超过512时,建议采用内存优化的注意力实现。
| 类型 | 计算复杂度 | 适用场景 | 主要特点 |
|---|---|---|---|
| 全连接注意力 | O(n²) | 短文本处理 | 标准实现,计算所有位置关系 |
| 局部注意力 | O(n*w) | 长序列处理 | 只计算窗口w内的位置关系 |
| 稀疏注意力 | O(n√n) | 超长序列 | 按规则跳过部分位置计算 |
| 轴向注意力 | O(n) | 图像处理 | 分别处理高度和宽度维度 |
| 线性注意力 | O(n) | 实时系统 | 用核函数近似softmax |
处理长序列时的实用技巧:
python复制from torch.utils.checkpoint import checkpoint
output = checkpoint(self.attention, x)
python复制scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
output = model(x)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
python复制def plot_attention(attention_weights, src_text, tgt_text):
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(111)
cax = ax.matshow(attention_weights, cmap='bone')
fig.colorbar(cax)
ax.set_xticklabels([''] + src_text, rotation=90)
ax.set_yticklabels([''] + tgt_text)
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
plt.show()
python复制# 原始实现
attention = torch.softmax(Q @ K.T / sqrt(d_k), dim=-1) @ V
# 优化实现(减少中间变量)
attention = torch.matmul(
torch.softmax(
torch.matmul(Q, K.transpose(-2,-1)) / math.sqrt(d_k),
dim=-1
),
V
)
使用TVM进行自定义算子融合:
python复制@tvm.register_func("attention_fused")
def attention_fused(q, k, v):
# 融合softmax和矩阵乘
...
python复制model = torch.quantization.quantize_dynamic(
model, {nn.Linear}, dtype=torch.qint8
)
python复制model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(model, inplace=True)
可能原因:
解决方案:
python复制# 修改初始化
nn.init.xavier_uniform_(self.query.weight)
nn.init.xavier_uniform_(self.key.weight)
# 添加层归一化
self.layer_norm = nn.LayerNorm(embed_size)
优化方案对比:
| 方法 | 速度提升 | 精度损失 | 实现难度 |
|---|---|---|---|
| 局部注意力 | 3-5x | <1% | 低 |
| 稀疏注意力 | 5-8x | 1-3% | 中 |
| 线性注意力 | 10x+ | 3-5% | 高 |
实际测试表明,对于512-1024长度的序列,采用局部注意力配合梯度检查点是最佳平衡方案。