在大型语言模型的实际应用中,理解推理过程的核心机制至关重要。今天我将结合工程实践经验,深入剖析自回归推理中的两个关键阶段:Prefill(预填充)和Decode(解码)。这两个阶段不仅仅是概念上的划分,它们直接影响着模型的推理效率、资源消耗和实际部署策略。
当我们使用语言模型生成文本时,整个过程可以分为两个本质不同的计算阶段。这种划分源于KV缓存(Key-Value缓存)的构建和使用方式:
这种划分不是人为的,而是由自回归生成的基本特性决定的。在prefill阶段,我们可以并行处理所有输入token,因为它们的计算互不依赖;而在decode阶段,每个新token的生成都依赖于前一个token的输出,形成了严格的顺序依赖。
假设用户输入提示:"解释量子计算的简单概念"(假设被token化为500个token)。Prefill阶段会:
python复制def prefill(prompt_tokens):
"""处理整个提示的单一前向传播"""
N = len(prompt_tokens)
hidden_states = embed(prompt_tokens) # 形状: [N, hidden_dim]
kv_cache = {}
for layer_idx, layer in enumerate(transformer_layers):
# 为所有N个token并行计算Q,K,V
Q = hidden_states @ W_Q # 形状: [N, num_heads, head_dim]
K = hidden_states @ W_K # 形状: [N, num_heads, head_dim]
V = hidden_states @ W_V # 形状: [N, num_heads, head_dim]
# 将K,V存入缓存
kv_cache[layer_idx] = {'K': K, 'V': V}
# 计算[N, N]注意力矩阵(带因果掩码)
attention_scores = Q @ K.transpose(-1, -2) / sqrt(d)
attention_scores = apply_causal_mask(attention_scores)
attention_weights = softmax(attention_scores)
attention_output = attention_weights @ V
hidden_states = layer.ffn(layer.norm(attention_output + hidden_states))
# 只获取最后一个位置的logits(用于生成第一个token)
next_token_logits = hidden_states[-1] @ W_output
return next_token_logits, kv_cache
提示:在实际工程实现中,prefill阶段通常会利用GPU的并行计算能力,将大批量矩阵运算合并执行,这是其高效的主要原因。
以输入"The cat sat on the mat"(6个token)为例:
code复制PREFILL阶段
输入: "The cat sat on the mat" (6 tokens)
┌─────────────────────────────────────────────────────────────────┐
│ │
│ Token embeddings (并行处理) │
│ ┌─────┬─────┬─────┬─────┬─────┬─────┐ │
│ │ The │ cat │ sat │ on │ the │ mat │ │
│ └──┬──┴──┬──┴──┬──┴──┬──┴──┬──┴──┬──┘ │
│ │ │ │ │ │ │ │
│ ▼ ▼ ▼ ▼ ▼ ▼ │
│ ┌─────────────────────────────────────┐ │
│ │ Transformer Layers (×32) │ │
│ │ │ │
│ │ 对每一层: │ │
│ │ • 为所有6个token计算Q,K,V │ │
│ │ • 将K,V存入缓存 │ │
│ │ • 计算[6×6]注意力矩阵 │ │
│ │ • 应用FFN │ │
│ └─────────────────────────────────────┘ │
│ │
│ ▼ ▼ ▼ ▼ ▼ ▼ │
│ ┌─────┬─────┬─────┬─────┬─────┬─────┐ │
│ │ h₀ │ h₁ │ h₂ │ h₃ │ h₄ │ h₅ │ 最终隐藏状态 │
│ └─────┴─────┴─────┴─────┴──┬──┴─────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Logits (h₅→词表)│ │
│ │ 采样: "." │ ← 第一个生成token │
│ └─────────────────┘ │
│ │
│ KV缓存现在包含: │
│ ┌────────────────────────────────────────┐ │
│ │ Layer 0: K₀,K₁,K₂,K₃,K₄,K₅ │ V₀...V₅ │ │
│ │ Layer 1: K₀,K₁,K₂,K₃,K₄,K₅ │ V₀...V₅ │ │
│ │ ... │ │
│ │ Layer 31: K₀...K₅ │ V₀...V₅ │ │
│ └────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
Prefill完成后,我们获得第一个生成token和初始化的KV缓存,随后进入循环:
python复制def decode_one_token(new_token, kv_cache):
"""处理单个新token,使用并扩展KV缓存"""
hidden_states = embed([new_token]) # 形状: [1, hidden_dim]
for layer_idx, layer in enumerate(transformer_layers):
# 仅为新token计算Q,K,V
Q_new = hidden_states @ W_Q # 形状: [1, num_heads, head_dim]
K_new = hidden_states @ W_K # 形状: [1, num_heads, head_dim]
V_new = hidden_states @ W_V # 形状: [1, num_heads, head_dim]
# 读取缓存的K和V
K_cached = kv_cache[layer_idx]['K'] # 形状: [seq_len, num_heads, head_dim]
V_cached = kv_cache[layer_idx]['V'] # 形状: [seq_len, num_heads, head_dim]
# 将新K,V追加到缓存
K_full = concat([K_cached, K_new], dim=0)
V_full = concat([V_cached, V_new], dim=0)
kv_cache[layer_idx] = {'K': K_full, 'V': V_full}
# 注意力: Q_new关注所有key(完整序列)
attention_scores = Q_new @ K_full.transpose(-1, -2) / sqrt(d) # [1, seq_len+1]
attention_weights = softmax(attention_scores)
attention_output = attention_weights @ V_full
hidden_states = layer.ffn(layer.norm(attention_output + hidden_states))
next_token_logits = hidden_states[0] @ W_output
return next_token_logits, kv_cache
假设KV缓存已包含位置0-5的K,V(来自prefill),现在处理新token"."(位置6):
code复制DECODE阶段(单一步骤)
KV缓存状态: 包含位置0-5的K,V
新token处理: "." (位置6)
┌─────────────────────────────────────────────────────────────────┐
│ │
│ 输入: 单个token "." │
│ ┌─────┐ │
│ │ . │ │
│ └──┬──┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ Transformer Layers (×32) │ │
│ │ │ │
│ │ 对每一层: │ │
│ │ ┌─────────────────────────────────────────────────┐ │ │
│ │ │ 1. 计算Q₆,K₆,V₆(仅对新token) │ │ │
│ │ │ │ │ │
│ │ │ 2. 从缓存读取: K₀...K₅, V₀...V₅ │ │ │
│ │ │ ┌─────────────────────────────┐ │ │ │
│ │ │ │ K_cache: [K₀,K₁,K₂,K₃,K₄,K₅] │ │ │ │
│ │ │ │ V_cache: [V₀,V₁,V₂,V₃,V₄,V₅] │ │ │ │
│ │ │ └─────────────────────────────┘ │ │ │
│ │ │ │ │ │
│ │ │ 3. 注意力: Q₆ @ [K₀...K₆]ᵀ → [1×7] scores │ │ │
│ │ │ │ │ │
│ │ │ 4. 追加到缓存: K₆, V₆ │ │ │
│ │ │ ┌────────────────────────────────┐ │ │ │
│ │ │ │ K_cache: [K₀,K₁,K₂,K₃,K₄,K₅,K₆] │ │ │ │
│ │ │ │ V_cache: [V₀,V₁,V₂,V₃,V₄,V₅,V₆] │ │ │ │
│ │ │ └────────────────────────────────┘ │ │ │
│ │ └─────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
│ ▼ │
│ ┌─────┐ │
│ │ h₆ │ 位置6的隐藏状态 │
│ └──┬──┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Logits (h₆→词表)│ │
│ │ 采样: "The" │ ← 下一个生成token │
│ └─────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
| 特性 | Prefill | Decode |
|---|---|---|
| 每次前向传播处理的token | 所有提示token(N) | 每个步骤一个token |
| 计算的Q向量数量 | N个向量 | 1个向量 |
| 计算的K,V向量数量 | 各N个向量 | 各1个向量 |
| 注意力矩阵形状 | [N, N] | [1, seq_len] |
| KV缓存操作 | 写入(初始化) | 读取+追加 |
| 并行度 | 高(所有token一起处理) | 低(顺序依赖) |
| 运行次数 | 每个请求一次 | 每个输出token一次 |
| 跨token并行化 | 可以(在传播内) | 不可以(token i需要token i-1) |
Prefill和Decode阶段最本质的区别在于并行性:
Prefill:高度并行
Decode:本质顺序
这种并行性差异是prefill和decode具有完全不同性能特征的根本原因。
假设一个典型请求:500 token的提示,生成200个token
code复制时间 ──────────────────────────────────────────────────────────────────────►
│◄─── Prefill ───►│◄──────────────── Decode ─────────────────────────────►│
│ │ │
│ 处理500个token │ 生成 生成 生成 生成 ... 生成 生成 │
│ 单次前向传播 │ token1 token2 token3 token4 token199 token200 │
│ │ ◄──► ◄──► ◄──► ◄──► ◄──► ◄──► │
│ │ 每个decode步骤都是独立的前向传播 │
│ │ │
│ ~50ms │ ~2000ms │
│ (示例) │ (示例: 每个token 10ms × 200 tokens) │
│ │ │
总时间分解:
├─ Prefill: ~50ms (总时间2.4%)
├─ Decode: ~2000ms (总时间97.6%)
└─ 总计: ~2050ms
尽管prefill处理了500个token而decode只处理了200个,decode却耗时约40倍,因为它需要200次顺序前向传播,而prefill只需单次传播。这是关键洞察:decode主导实际耗时,尽管它处理的token更少,但无法跨token并行化。
基于两阶段特性,实践中我们采用不同优化策略:
Prefill优化:
Decode优化:
问题1:长提示导致prefill耗时过长
问题2:decode阶段吞吐量低
问题3:KV缓存内存爆炸
根据两阶段特性选择合适硬件:
| 考虑因素 | Prefill侧重 | Decode侧重 |
|---|---|---|
| 关键硬件指标 | 计算能力(TFLOPS) | 内存带宽(GB/s) |
| 推荐GPU特性 | 高FP16/FP32算力 | 高内存带宽和缓存 |
| 典型优势GPU | NVIDIA A100(矩阵计算强) | NVIDIA H100(高带宽) |
| 优化方向 | 大核心数量 | 高内存子系统效率 |
分块注意力:将长序列分块处理,减少内存压力
混合精度解码:
硬件感知架构:
打破顺序依赖:
动态稀疏注意力:
KV缓存压缩:
在实际部署中,理解prefill和decode的差异帮助我们做出更明智的决策。例如,在实时对话场景中,我们可能接受较长的prefill时间以换取更流畅的decode;而在批量处理场景中,则可能优先优化prefill的吞吐量。