学习文件attention/base.py (66 行), attention/fa.py (187 行), attention/fi.py (279 行)


1. Attention Backend 的核心职责

Attention Backend 是整个推理系统的计算核心,负责执行 Attention 计算:

1.1 职责定位

  • 输入:Q (Query)、K (Key)、V (Value) 三个矩阵
  • 输出:Attention 输出(加权求和后的结果)
  • 核心:实现高效的 Attention 计算

1.2 为什么需要抽象基类?

Mini-SGLang 支持多种 Attention 实现:

  • FlashAttention:适合 Prefill 阶段(计算密集)
  • FlashInfer:适合 Decode 阶段(访存密集)

但 Engine 只需要调用 attn_backend.forward(),不关心具体实现。

1
2
3
4
5
6
7
8
9
# engine.py
class Engine:
def __init__(self, config):
self.attn_backend: BaseAttnBackend = create_attention_backend(...)
# 可能是 FlashAttentionBackend、FlashInferBackend 或 HybridBackend

def forward_batch(self, batch):
# 统一接口,不关心具体实现
output = self.attn_backend.forward(q, k, v, layer_id, batch)

核心:抽象基类 = 统一接口 + 多态。


2. BaseAttnBackend 的接口设计

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class BaseAttnBackend(ABC):
@abstractmethod
def forward(
self,
q: torch.Tensor, # Query
k: torch.Tensor, # Key
v: torch.Tensor, # Value
layer_id: int, # 第几层
batch: Batch # 批次信息
) -> torch.Tensor: # 返回 Attention 输出
...

@abstractmethod
def prepare_metadata(self, batch: Batch) -> None:
...

@abstractmethod
def init_capture_graph(self, max_seq_len: int, bs_list: List[int]) -> None:
...

@abstractmethod
def prepare_for_capture(self, batch: Batch) -> None:
...

@abstractmethod
def prepare_for_replay(self, batch: Batch) -> None:
...

2.1 forward 方法

Attention 机制的核心公式

1
Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d)) @ V

三个矩阵的作用

  1. Q (Query):当前 token 的查询向量,“我想找什么”
  2. K (Key):所有 token 的键向量,“我是什么”
  3. V (Value):所有 token 的值向量,“我的内容是什么”

计算流程

1
2
3
1. Q @ K^T:计算相似度(当前 token 和所有 token 的相关性)
2. softmax:归一化为概率分布
3. @ V:加权求和,得到 Attention 输出

例子(简化):

1
2
3
4
5
6
7
8
9
输入句子:"The cat sat on the mat"
当前 token:"sat"

Q (sat 的查询):我想找"谁在做这个动作"
K (所有 token 的键):[The, cat, sat, on, the, mat]
相似度:[0.1, 0.8, 0.05, 0.02, 0.01, 0.02] # cat 最相关
softmax:[0.1, 0.8, 0.05, 0.02, 0.01, 0.02]
V (所有 token 的值):[The的内容, cat的内容, ...]
输出:0.8 * cat的内容 + 0.1 * The的内容 + ... # 主要关注 cat

layer_id 的作用

  • 每一层都有独立的 KV Cache
  • layer_id 用于索引对应层的 KV Cache
1
2
3
4
5
# Llama 有 32 层
for layer_id in range(32):
# 每层的 KV Cache 是独立的
k_cache = self.kvcache.k_cache(layer_id) # 第 layer_id 层的 K Cache
v_cache = self.kvcache.v_cache(layer_id) # 第 layer_id 层的 V Cache

2.2 prepare_metadata 方法

什么是 metadata?

Attention 计算需要很多辅助信息:

信息 作用 例子
positions 每个 token 的位置 [0, 1, 2, 3, ...]
cu_seqlens_q 累积序列长度(Query) [0, 3, 7, 10](3个请求,长度3,4,3)
cu_seqlens_k 累积序列长度(Key) [0, 5, 12, 18]
page_table KV Cache 的页表 每个请求的 KV 存在哪些页
cache_seqlens 每个请求的缓存长度 [5, 7, 6]

为什么需要提前准备?

Attention kernel 需要这些信息才能正确计算:

1
2
3
4
5
6
7
8
9
10
11
# 1. Scheduler 准备 batch
batch = Batch(reqs=[req1, req2, req3])

# 2. 准备 metadata
attn_backend.prepare_metadata(batch)
# 计算 positions、cu_seqlens、page_table 等
# 拷贝到 GPU

# 3. 执行 Attention
output = attn_backend.forward(q, k, v, layer_id, batch)
# 使用 batch.attn_metadata 中的信息

为什么不能在 forward 里准备?

1
2
3
4
5
6
7
8
9
# 如果在 forward 里准备 metadata
def forward(self, q, k, v, layer_id, batch):
# 1. 准备 metadata(包含内存分配、拷贝等)
positions = make_positions(device, batch.reqs) # 可能分配新内存
cu_seqlens = torch.tensor([...]).to(device) # 可能分配新内存

# 2. 执行 Attention
output = attention_kernel(q, k, v, positions, cu_seqlens)
return output

问题

  1. CUDA Graph 不允许内存分配make_positionstorch.tensor 可能分配新内存
  2. 地址不固定:每次 forward 分配的地址可能不同
  3. 无法 Capture:CUDA Graph 会报错

