在Transformer架构中,注意力机制的计算复杂度一直是制约模型规模扩展的关键瓶颈。传统注意力计算需要将整个N×N的注意力矩阵存储在显存中,当序列长度N增大时,显存消耗呈平方级增长。FlashAttention通过创新的分块计算策略,从根本上解决了这一问题。
关键突破:FlashAttention的核心思想是避免在显存中完整存储N×N注意力矩阵,而是将计算分解为适合SRAM的小块,通过增量计算直接得到最终输出。
标准注意力计算流程:
这种方法的显存消耗主要来自:
FlashAttention的分块计算策略:
python复制# 伪代码示例
for i in range(0, N, B_q): # 按查询块遍历
Q_block = Q[i:i+B_q] # 加载查询块到SRAM
O_partial = zeros(B_q, d) # 初始化部分输出
for j in range(0, N, B_k): # 按键值块遍历
K_block = K[j:j+B_k] # 加载键块
V_block = V[j:j+B_k] # 加载值块
# 在SRAM中计算小块注意力
S_block = Q_block @ K_block.T / sqrt(d)
P_block = softmax(S_block)
O_partial += P_block @ V_block # 累加部分结果
O[i:i+B_q] = O_partial # 写回最终输出
选择合适的分块大小(B_q, B_k)需要考虑SRAM的容量限制。典型配置下(d=128,SRAM=192KB),各组件在FP16下的内存占用:
| 组件 | 大小计算 | 示例(B=128) |
|---|---|---|
| Q_block | B_q×d×2 | 32KB |
| K_block | B_k×d×2 | 32KB |
| V_block | B_k×d×2 | 32KB |
| S_block | B_q×B_k×2 | 32KB |
| O_block | B_q×d×2 | 32KB |
| 统计量 | B_q×8 | ~1KB |
| 总计 | 2×(2Bd+B²) | ~161KB |
实际应用中,分块大小通常选择64-256之间的2的幂次方,以平衡计算效率和内存使用。当B=128时,各组件能很好地适配192KB的SRAM容量。
标准注意力与FlashAttention在N=4096,B=128,d=128时的HBM访问量对比:
| 操作 | 标准注意力 | FlashAttention |
|---|---|---|
| Q矩阵 | 1MB(读) | 1MB(读) |
| K矩阵 | 1MB(读) | 32MB(32次读) |
| V矩阵 | 1MB(读) | 32MB(32次读) |
| S矩阵 | 32MB(读写) | 0 |
| P矩阵 | 32MB(读写) | 0 |
| O矩阵 | 1MB(写) | 1MB(写) |
| 总计 | 67MB(读)+65MB(写) | 65MB(读)+1MB(写) |
虽然FlashAttention需要多次读取K/V矩阵,但由于:
因此实际带宽消耗显著降低,且随着N增大,优势更加明显。
在自回归模型中,因果掩码使得注意力矩阵呈下三角形状。FlashAttention对此进行了专门优化:
因果注意力下的分块策略示例(N=8,B=2):
code复制K₀K₁K₂K₃K₄K₅K₆K₇
Q₀Q₁ ■ ■ ■ ■ ■ ■ ■ ■
Q₂Q3 □ □ ■ ■ ■ ■ ■ ■
Q₄Q5 □ □ □ □ ■ ■ ■ ■
Q₆Q7 □ □ □ □ □ □ ■ ■
■ 需要计算的块 □ 可跳过的块
传统softmax需要全局归一化:
code复制softmax(x)_i = exp(x_i) / sum(exp(x_j) for j in 1..N)
在分块计算中,每个块只能看到部分数据,无法直接计算完整softmax。FlashAttention采用在线softmax算法解决这一问题。
核心思想:维护运行最大值和求和,逐步修正计算结果:
python复制def online_softmax(Q_block, K_block, V_block, prev_max, prev_sum):
# 计算当前块的注意力分数
S_block = Q_block @ K_block.T / sqrt(d)
# 更新运行最大值
current_max = max(prev_max, rowmax(S_block))
# 修正之前的累加结果
correction = exp(prev_max - current_max)
running_sum = prev_sum * correction
running_out = prev_out * correction
# 计算当前块的贡献
P_block = exp(S_block - current_max)
running_sum += rowsum(P_block)
running_out += P_block @ V_block
return running_out, current_max, running_sum
这种方法的优势:
| 数据类型 | 是否必须存储 | FlashAttention作用 |
|---|---|---|
| 注意力矩阵(S,P) | 可避免 | 通过分块计算消除存储 |
| KV缓存 | 必须存储 | 仅优化读取方式 |
虽然FlashAttention显著减少了注意力计算的内存需求,但KV缓存仍需要O(N)的存储空间。对于大模型长上下文场景,KV缓存可能成为新的瓶颈。
以LLaMA-2 70B模型为例:
KV缓存大小 = 2 × 80 × 8 × 100,000 × 128 × 2 = ~32GB
结构优化:
量化压缩:
内存管理:
计算优化:
理想的分块大小应平衡:
经验法则:
在自回归解码时:
此时FlashAttention的优势有限,因为:
批量解码时FlashAttention仍能带来收益:
| 序列长度(N) | 标准注意力 | FlashAttention | 节省倍数 |
|---|---|---|---|
| 4,096 | 132MB | 66MB | 2× |
| 8,192 | 528MB | 130MB | 4× |
| 16,384 | 2,112MB | 258MB | 8× |
趋势:随着N增大,节省效果更加显著,因O(N²)项主导标准注意力的成本。
长文本处理:
批量推理:
训练加速:
现象:在线softmax可能引入数值误差
解决方案:
现象:不同GPU的SRAM大小不同
解决方案:
挑战:需要修改底层注意力实现
最佳实践:
策略:
优势:
结合:
技术:
实现要点:
在实际项目中采用FlashAttention时,建议从中等分块大小(如128)开始,逐步调整以获得最佳性能。同时要监控不同序列长度下的实际内存使用,确保达到预期优化效果。对于超长序列场景,还需要结合KV缓存优化策略,才能实现端到端的高效处理。