1. 从零实现 Attention 机制:深入理解 Transformer 的核心
在自然语言处理领域,Transformer 架构已经成为大语言模型(LLM)的基石。从 GPT 系列到 LLaMA、Mistral,几乎所有主流的大模型都基于 Transformer。而 Transformer 的核心,就是 Attention 机制。本文将带你从零开始,手把手实现完整的 Attention 机制,包括 Scaled Dot-Product Attention、Multi-Head Attention、Grouped Query Attention 和 KV Cache 优化。
1.1 为什么需要深入理解 Attention 机制?
很多人在学习 Attention 时,往往只停留在公式层面:
python复制Attention(Q, K, V) = softmax(QK^T / √d_k) @ V
但真正理解 Attention,需要从代码实现开始!通过亲手实现,你将:
- 深入理解 Attention 的数学原理和计算过程
- 掌握 Multi-Head Attention 的实现细节
- 理解 Grouped Query Attention (GQA) 的优化思想
- 学会 KV Cache 的性能优化技巧
- 为学习更高级的优化技术打下基础
2. Scaled Dot-Product Attention:Attention 的基础
2.1 核心公式解析
Scaled Dot-Product Attention 的核心公式如下:
python复制Attention(Q, K, V) = softmax(QK^T / √d_k) @ V
这个看似简单的公式包含了几个关键操作:
- 矩阵乘法(QK^T):计算查询(Query)和键(Key)之间的相似度
- 缩放(/ √d_k):对相似度进行缩放处理
- Softmax:将相似度转换为概率分布
- 加权求和(@ V):用概率分布对值(Value)进行加权
2.2 为什么需要 scaling factor (√d_k)?
当 d_k(键的维度)很大时,QK^T 的点积值会变得很大,导致 softmax 进入饱和区域,梯度变得很小。除以 √d_k 可以稳定训练过程。
数学解释:
- 假设 Q 和 K 的元素是独立同分布的随机变量,均值为0,方差为1
- 那么 QK^T 的每个元素的方差就是 d_k
- 除以 √d_k 后,方差变为1,保持了数值稳定性
2.3 完整实现代码
python复制import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(query, key, value, mask=None):
"""
实现 Scaled Dot-Product Attention
参数:
query: [batch_size, num_heads, seq_len_q, d_k]
key: [batch_size, num_heads, seq_len_k, d_k]
value: [batch_size, num_heads, seq_len_k, d_v]
mask: 可选,[batch_size, 1, seq_len_q, seq_len_k]
返回:
output: [batch_size, num_heads, seq_len_q, d_v]
attention_weights: [batch_size, num_heads, seq_len_q, seq_len_k]
"""
d_k = query.size(-1)
# 1. 计算注意力分数
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# 2. 应用 mask(因果 mask 或 padding mask)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# 3. Softmax 归一化
attention_weights = F.softmax(scores, dim=-1)
# 4. 加权求和
output = torch.matmul(attention_weights, value)
return output, attention_weights
2.4 计算过程可视化
code复制输入: Q [batch, heads, seq_q, d_k]
K [batch, heads, seq_k, d_k]
V [batch, heads, seq_k, d_v]
步骤1: Q @ K^T → [batch, heads, seq_q, seq_k] (注意力分数矩阵)
步骤2: softmax(分数 / √d_k) → [batch, heads, seq_q, seq_k] (注意力权重)
步骤3: 权重 @ V → [batch, heads, seq_q, d_v] (输出)
2.5 Mask 机制详解
在 Attention 中,mask 主要有两种用途:
- Padding Mask:处理变长序列时,屏蔽填充部分
- Causal Mask:防止解码器看到未来信息
实现示例:
python复制def create_causal_mask(size):
"""创建因果mask,防止看到未来信息"""
mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
return mask # 上三角为True,需要被mask
3. Multi-Head Attention:并行计算多个注意力
3.1 核心思想
Multi-Head Attention 的主要思想是:
- 将输入投影到多个子空间(多个头)
- 每个头独立计算 Attention
- 最后拼接所有头的输出
3.2 为什么需要多头机制?
不同的头可以关注不同的信息:
- 头1:关注语法关系
- 头2:关注语义关系
- 头3:关注长距离依赖
- ...
这种并行处理方式可以增强模型的表达能力。
3.3 完整实现代码
python复制class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Q, K, V 的投影矩阵
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model) # 输出投影
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 1. 投影并分割成多个头
Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 2. 每个头独立计算 Attention
attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
# 3. 拼接所有头
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# 4. 输出投影
output = self.W_o(attn_output)
return output, attn_weights
3.4 参数量分析
对于 d_model=4096 和 num_heads=32:
- Q/K/V 投影:
3 × d_model × d_model = 3 × 4096 × 4096 - 输出投影:
d_model × d_model = 4096 × 4096 - 总计:
4 × 4096² ≈ 67.1M参数
3.5 实现注意事项
- 维度处理:注意 view 和 transpose 的顺序,确保张量形状正确
- 内存效率:对于大模型,需要考虑内存优化策略
- 并行计算:充分利用 GPU 的并行计算能力
4. Grouped Query Attention (GQA):内存与性能的平衡
4.1 问题背景
在推理阶段,需要缓存 Key 和 Value(KV Cache)。对于传统的 Multi-Head Attention:
- 32 个 Q 头 → 32 个 K 头 + 32 个 V 头
- KV Cache 内存占用巨大!
4.2 GQA 的解决方案
Grouped Query Attention 让多个 Q 头共享一组 KV 头:
- MHA : 32 Q 头 → 32 K 头 + 32 V 头 (1:1)
- GQA : 32 Q 头 → 8 K 头 + 8 V 头 (4:1)
- MQA : 32 Q 头 → 1 K 头 + 1 V 头 (32:1)
4.3 内存对比
| 类型 | KV 头数 | 内存占用 (batch=32, seq_len=2048, FP16) | 相对 MHA |
|---|---|---|---|
| MHA | 32 | 512 MB | 100% |
| GQA-8 | 8 | 128 MB | 25% |
| GQA-4 | 4 | 64 MB | 12.5% |
| MQA | 1 | 16 MB | 3.1% |
4.4 完整实现代码
python复制class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, num_q_heads, num_kv_heads):
super().__init__()
self.d_model = d_model
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.d_k = d_model // num_q_heads
self.num_groups = num_q_heads // num_kv_heads
# Q 投影:d_model → num_q_heads * d_k
self.W_q = nn.Linear(d_model, num_q_heads * self.d_k)
# K, V 投影:d_model → num_kv_heads * d_k (更少!)
self.W_k = nn.Linear(d_model, num_kv_heads * self.d_k)
self.W_v = nn.Linear(d_model, num_kv_heads * self.d_k)
# 输出投影
self.W_o = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 1. 投影
Q = self.W_q(query).view(batch_size, -1, self.num_q_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(batch_size, -1, self.num_kv_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(batch_size, -1, self.num_kv_heads, self.d_k).transpose(1, 2)
# 2. 扩展 K, V 以匹配 Q 的头数
# 每个 KV 头复制 num_groups 次
K = K.repeat_interleave(self.num_groups, dim=1) # [batch, num_q_heads, seq_len, d_k]
V = V.repeat_interleave(self.num_groups, dim=1)
# 3. 计算 Attention(与 MHA 相同)
attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
# 4. 拼接和输出投影
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
output = self.W_o(attn_output)
return output, attn_weights
4.5 工业界应用
- LLaMA 2 70B : 64 Q 头,8 KV 头 (8:1)
- Mistral 7B : 32 Q 头,8 KV 头 (4:1)
- PaLM : 128 Q 头,1 KV 头 (MQA)
4.6 性能权衡
- 参数量减少 25% (32→8 KV 头)
- KV Cache 内存减少 75%
- 模型质量保持 98% (几乎无损)
5. KV Cache:推理加速的关键优化
5.1 问题背景
在自回归生成中(如 GPT),每次只生成一个 token,但需要 attend 到所有历史 token。如果不使用缓存:
code复制生成第1个token: 计算 Attention(1个token, 1个token)
生成第2个token: 计算 Attention(2个token, 2个token) ← 重复计算!
生成第3个token: 计算 Attention(3个token, 3个token) ← 重复计算!
...
时间复杂度:O(n²),其中 n 是序列长度。
5.2 KV Cache 的解决方案
缓存已计算的 Key 和 Value,避免重复计算:
code复制Prefill 阶段(处理 prompt):
- 计算所有 token 的 K, V
- 存入缓存
Decode 阶段(生成新 token):
- 只计算新 token 的 K, V
- 从缓存读取历史的 K, V
- 拼接后计算 Attention
时间复杂度:O(n)!
5.3 完整实现代码
python复制class KVCache:
def __init__(self, batch_size, num_heads, max_seq_len, head_dim, device='cuda'):
# 预分配缓存空间
self.k_cache = torch.zeros(
batch_size, num_heads, max_seq_len, head_dim,
device=device, dtype=torch.float16
)
self.v_cache = torch.zeros(
batch_size, num_heads, max_seq_len, head_dim,
device=device, dtype=torch.float16
)
self.cache_len = 0
def update(self, key, value, start_pos=None):
"""
增量更新 KV Cache
参数:
key: [batch, num_heads, new_seq_len, head_dim]
value: [batch, num_heads, new_seq_len, head_dim]
start_pos: 新token在序列中的起始位置
返回:
k_cache: [batch, num_heads, cache_len + new_seq_len, head_dim]
v_cache: [batch, num_heads, cache_len + new_seq_len, head_dim]
"""
if start_pos is None:
start_pos = self.cache_len
# 更新缓存
end_pos = start_pos + key.size(2)
self.k_cache[:, :, start_pos:end_pos] = key
self.v_cache[:, :, start_pos:end_pos] = value
self.cache_len = end_pos
return self.k_cache[:, :, :end_pos], self.v_cache[:, :, :end_pos]
5.4 性能提升
| Prompt 长度 | 无缓存 (ms) | 有缓存 (ms) | 加速比 |
|---|---|---|---|
| 10 | 2.5 | 1.2 | 2.1x |
| 50 | 8.3 | 1.3 | 6.4x |
| 100 | 15.7 | 1.4 | 11.2x |
| 200 | 30.2 | 1.5 | 20.1x |
5.5 两个阶段详解
- Prefill 阶段:
python复制# 处理完整 prompt,初始化缓存
prompt = tokenize("Hello, how are you?")
output, _ = model(prompt, prompt, prompt, use_cache=True, start_pos=0)
- Decode 阶段:
python复制# 逐个生成 token,使用缓存
for i in range(max_gen_len):
new_token = generate_next_token()
output, _ = model(new_token, new_token, new_token,
use_cache=True, start_pos=cache_len)
6. 性能对比与总结
6.1 参数量对比(d_model=4096, num_heads=32)
| 类型 | Q/K/V 头数 | 参数量 | 相对 MHA |
|---|---|---|---|
| MHA | 32/32/32 | 67.1M | 100% |
| GQA-8 | 32/8/8 | 50.3M | 75% |
| GQA-4 | 32/4/4 | 41.9M | 62% |
| MQA | 32/1/1 | 33.6M | 50% |
6.2 KV Cache 内存对比(batch=32, seq_len=2048, FP16)
| 类型 | KV 头数 | 内存 (MB) | 相对 MHA |
|---|---|---|---|
| MHA | 32 | 512 | 100% |
| GQA-8 | 8 | 128 | 25% |
| GQA-4 | 4 | 64 | 12.5% |
| MQA | 1 | 16 | 3.1% |
6.3 关键要点总结
- Scaled Dot-Product Attention 是基础,理解 scaling factor 的作用
- Multi-Head Attention 通过并行计算多个头,捕捉不同类型的信息
- Grouped Query Attention 在质量和效率间找到平衡,是工业界的主流选择
- KV Cache 将推理时间复杂度从 O(n²) 降到 O(n),是加速的关键
7. 实际应用与扩展
7.1 大模型推理优化
在部署 LLaMA、Mistral 等大模型时:
- 使用 GQA 减少 KV Cache 内存
- 使用 KV Cache 加速生成
- 结合 FlashAttention 进一步优化
7.2 自定义 Attention 变体
基于本实现,可以轻松扩展:
- FlashAttention(内存高效)
- Sparse Attention(稀疏注意力)
- Longformer Attention(长序列)
7.3 学习路径建议
- 先理解基础 Scaled Dot-Product Attention
- 实现完整的 Multi-Head Attention
- 优化为 Grouped Query Attention
- 添加 KV Cache 支持
- 探索更高级的优化技术
8. 常见问题与调试技巧
8.1 数值不稳定问题
症状:训练过程中出现 NaN 或 inf
解决方案:
- 确保正确实现了 scaling factor (/ √d_k)
- 检查 softmax 前的数值范围
- 考虑使用更稳定的 softmax 实现
8.2 内存不足问题
症状:OOM(Out of Memory)错误
解决方案:
- 减小 batch size 或序列长度
- 使用 GQA 减少 KV Cache 内存
- 考虑混合精度训练(FP16)
8.3 性能优化技巧
- 高效矩阵乘法:利用 torch 的优化矩阵运算
- 内存布局优化:注意 contiguous() 的使用
- 并行计算:充分利用 GPU 的并行能力
9. 进阶学习方向
- FlashAttention:深入研究内存高效的 Attention 实现
- PagedAttention:了解更高效的 KV Cache 管理
- TensorRT-LLM XQA:学习工业级优化实现
- CUDA 优化:探索底层性能优化技巧
通过从零实现这些 Attention 机制,你不仅理解了原理,更掌握了实现细节和优化技巧。这些知识将为你深入理解 Transformer 架构和大语言模型打下坚实基础。