分离后

1
2
3
4
5
6
7
8
9
10
# prepare_metadata:在 CUDA Graph 外部调用
def prepare_metadata(self, batch):
positions = make_positions(device, batch.reqs) # 允许分配内存
batch.attn_metadata = Metadata(positions=positions)

# forward:在 CUDA Graph 内部调用
def forward(self, q, k, v, layer_id, batch):
metadata = batch.attn_metadata # 直接使用,不分配内存
output = attention_kernel(q, k, v, metadata.positions)
return output

好处

  1. prepare_metadata 在 CUDA Graph 外部调用,可以分配内存
  2. forward 在 CUDA Gra内部调用,只读取数据,不分配内存
  3. CUDA Graph 可以成功 Capture

核心:分离 = 内存分配在外部 + 计算在内部 = 支持 CUDA Graph。


2.3 CUDA Graph 相关的三个方法

init_capture_graph:初始化 CUDA Graph

1
2
3
def init_capture_graph(self, max_seq_len: int, bs_list: List[int]) -> None:
# 预分配 CUDA Graph 需要的 buffer
self.capture = FICaptureData.create(max_bs, max_seq_len, device)

作用

  • 在 Engine 初始化时调用一次
  • 预分配固定大小的 buffer(input_idspage_table 等)
  • 为后续 Capture 做准备

prepare_for_capture:准备 Capture

1
2
3
4
def prepare_for_capture(self, batch: Batch) -> None:
# 使用预分配的 buffer
batch.input_ids = self.capture.input_ids[:bs]
batch.attn_metadata = ... # 使用固定的 metadata

作用

  • 在 CUDA Graph Capture 时调用
  • 使用预分配的 buffer,确保地址固定
  • 准备固定的 metadata

prepare_for_replay:准备 Replay

1
2
3
4
def prepare_for_replay(self, batch: Batch) -> None:
# 更新 buffer 的内容(地址不变)
self.capture.input_ids[:bs].copy_(batch.input_ids)
self.capture.page_table[:bs].copy_(metadata.page_table)

作用

  • 在 CUDA Graph Replay 时调用
  • 更新 buffer 的内容(地址不变)
  • CUDA Graph 会读取更新后的内容

三个方法的关系

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
Engine 初始化:

init_capture_graph() # 预分配 buffer

CUDA Graph Capture:

prepare_for_capture() # 使用固定 buffer

Capture 完成

每次 Replay:

prepare_for_replay() # 更新 buffer 内容

Replay CUDA Graph

核心:三个方法配合 CUDA Graph,实现固定地址 + 动态内容。


3. HybridBackend 的设计

1
2
3
4
5
6
7
8
9
10
11
12
class HybridBackend(BaseAttnBackend):
def __init__(self, prefill_backend: BaseAttnBackend, decode_backend: BaseAttnBackend):
self.prefill_backend = prefill_backend
self.decode_backend = decode_backend

def forward(self, q, k, v, layer_id, batch):
backend = self.prefill_backend if batch.is_prefill else self.decode_backend
return backend.forward(q, k, v, layer_id, batch)

def prepare_metadata(self, batch):
backend = self.prefill_backend if batch.is_prefill else self.decode_backend
return backend.prepare_metadata(batch)

3.1 为什么需要两个 backend?

Prefill 和 Decode 的计算特征不同

阶段 输入长度 计算特征 最佳实现
Prefill 长(10-1000+ tokens) 计算密集 FlashAttention
Decode 短(1 token) 访存密集 FlashInfer

FlashAttention

  • 优化了长序列的 Attention 计算
  • 减少 HBM 访问次数
  • 适合 Prefill(大量计算)

FlashInfer

  • 优化了短序列的 KV Cache 访问
  • 高效的 Paged Attention
  • 适合 Decode(频繁访存)

HybridBackend 的作用

  • 根据 batch.is_prefill 自动选择最优的 backend
  • Prefill 用 FlashAttention,Decode 用 FlashInfer
  • 对外提供统一接口

核心:HybridBackend = 自动选择最优实现 + 统一接口。


3.2 完整流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Scheduler 的 overlap_loop
def overlap_loop(self):
# 1. 调度下一个批次
batch = self._schedule_next_batch()

# 2. 准备 metadata(在 CUDA Graph 外部)
self.attn_backend.prepare_metadata(batch)
# HybridBackend 根据 batch.is_prefill 选择 backend
# 可能分配内存、拷贝数据

# 3. 执行前向传播(在 CUDA Graph 内部)
with self.engine_stream_ctx:
if self.graph_runner.can_use_cuda_graph(batch):
logits = self.graph_runner.replay(batch)
# 内部调用 model.forward() → attn_backend.forward()
# 只读取 metadata,不分配内存
else:
logits = self.model.forward()

时序

1
2
3
4
5
6
7
prepare_metadata (外部,可分配内存)

CUDA Graph Capture/Replay (内部,不可分配内存)

model.forward()

attn_backend.forward() (只读取 metadata)

4. 总结

4.1 Attention Backend 的核心设计

