目录
背景和提出动机
Transformer Attention的显存瓶颈
FlashAttention数学本质与核心思想
FlashAttention算法详解与分步推导
FlashAttention底层实现原理
PyTorch中的FlashAttention实践
与常规Softmax Attention的比较
工程部署中的注意事项与参数调优
FlashAttention2及近期演进方向
总结与未来展望
参考资料与延伸阅读
1. 背景和提出动机
1.1 大模型、长序列与注意力瓶颈
在Transformer大模型迅速普及的当下,注意力机制(Attention)成为深度学习的核心操作之一。但其原始实现的空间和时间复杂度为 ,当输入序列 很长时(如千/万Token),计算和显存消耗将成为实际训练和推理的巨大障碍:
计算瓶颈:每一步都要与所有Token交互。
显存瓶颈:常规Attention实现需完整存储大规模的注意力得分矩阵与中间激活。
尤其在大模型场景下,GPU显存常常成为模型扩容和推理速率的核心瓶颈。
1.2 FlashAttention的提出
2022年,HazyResearch(斯坦福)提出FlashAttention算法,通过对Attention算子的计算与存储进行硬件友好优化,大幅降低了显存占用和延迟,甚至让许多“内存受限”的模型和任务成为可能。它已成为Llama-2、Qwen、Gemma等大模型训练/推理的事实标准,并已集成入PyTorch 2.x主分支。
2. Transformer Attention的显存瓶颈
2.1 标准Attention回顾
对输入序列 ,先分别做线性变换得到Q、K、V:
然后点积注意力为:
即:对每个Query,计算所有Key的相似度,经过Softmax归一化,再与Value线性组合。
2.2 显存消耗的来源
1. 存储所有 得分矩阵:
以典型的batch=1, seq=2048为例,单个Head即4M元素,16/32位浮点数就是8~16MB。
2. Softmax和dropout的中间激活需反向传播时保存
3. PyTorch默认实现在前向/反向传播时会复制中间结果,显存开销翻倍。
2.3 传统优化瓶颈
不能分块处理(blockwise):否则Softmax全局归一性无法保证。
不能边算边释放:因为反向传播需用到中间激活。
3. FlashAttention数学本质与核心思想
FlashAttention的核心突破在于:
用 分块(blockwise)算法处理Attention,每次只加载Q、K、V的一部分到高速缓存(寄存器/SMEM),极大降低全局显存需求;
不显式保存完整 得分矩阵,而是边计算边归一化;
利用数值稳定的Softmax递推公式,保证分块归一的正确性;
反向传播阶段重计算部分中间结果(recompute trick),进一步降低显存。
3.1 分块Softmax的数学技巧
普通Softmax写法:
数值稳定写法:
其中
分块累加技巧:
Softmax可以分步累积:
当前块最大值 ,和
下一个块最大值 ,
最终全局最大值 ,全局和 。
这样分块归一化,最终等价于全局Softmax。
4. FlashAttention算法详解与分步推导
4.1 前向传播算法步骤
将序列分为小块(Block),每次只计算和保存一小部分(如32或64 tokens)的 。
逐块累积最大值和和(数值稳定Softmax),每块单独处理Softmax,然后全局拼接,保证归一性。
每步只在高速缓存(SMEM)中操作,大幅减少对GPU主显存的访问。
伪代码(参考论文/官方实现):
for each query block:
for each key block:
# 取出当前Q, K, V小块
Q_blk, K_blk, V_blk = Q[q_start:q_end], K[k_start:k_end], V[k_start:k_end]
# 计算得分
score_blk = Q_blk @ K_blk.T / sqrt(d_k)
# 累积当前块最大值与和
max_score_blk = max(score_blk, dim=-1)
sum_exp_blk = sum(exp(score_blk - max_score_blk))
# 更新全局最大值与和
global_max = max(global_max, max_score_blk)
global_sum = global_sum * exp(prev_max - global_max) + sum_exp_blk * exp(max_score_blk - global_max)
# 得到全局Softmax归一化结果
softmax_blk = exp(score_blk - global_max) / global_sum
# 得到最终Attention输出
output_blk = softmax_blk @ V_blk
# 拼接输出
4.2 反向传播显存再优化
传统做法需保存所有
FlashAttention采用激活重计算(recompute),即正向阶段不保存全部激活,反向再用同样的分块方法重新计算。
4.3 多头并行、mask支持
分头并行自然并行于GPU不同SM
支持Causal Mask、Padding Mask(适用于自回归和BERT)
5. FlashAttention底层实现原理
5.1 CUDA优化与硬件亲和性
全部数据优先加载到shared memory(L1 cache),最大化带宽利用率;
寄存器复用,充分利用Tensor Core计算能力;
高效的blockwise循环展开,每次只处理一小段,避免大规模内存拷贝;
支持多batch、多head并行计算。
5.2 显存开销对比
常规实现:需要存储 的完整得分矩阵,反向还需保存全部Softmax。
FlashAttention:只需存储少量块的激活,显存下降至 ,理论上可支持序列长度成千上万。
5.3 兼容性
支持多种数据类型(float32/16/bfloat16)
支持Dropout(新版FlashAttention2),支持Padding、Mask
支持Causal(自回归)、非Causal(编码器)
6. PyTorch中的FlashAttention实践
6.1 PyTorch 2.0+内置的FlashAttention
PyTorch 2.0 开始,官方 torch.nn.functional.scaled_dot_product_attention 就已集成FlashAttention算法。调用方式非常简洁:
import torch
# Q, K, V shape: (batch, num_heads, seq_len, head_dim)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query=Q, key=K, value=V, attn_mask=mask, is_causal=True # 是否自回归mask
)
说明:
若设备和dtype支持,将自动启用FlashAttention内核,否则退化为标准实现。
6.2 与传统实现的对比
传统实现需如下操作:
attn_weights = (Q @ K.transpose(-2, -1)) / sqrt(d_k)
attn_weights = attn_weights.masked_fill(mask == 0, -1e9)
attn_probs = torch.softmax(attn_weights, dim=-1)
output = attn_probs @ V
而FlashAttention一行即可:
output = torch.nn.functional.scaled_dot_product_attention(Q, K, V, mask)
6.3 集成到自定义模块
例如自定义多头注意力(简化):
class FlashMultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.qkv = nn.Linear(d_model, d_model * 3)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
B, N, _ = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2,0,3,1,4)
q, k, v = qkv[0], qkv[1], qkv[2] # (B, heads, seq, head_dim)
attn_out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
attn_out = attn_out.transpose(1,2).reshape(B, N, self.d_model)
return self.out_proj(attn_out)
6.4 Huggingface Transformers集成
目前主流大模型(Llama、Qwen、Gemma等)均已集成FlashAttention作为核心注意力算子。可通过设置use_flash_attention_2=True自动启用。
例如:
from transformers import LlamaForCausalLM
model = LlamaForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf', torch_dtype=torch.bfloat16, device_map="auto")
model.config.use_flash_attention_2 = True
7. 与常规Softmax Attention的比较
7.1 时间复杂度
两者均为 ,但FlashAttention利用了blockwise,极大提升了数据带宽利用率。
7.2 空间复杂度
常规Attention:
FlashAttention:
7.3 性能对比(官方数据)
训练速度提升:最多3倍以上
显存降低:最高降低10倍
长序列推理速度提升:Llama等大模型在长序列推理场景能显著提升吞吐
7.4 可解释性与数值精度
完全等价于标准Attention(通过分块softmax递推证明)
支持所有原生Attention特性(Dropout, Mask, Padding等)
8. 工程部署中的注意事项与参数调优
8.1 部署环境要求
PyTorch >=2.0
Nvidia Ampere(A100)或更新架构GPU最佳,V100/3090支持但速度有限
推荐bfloat16/fp16加速(但float32也支持)
8.2 参数选择
分块大小:默认32/64,通常无需手动设置
batch/seq头数:多头并行会获得更大速度收益
mask选择:训练时可用padding mask;自回归推理用causal mask
8.3 与其他高效Attention算法的对比
算法优势劣势适用场景FlashAttention完全等价全局Softmax,硬件友好仅加速/显存优化,不改变O(n^2)长序列/大batchPerformer显式稀疏近似精度略降超长序列Longformer/Sparse结构稀疏,O(n)局部上下文超长序列局部交互Linformer低秩近似有信息损失信息可压缩任务
9. FlashAttention2及近期演进方向
9.1 FlashAttention2新特性
更快的内核,完全重写kernel,进一步优化访存和流水线
更强的mask支持,更灵活
支持dropout,原版FlashAttention不支持
多流并行,支持多任务/多样本场景
9.2 最新进展
Huggingface、PyTorch主干均集成FlashAttention2
支持分布式训练(DeepSpeed、FSDP等)
大模型社区将其视为默认Attention算子
9.3 未来趋势
进一步结合低秩/稀疏/线性Attention
融合FlashAttention+Rope/动态Mask等创新
持续优化推理和多模态支持
10. 总结与未来展望
FlashAttention极大改变了Transformer大模型训练/推理的性能边界,推动了长文本、超大Batch等任务的落地。它没有改变Attention的表达能力,仅仅是极致优化了存储和带宽效率,是工程和算法的双重创新。
未来,大模型推理和训练会越来越依赖类似FlashAttention这样的硬件友好、近似零损失的优化算子。随着FlashAttention3、内存亲和Attention等新算法问世,我们很快将能在桌面级甚至边缘设备上运行原本需要几百GB显存的大模型。
11. 参考资料与延伸阅读
FlashAttention论文:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
官方PyTorch文档:torch.nn.functional.scaled_dot_product_attention
论文代码实现:FlashAttention Github
PyTorch源码剖析:pytorch/aten/src/ATen/native/transformers/attention.cpp
Huggingface集成示例:Huggingface Llama2代码
高性能显存分析:NVIDIA Ampere Tensor Core
知乎深度解析:FlashAttention原理和工程实现详解
B站讲解视频:B站:FlashAttention深度剖析
FlashAttention 实践:全流程代码演示
最后,以代码实战收尾,帮助你将理论与工程落地结合:
import torch
import torch.nn as nn
class FlashAttentionBlock(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.qkv_proj = nn.Linear(d_model, d_model * 3)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x, attn_mask=None, is_causal=True):
B, N, _ = x.shape
qkv = self.qkv_proj(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2,0,3,1,4)
q, k, v = qkv[0], qkv[1], qkv[2] # (B, heads, seq, head_dim)
attn_out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=is_causal
)
attn_out = attn_out.transpose(1,2).reshape(B, N, self.d_model)
return self.out_proj(attn_out)
# 测试
batch_size = 4
seq_len = 1024
d_model = 512
num_heads = 8
x = torch.randn(batch_size, seq_len, d_model).cuda()
block = FlashAttentionBlock(d_model, num_heads).cuda()
out = block(x) # 超长序列依然流畅
print(out.shape) # [4, 1024, 512]