设计点 目的 实现
BaseAttnBackend 统一接口 抽象基类,支持多种实现
HybridBackend 自动选择最优实现 Prefill 用 FA,Decode 用 FI
prepare_metadata 分离内存分配和计算 外部分配,内部只读
CUDA Graph 三方法 支持 CUDA Graph 固定地址 + 动态内容

4.2 Prefill vs Decode

特征 Prefill Decode
输入长度 长(10-1000+ tokens) 短(1 token)
计算特征 计算密集 访存密集
最佳实现 FlashAttention FlashInfer
优化重点 减少 HBM 访问 高效 KV Cache 访问

4.3 关键设计原则

  1. 分离内存分配和计算

    • prepare_metadata 在 CUDA Graph 外部,可以分配内存
    • forward 在 CUDA Graph 内部,只读取数据
  2. 固定地址 + 动态内容

    • init_capture_graph 预分配固定 buffer
    • prepare_for_replay 更新 buffer 内容
  3. 自动选择最优实现

    • HybridBackend 根据阶段自动选择
    • Prefill 用 FlashAttention,Decode 用 FlashInfer

5. 费曼挑战

问题 1:用简单的话解释"为什么 Prefill 和 Decode 需要不同的 Attention backend?"

答案
Prefill 阶段输入长(10-1000+ tokens),计算量大,是计算密集型,FlashAttention 优化了长序列的 Attention 计算,减少 HBM 访问次数。Decode 阶段输入短(1 token),计算量小,但需要频繁访问 KV Cache,是访存密集型,FlashInfer 优化了 KV Cache 访问效率。使用不同的实现可以针对性优化,提升性能。

问题 2:用简单的话解释"为什么 prepare_metadata 要在 forward 之前调用?"

答案
prepare_metadata 可能需要分配内存(如创建 Tensor、拷贝数据),但 CUDA Graph 不允许在 Capture 时分配内存。将内存分配操作放在 prepare_metadata(CUDA Graph 外部),forward 只读取数据(CUDA Graph 内部),这样可以支持 CUDA Graph 加速。分离 = 内存分配在外部 + 计算在内部 = 支持 CUDA Graph。

问题 3:用简单的话解释"prepare_for_replay 为什么只更新内容,不更新地址?"

答案
CUDA Graph 在 Capture 时记录了所有 Tensor 的 GPU 地址。Replay 时,GPU 会直接访问这些记录的地址。如果地址变了,GPU 还是访问旧地址,会读到错误数据。所以必须保持地址固定,只更新内容。prepare_for_replay 使用 .copy_() 方法,在原地更新数据,地址不变。


7. FlashAttention 实现详解

7.1 FAMetadata 的字段含义

1
2
3
4
5
6
7
8
@dataclass
class FAMetadata(BaseAttnMetadata):
cu_seqlens_k: torch.Tensor # 累积序列长度(Key)
cu_seqlens_q: torch.Tensor # 累积序列长度(Query)
cache_seqlens: torch.Tensor # 每个请求的缓存长度
max_seqlen_k: int # 最大 Key 序列长度
max_seqlen_q: int # 最大 Query 序列长度
page_table: torch.Tensor # 页表

什么是 cu_seqlens(累积序列长度)?

例子:假设有 3 个请求,Query 长度分别是 [3, 4, 2]

1
2
seqlens_q = [3, 4, 2]
cu_seqlens_q = [0, 3, 7, 9] # 累积和:[0, 0+3, 3+4, 7+2]

作用:告诉 FlashAttention kernel 每个请求的 Query 在哪里。

1
2
3
请求 0:cu_seqlens_q[0:1] =   → Query 在位置 0-2
请求 1:cu_seqlens_q[1:2] = [3, 7] → Query 在位置 3-6
请求 2:cu_seqlens_q[2:3] = [7, 9] → Query 在位置 7-8

为什么需要累积?

FlashAttention 使用 变长序列批处理(Variable-Length Batching):

  • 所有请求的 Query 拼接成一个大 Tensor
  • 使用 cu_seqlens 标记每个请求的边界
1
2
3
4
5
6
7
8
9
10
# 3 个请求的 Query
q1 = [q1_0, q1_1, q1_2] # 长度 3
q2 = [q2_0, q2_1, q2_2, q2_3] # 长度 4
q3 = [q3_0, q3_1] # 长度 2

# 拼接成一个大 Tensor
q_all = [q1_0, q1_1, q1_2, q2_0, q2_1, q2_2, q2_3, q3_0, q3_1]
# ↑ ↑ ↑ ↑
# 0 3 7 9
# cu_seqlens_q = [ 3, 7, 9]

核心cu_seqlens 用于标记变长序列的边界。


cache_seqlens vs cu_seqlens_k

字段 含义 例子
cache_seqlens 每个请求的 KV Cache 长度 [5, 7, 6]
cu_seqlens_k 累积的 Key 长度 [0, 5, 12, 18]

关系

1
2
cache_seqlens = [5, 7, 6]
cu_seqlens_k = [0] + cumsum(cache_seqlens) = [0, 5, 12, 18]

为什么需要两个?

  • cache_seqlens:传给 FlashAttention kernel,告诉它每个请求的 KV 长度
  • cu_seqlens_k:用于索引,标记每个请求的 Key 边界

7.2 forward 方法的执行顺序

1
2
3
4
5
6
7
8
9
10
11
12
13
def forward(self, q, k, v, layer_id, batch):
metadata = batch.attn_metadata

# 1. 先存储 KV
self.kvcache.store_kv(k, v, batch.out_loc, layer_id)

# 2. 再执行 Attention
return _fa_sgl_impl(
q=q,
k_cache=self.kvcache.k_cache(layer_id),
v_cache=self.kvcache.v_cache(layer_id),
...
)

为什么要先 store_kv

Attention 需要访问完整的 KV Cache

1
2
3
4
5
6
7
当前批次:
- 请求 1:生成第 5 个 token
- 请求 2:生成第 7 个 token

Attention 计算:
- 请求 1 需要访问前 4 个 token 的 KV(已缓存)+ 第 5 个 token 的 KV(当前生成)
- 请求 2 需要访问前 6 个 token 的 KV(已缓存)+ 第 7 个 token 的 KV(当前生成)

流程

1
2
3
4
5
6
7
8
9
10
# 1. 存储当前 token 的 KV.kvcache.store_kv(k, v, batch.out_loc, layer_id)
# 现在 KV Cache 包含:前 N-1 个 token + 当前 token

# 2. 执行 Attention(访问完整的 KV Cache)
output = _fa_sgl_impl(
q=q,
k_cache=self.kvcache.k_cache(layer_id), # 包含所有 token 的 K
v_cache=self.kvcache.v_cache(layer_id), # 包含所有 token 的 V
...
)

如果顺序反了

1
2
3
4
5
6
7
8
9
10
# 错误:先执行 Attention
output = _fa_sgl_impl(
q=q,
k_cache=self.kvcache.k_cache(layer_id), # 只有前 N-1 个 token
...
)

# 再存储 KV
self.kvcache.store_kv(k, v, batch.out_loc, layer_id)
# 当前 token 的 KV 没有参与 Attention 计算!

核心:先存储 KV,确保 Attention 访问到完整的 KV Cache。


7.3 prepare_metadata 的三种计算方式

1
2
3
4
5
6
7
8
9
if max_seqlen_q == 1:
# 情况 1:Decode 阶段
cu_seqlens_q = torch.arange(0, padded_size + 1, device=device, dtype=torch.int32)
elif all(l == 0 for l in cached_lens):
# 情况 2:Prefill 阶段,没有 cache hit
cu_seqlens_q = cu_seqlens_k
else:
# 情况 3:Prefill 阶段,有 cache hit(Chunked Prefill)
cu_seqlens_q = torch.tensor([0] + seqlens_q, **cpu_kwargs).cumsum_(dim=0)

情况 1:Decode 阶段(max_seqlen_q == 1

特征:每个请求只生成 1 个 token。

1
2
3
4
5
# 假设 3 个请求,每个生成 1 个 token
seqlens_q = [1, 1, 1]
padded_size = 3

cu_seqlens_q = torch.arange(0, 3 + 1) = [0, 1, 2, 3]

为什么用 arange

  • Decode 时,每个请求的 Query 长度都是 1
  • cu_seqlens_q = [0, 1, 2, 3] 表示:请求 0 在位置 0,请求 1 在位置 1,请求 2 在位置 2
  • arange 直接生成,比 cumsum 更快

情况 2:Prefill 阶段,没有 cache hit

特征:所有请求都是第一次处理,没有缓存。

1
2
3
4
5
6
7
8
9
10
# 假设 3 个请求,长度分别是 [5, 7, 6]
seqlens_q = [5, 7, 6]
cached_lens = [0, 0, 0] # 都没有缓存

# Query 长度 = Key 长度(因为没有缓存)
seqlens_k = [5, 7, 6]
cu_seqlens_k = [0, 5, 12, 18]

# 直接复用 cu_seqlens_k
cu_seqlens_q = cu_seqlens_k = [0, 5, 12, 18]

为什么 cu_seqlens_q = cu_seqlens_k

  • 没有缓存时,Query 长度 = Key 长度
  • 直接复用,避免重复计算

情况 3:Prefill 阶段,有 cache hit(Chunked Prefill)

特征:部分请求有缓存(Prefix Caching 或 Chunked Prefill)。

1
2
3
4
5
6
7
8
9
10
11
# 假设 3 个请求
# 请求 0:前 3 个 token 已缓存,现在处理后 2 个
# 请求 1:前 4 个 token 已缓存,现在处理后 3 个
# 请求 2:没有缓存,处理全部 6 个

cached_lens = [3, 4, 0]
seqlens_q = [2, 3, 6] # 当前处理的长度
seqlens_k = [5, 7, 6] # 总长度 = cached_lens + seqlens_q

cu_seqlens_q = [0] + cumsum([2, 3, 6]) = [0, 2, 5, 11]
cu_seqlens_k = [0] + cumsum([5, 7, 6]) = [0, 5, 12, 18]

为什么 cu_seqlens_q != cu_seqlens_k

  • Query 长度(当前处理)!= Key 长度(总长度)
  • 需要单独计算 cu_seqlens_q

三种情况总结

情况 特征 cu_seqlens_q 计算方式 原因
Decode 每个请求生成 1 个 token arange(0, bs+1) 长度都是 1,用 arange 更快
Prefill 无缓存 所有请求都是第一次处理 复用 cu_seqlens_k Query 长度 = Key 长度
Prefill 有缓存 部分请求有缓存 cumsum(seqlens_q) Query 长度 != Key 长度

核心:根据不同场景选择最高效的计算方式。


7.4 prepare_for_replay 的优化

1
2
3
4
5
6
7
8
9
10
11
def prepare_for_replay(self, batch: Batch) -> None:
metadata, bs = batch.attn_metadata, batch.padded_size

# 更新 6 个数据
self.capture.input_ids[:bs].copy_(batch.input_ids)
self.capture.out_loc[:bs].copy_(batch.out_loc)
self.capture.cu_seqlens_k[: bs + 1].copy_(metadata.cu_seqlens_k)
self.capture.positions[:bs].copy_(metadata.positions)
self.capture.seq_lens[:bs].copy_(metadata.cache_seqlens)
self.capture.page_table[:bs, : metadata.max_seqlen_k].copy_(metadata.page_table)
# 注意:没有更新 cu_seqlens_q

为什么要更新这些数据?

CUDA Graph 记录的是固定地址,但内容需要动态更新

数据 作用 为什么需要更新
input_ids 当前 token 的 ID 每次生成的 token 不同
out_loc KV 存储位置 每个请求的 KV 位置不同
cu_seqlens_k Key 的累积长度 每次生成后,序列长度+1
positions Token 位置 每次生成后,位置+1
seq_lens 缓存长度 每次生成后,长度+1
page_table 页表 KV 可能分配到新页

为什么不更新 cu_seqlens_q

看代码第 135 行的注释:

1
# cu_seqlens_q is always [0, 1, 2, ..., bs] for decode (i.e. no-op)

原因

Decode 阶段,每个请求都只生成 1 个 token:

1
2
3
4
5
6
7
8
9
10
11
# Capture 时
cu_seqlens_q = [0, 1, 2, 3, 4, 5, 6, 7, 8] # 假设 bs=8

# 第 1 次 Replay
cu_seqlens_q = [0, 1, 2, 3, 4, 5, 6, 7, 8] # 不变

# 第 2 次 Replay
cu_seqlens_q = [0, 1, 2, 3, 4, 5, 6, 7, 8] # 还是不变

# 第 N 次 Replay
cu_seqlens_q = [0, 1, 2, 3, 4, 5, 6, 7, 8] # 永远不变

关键

  • Decode 阶段,Query 长度永远是 1
  • cu_seqlens_q 永远是 [0, 1, 2, ..., bs]
  • 不需要更新,节省拷贝开销

对比 cu_seqlens_k

1
2
3
4
5
6
7
8
# 第 1 次 Replay
cu_seqlens_k = [0, 5, 12, 18, ...] # 每个请求的 KV 长度

# 第 2 次 Replay(每个请求生成了 1 个 token,长度+1)
cu_seqlens_k = [0, 6, 13, 19, ...] # 变化了!

# 需要更新
self.capture.cu_seqlens_k[:bs+1].copy_(metadata.cu_seqlens_k)

核心:Decode 阶段 cu_seqlens_q 是常量,不需要更新;cu_seqlens_k 会变化,需要更新。


8. 总结

8.1 Attention Backend 的核心设计

设计点 目的 实现
BaseAttnBackend 统一接口 抽象基类,支持多种实现
HybridBackend 自动选择最优实现 Prefill 用 FA,Decode 用 FI
prepare_metadata 分离内存分配和计算 外部分配,内部只读
CUDA Graph 三方法 支持 CUDA Graph 固定地址 + 动态内容

8.2 FlashAttention 的核心设计

设计点 目的 实现
变长序列批处理 高效处理不同长度的请求 使用 cu_seqlens 标记边界
三种 cu_seqlens_q 计算 针对不同场景优化 Decode 用 arange,Prefill 分两种情况
先存储 KV 确保 Attention 访问完整 KV store_kvforward 之前
prepare_for_replay 优化 减少拷贝开销 只更新会变化的数据

8.3 Prefill vs Decode

特征 Prefill Decode
输入长度 长(10-1000+ tokens) 短(1 token)
计算特征 计算密集 访存密集
最佳实现 FlashAttention FlashInfer
cu_seqlens_q 可能变化 永远是 [0,1,2,...]

9. 费曼挑战

问题 1:用简单的话解释"为什么 Prefill 和 Decode 需要不同的 Attention backend?"

答案
Prefill 阶段输入长(10-1000+ tokens),计算量大,是计算密集型,FlashAttention 优化了长序列的 Attention 计算,减少 HBM 访问次数。Decode 阶段输入短(1 token),计算量小,但需要频繁访问 KV Cache,是访存密集型,FlashInfer 优化了 KV Cache 访问效率。使用不同的实现可以针对性优化,提升性能。

问题 2:用简单的话解释"为什么 prepare_metadata 要在 forward 之前调用?"

答案
prepare_metadata 可能需要分配内存(如创建 Tensor、拷贝数据),但 CUDA Graph 不允许在 Capture 时分配内存。将内存分配操作放在 prepare_metadata(CUDA Graph 外部),forward 只读取数据(CUDA Graph 内部),这样可以支持 CUDA Graph 加速。分离 = 内存分配在外部 + 计算在内部 = 支持 CUDA Graph。

问题 3:用简单的话解释"prepare_for_replay 为什么只更新内容,不更新地址?"

答案
CUDA Graph 在 Capture 时记录了所有 Tensor 的 GPU 地址。Replay 时,GPU 会直接访问这些记录的地址。如果地址变了,GPU 还是访问旧地址,会读到错误数据。所以必须保持地址固定,只更新内容。prepare_for_replay 使用 .copy_() 方法,在原地更新数据,地址不变。

问题 4:用简单的话解释"为什么 Decode 阶段的 cu_seqlens_q 不需要更新?"

答案
Decode 阶段每个请求都只生成 1 个 token,Query 长度永远是 1。所以 cu_seqlens_q 永远是 [0, 1, 2, ..., bs],是一个常量,不会变化。不需要每次 Replay 都拷贝,节省开销。而 cu_seqlens_k 会随着生成的 token 增加而变化,需要更新。

问题 5:用简单的话解释"为什么要先 store_kv,再执行 Attention?"

答案
Attention 需要访问完整的 KV Cache,包括前面已缓存的 token 和当前生成的 token。如果先执行 Attention,当前 token 的 KV 还没存储,Attention 只能访问到前面的 token,会导致计算错误。先 store_kv 确保当前 token 的 KV 也被存储,Attention 可以访问到完整的 KV Cache。


10. FlashInfer 实现详解

10.1 FIMetadata 的设计

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
@dataclass
class FIMetadata(BaseAttnMetadata):
cu_seqlens_q_cpu: torch.Tensor # on cpu
cu_seqlens_k_cpu: torch.Tensor # on cpu
cu_seqlens_q_gpu: torch.Tensor # on gpu
indices: torch.Tensor # on gpu
last_page_len_cpu: torch.Tensor # on cpu
num_qo_heads: int
num_kv_heads: int
head_dim: int
page_size: Literal[1]
pos_encoding_mode: str
seq_lens_cpu: torch.Tensor # on cpu
dtype: torch.dtype
wrapper: BatchPrefillWithPagedKVCacheWrapper | BatchDecodeWithPagedKVCacheWrapper
initialized: bool = False

关键区别:CPU/GPU 数据分离

数据 位置 作用
cu_seqlens_k_cpu CPU Plan 阶段使用
cu_seqlens_q_cpu CPU Plan 阶段使用
seq_lens_cpu CPU Plan 阶段使用
last_page_len_cpu CPU Plan 阶段使用
cu_seqlens_q_gpu GPU Run 阶段使用
indices GPU Run 阶段使用(页表索引)
positions GPU Run 阶段使用

为什么需要 CPU/GPU 分离?

FlashInfer 使用 两阶段初始化

  1. Plan 阶段(在 CPU 上):分析 batch 的结构,计算内存布局
  2. Run 阶段(在 GPU 上):执行实际的 Attention 计算
1
2
3
4
5
6
7
8
9
10
11
# Plan 阶段(CPU)
wrapper.plan(
indptr=cu_seqlens_k_cpu, # CPU 上的数据
indices=indices, # GPU 上的数据
last_page_len=last_page_len_cpu, # CPU 上的数据
seq_lens=seq_lens_cpu, # CPU 上的数据
...
)

# Run 阶段(GPU)
output = wrapper.run(q=q, paged_kv_cache=kv_cache)

好处

  1. Plan 在 CPU 上执行:不需要 GPU 同步,避免性能损失
  2. 避免重复拷贝:GPU 数据预先拷贝,Run 时直接使用
  3. Pin Memory 优化:CPU 数据使用 pin_memory=True,拷贝更快

10.2 wrapper 的设计

1
wrapper: BatchPrefillWithPagedKVCacheWrapper | BatchDecodeWithPagedKVCacheWrapper

两种 Wrapper

Wrapper 用途 特点
BatchPrefillWithPagedKVCacheWrapper Prefill 阶段 处理长序列,支持变长
BatchDecodeWithPagedKVCacheWrapper Decode 阶段 处理短序列(1 token),优化访存

Wrapper 的核心方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class BatchDecodeWithPagedKVCacheWrapper:
def plan(
self,
indptr, # cu_seqlens_k_cpu
indices, # 页表索引
last_page_len, # 最后一页长度
seq_lens, # 序列长度
...
):
# 1. 分析 batch 结构
# 2. 计算内存访问模式
# 3. 分配 workspace buffer
# 4. 生成 kernel 配置
...

def run(self, q, paged_kv_cache):
# 执行 Attention kernel
return flashinfer_decode_kernel(q, paged_kv_cache, ...)

10.3 _initialize_metadata_once 的作用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def _initialize_metadata_once(metadata: FIMetadata) -> None:
if metadata.initialized:
return # 已经初始化,跳过

metadata.initialized = True

# 调用 wrapper.plan()
metadata.wrapper.plan(
indptr=metadata.cu_seqlens_k_cpu,
indices=metadata.indices,
last_page_len=metadata.last_page_len_cpu,
seq_lens=metadata.seq_lens_cpu,
...
)

plan 的作用

Plan 阶段做了什么?

  1. 分析 batch 结构

    • 每个请求的序列长度(seq_lens
    • 累积序列长度(cu_seqlens_k_cpu
    • 页表索引(indices
  2. 计算内存访问模式

    • 哪些 KV Cache 页需要访问
    • 访问顺序是什么
    • 如何合并访问以提高效率
  3. 分配 workspace buffer

    • FlashInfer 需要临时 buffer 存储中间结果
    • 根据 batch 大小动态分配
  4. 生成 kernel 配置

    • 选择最优的 kernel(不同的 batch size 用不同的 kernel)
    • 配置 thread block 大小
    • 配置 shared memory 大小

为什么只初始化一次?

1
2
if metadata.initialized:
return # 已经初始化,跳过

原因

  1. Plan 开销大

    • 需要分析 batch 结构
    • 需要分配 workspace buffer
    • 需要选择 kernel
  2. CUDA Graph 中 metadata 不变

    • Capture 时,metadata 的结构是固定的
    • Replay 时,只有数据内容变化,结构不变
    • 不需要重复 Plan
  3. 性能优化

    • 第一次调用时 Plan
    • 后续调用直接 Run
    • 避免重复开销

完整流程

1
2
3
4
5
6
7
Capture 时:
第 1 次 forward → Plan + Run

Replay 时:
第 2 次 forward → 跳过 Plan,只 Run
第 3 次 forward → 跳过 Plan,只 Run
第 N 次 forward → 跳过 Plan,只 Run

核心:Plan 一次,Run 多次 = 降低 CPU 开销 = 提高 Decode 吞吐。


10.4 use_tensor_cores 的判断逻辑

1
2
3
4
5
6
7
@cached_property
def use_tensor_cores(self) -> bool:
if (overriden_value := .FLASHINFER_USE_TENSOR_CORES.value) is not None:
return overriden_value

GQA = self.config.num_qo_heads // self.config.num_kv_heads
return GQA >= 4

什么是 Tensor Cores?

Tensor Cores 是 NVIDIA GPU 的专用硬件,用于加速矩阵乘法:

  • 普通 CUDA Cores:通用计算
  • Tensor Cores:专门用于矩阵乘法(如 Attention 的 Q @ K^T)

性能对比

  • Tensor Cores:~10x 吞吐量
  • 但有限制:矩阵大小必须是 8 或 16 的倍数

什么是 GQA(Grouped Query Attention)?

传统 Multi-Head Attention

1
2
3
num_qo_heads = 32  # Query/Output 头数
num_kv_heads = 32 # Key/Value 头数
GQA = 32 / 32 = 1 # 每个 KV 头对应 1 个 QO 头

Grouped Query Attention

1
2
3
num_qo_heads = 32  # Query/Output 头数
num_kv_heads = 8 # Key/Value 头数
GQA = 32 / 8 = 4 # 每个 KV 头对应 4 个 QO 头

好处

  • 减少 KV Cache 大小(8 个头 vs 32 个头)
  • 减少内存带宽需求
  • 性能损失很小

为什么 GQA >= 4 才使用 Tensor Cores?

原因

  1. Tensor Cores 的限制

    • 需要足够大的矩阵才能发挥优势
    • 小矩阵用 Tensor Cores 反而慢(启动开销大)
  2. GQA 的影响

    • GQA = 1:每个 KV 头对应 1 个 QO 头,矩阵小
    • GQA = 4:每个 KV 头对应 4 个 QO 头,矩阵大
  3. 经验阈值

    • GQA >= 4:矩阵足够大,Tensor Cores 有优势
    • GQA < 4:矩阵太小,用普通 CUDA Cores 更快

例子

1
2
3
4
5
6
7
8
9
10
11
Decode 阶段,batch_size = 8

GQA = 1:
Q: (8, 1, 4096) # 8 个请求,1 个 token,4096 维
K: (seq_len, 4096)
Q @ K^T: (8, 1, seq_len) # 小矩阵,Tensor Cores 不划算

GQA = 4:
Q: (8, 4, 4096) # 每个 KV 头对应 4 个 QO 头
K: (8, seq_len, 4096)
Q @ K^T: (8, 4, seq_len) # 大矩阵,Tensor Cores 有优势

核心GQA >= 4 时,矩阵足够大,Tensor Cores 才有性能优势。


10.5 FlashInfer vs FlashAttention

特征 FlashAttention FlashInfer
适用阶段 Prefill Decode
数据位置 全部在 GPU CPU + GPU 分离
初始化 每次都计算 metadata Plan 一次,Run 多次
CPU 开销 较高 较低(Plan 只一次)
优化重点 减少 HBM 访问 高效 KV Cache 访问
Tensor Cores 不使用 根据 GQA 选择

核心区别

  • FlashAttention:每次 forward 都重新计算 metadata,适合 Prefill(计算密集)
  • FlashInfer:Plan 一次,Run 多次,适合 Decode(访存密集,CPU 开销敏感)

11. 总结

11.1 Attention Backend 的核心设计

设计点 目的 实现
BaseAttnBackend 统一接口 抽象基类,支持多种实现
HybridBackend 自动选择最优实现 Prefill 用 FA,Decode 用 FI
prepare_metadata 分离内存分配和计算 外部分配,内部只读
CUDA Graph 三方法 支持 CUDA Graph 固定地址 + 动态内容

11.2 FlashAttention 的核心设计

设计点 目的 实现
变长序列批处理 高效处理不同长度的请求 使用 cu_seqlens 标记边界
三种 cu_seqlens_q 计算 针对不同场景优化 Decode 用 arange,Prefill 分两种情况
先存储 KV 确保 Attention 访问完整 KV store_kvforward 之前
prepare_for_replay 优化 减少拷贝开销 只更新会变化的数据

11.3 FlashInfer 的核心设计

设计点 目的 实现
CPU/GPU 数据分离 避免 GPU 同步 CPU 数据用于 Plan,GPU 数据用于 Run
Plan/Run 两阶段 降低 CPU 开销 Plan 一次,Run 多次
Wrapper 封装 针对不同阶段优化 Prefill 和 Decode 用不同 Wrapper
Tensor Cores 优化 大矩阵加速 GQA >= 4 时使用

11.4 Prefill vs Decode

特征 Prefill Decode
输入长度 长(10-1000+ tokens) 短(1 token)
计算特征 计算密集 访存密集
最佳实现 FlashAttention FlashInfer
cu_seqlens_q 可能变化 永远是 [0,1,2,...]
CPU 开销 可接受 敏感(需要优化)

12. 费曼挑战

问题 1:用简单的话解释"为什么 Prefill 和 Decode 需要不同的 Attention backend?"

答案
Prefill 阶段输入长(10-1000+ tokens),计算量大,是计算密集型,FlashAttention 优化了长序列的 Attention 计算,减少 HBM 访问次数。Decode 阶段输入短(1 token),计算量小,但需要频繁访问 KV Cache,是访存密集型,FlashInfer 优化了 KV Cache 访问效率。使用不同的实现可以针对性优化,提升性能。

问题 2:用简单的话解释"为什么 prepare_metadata 要在 forward 之前调用?"

答案
prepare_metadata 可能需要分配内存(如创建 Tensor、拷贝数据),但 CUDA Graph 不允许在 Capture 时分配内存。将内存分配操作放在 prepare_metadata(CUDA Graph 外部),forward 只读取数据(CUDA Graph 内部),这样可以支持 CUDA Graph 加速。分离 = 内存分配在外部 + 计算在内部 = 支持 CUDA Graph。

问题 3:用简单的话解释"prepare_for_replay 为什么只更新内容,不更新地址?"

答案
CUDA Graph 在 Capture 时记录了所有 Tensor 的 GPU 地址。Replay 时,GPU 会直接访问这些记录的地址。如果地址变了,GPU 还是访问旧地址,会读到错误数据。所以必须保持地址固定,只更新内容。prepare_for_replay 使用 .copy_() 方法,在原地更新数据,地址不变。

问题 4:用简单的话解释"为什么 Decode 阶段的 cu_seqlens_q 不需要更新?"

答案
Decode 阶段每个请求都只生成 1 个 token,Query 长度永远是 1。所以 cu_seqlens_q 永远是 [0, 1, 2, ..., bs],是一个常量,不会变化。不需要每次 Replay 都拷贝,节省开销。而 cu_seqlens_k 会随着生成的 token 增加而变化,需要更新。

问题 5:用简单的话解释"为什么要先 store_kv,再执行 Attention?"

答案
Attention 需要访问完整的 KV Cache,包括前面已缓存的 token 和当前生成的 token。如果先执行 Attention,当前 token 的 KV 还没存储,Attention 只能访问到前面的 token,会导致计算错误。先 store_kv 确保当前 token 的 KV 也被存储,Attention 可以访问到完整的 KV Cache。

问题 6:用简单的话解释"FlashInfer 为什么需要 CPU 和 GPU 两份数据?"

答案
FlashInfer 使用 Plan/Run 两阶段。Plan 阶段在 CPU 上分析 batch 结构、分配 buffer、生成 kernel 配置,需要 CPU 数据。Run 阶段在 GPU 上执行 Attention 计算,需要 GPU 数据。分离后,Plan 不需要 GPU 同步,Run 不需要重复拷贝,提高性能。

问题 7:用简单的话解释"为什么 FlashInfer 的 Plan 只调用一次?"

答案
Plan 开销大(分析结构、分配 buffer、选择 kernel),但 CUDA Graph 中 batch 结构固定,不需要重复 Plan。第一次调用时 Plan,后续 Replay 直接 Run,大幅降低 CPU 开销,提高 Decode 吞吐。Plan 一次,Run 多次 = 降低 CPU 开销。

问题 8:用简单的话解释"为什么 GQA >= 4 才使用 Tensor Cores?"

答案
Tensor Cores 专门用于加速矩阵乘法,但需要足够大的矩阵才能发挥优势,小矩阵反而慢。GQA = 1 时,每个 KV 头对应 1 个 QO 头,矩阵小。GQA >= 4 时,每个 KV 头对应 4+ 个 QO 头,矩阵大,Tensor Cores 有优势。