学习目标

深入理解 Mini-SGLang 的调度系统,包括:

  • Scheduler 主调度器的核心逻辑
  • PrefillManager 的 Chunked Prefill 实现
  • DecodeManager 的请求管理
  • 重叠调度 (Overlap Scheduling) 的原理

1. 调度系统整体架构

1.1 核心组件

调度系统由 5 个核心组件组成:

1
2
3
4
5
6
Scheduler (主调度器)
├─→ PrefillManager (Prefill 阶段调度)
├─→ DecodeManager (Decode 阶段调度)
├─→ CacheManager (缓存管理)
├─→ TableManager (页表管理)
└─→ Engine (GPU 计算引擎)

1.2 数据流概览

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
用户请求 (UserMsg)

PrefillManager.pending_list (待处理队列)

Scheduler._schedule_next_batch() (调度批次)

Scheduler._prepare_batch() (准备数据)

Engine.forward_batch() (GPU 计算)

DecodeManager.running_reqs (正在解码)

Scheduler._process_last_data() (处理结果)

DetokenizeMsg (发送给 Detokenizer)

2. Scheduler 主调度器

2.1 核心职责

Scheduler 是整个调度系统的大脑,负责协调所有组件:

  1. 接收消息: 从 Tokenizer 接收用户请求 (UserMsg)
  2. 调度批次: 决定下一个批次执行 Prefill 还是 Decode
  3. 准备数据: 分配缓存页面、准备 2D 索引、准备注意力元数据
  4. 执行推理: 调用 Engine 进行前向传播和采样
  5. 处理结果: 将采样结果发送给 Detokenizer
  6. 资源管理: 释放已完成请求的页表和缓存

2.2 初始化流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# scheduler.py: 45-73
def __init__(self, config: SchedulerConfig):
# 1. 创建 Engine
self.engine = Engine(config)

# 2. 创建两个 CUDA Stream (重叠调度)
self.device = self.engine.device
self.stream = torch.cuda.Stream(device=self.device) # Scheduler stream
self.engine_stream_ctx = torch.cuda.stream(self.engine.stream) # Engine stream

# 3. 初始化各个 Manager
self.table_manager = TableManager(config.max_running_req, self.engine.page_table)
self.cache_manager = CacheManager(self.device, self.engine.num_pages, config.cache_type)
self.decode_manager = DecodeManager()
self.prefill_manager = PrefillManager(
self.cache_manager, self.table_manager, self.decode_manager
)

# 4. 其他配置
self.tokenizer = AutoTokenizer.from_pretrained(config.model_path)
self.eos_token_id = self.tokenizer.eos_token_id
self.prefill_budget = config.max_extend_tokens # 每批次最多处理的 token 数

关键点:

  • 两个 CUDA Stream: Scheduler stream 处理元数据 (CPU 密集),Engine stream 执行 GPU 计算
  • prefill_budget: 限制每批次 Prefill 的 token 数,避免单个批次过大导致延迟

2.3 调度策略

1
2
3
4
5
6
7
8
# scheduler.py: 172-178
def _schedule_next_batch(self) -> ForwardInput | None:
# Prefill 优先于 Decode
batch = (
self.prefill_manager.schedule_next_batch(self.prefill_budget)
or self.decode_manager.schedule_next_batch()
)
return self._prepare_batch(batch) if batch else None

调度优先级: Prefill > Decode

原因:

  • 用户体验: 新请求的首 token 延迟直接影响用户感知
  • 公平性: 避免新请求长时间等待
  • 吞吐量: Prefill 可以批量处理多个请求,提高 GPU 利用率

例子:

1
2
3
4
5
时间轴:
T0: 收到请求 A (Prefill)
T1: 收到请求 B (Prefill), 请求 C 正在 Decode
T2: 调度器选择: Prefill(A, B) 而不是 Decode(C)
T3: Prefill 完成后,下一批次才调度 Decode(A, B, C)

2.4 批次准备 (_prepare_batch)

批次准备是调度系统最复杂的部分,包含 5 个关键步骤:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# scheduler.py: 141-170
def _prepare_batch(self, batch: Batch) -> ForwardInput:
# 步骤 1: 分配 KV Cache 页面
needed_size = sum(r.extend_len for r in batch.reqs)
batch.out_loc = self.cache_manager.allocate(needed_size)

# 步骤 2: Pad batch (CUDA Graph 优化)
if padding_size := self.engine.graph_runner.pad_batch(batch):
batch.out_loc = F.pad(batch.out_loc, (0, padding_size), value=self.engine.dummy_page)

# 步骤 3: 准备 2D 索引
load_indices = self._make_2d_indices(
[(r.table_idx, r.cached_len, r.device_len) for r in batch.padded_reqs]
)
write_indices = self._make_2d_indices([
(r.table_idx, r.device_len, r.device_len + 1) if r.can_decode()
else self.dummy_write_2d_pos
for r in batch.reqs
])

# 步骤 4: 写入 page_table
self.page_table.view(-1)[load_indices] = batch.out_loc

# 步骤 5: 准备注意力元数据
self.engine.attn_backend.prepare_metadata(batch)

return ForwardInput(batch, sample_args, load_indices, write_indices)

步骤 1: 分配 KV Cache 页面

out_loc新分配的页面索引,用于存储这个批次计算出的 KV Cache:

  • Prefill 阶段: 计算所有 token 的 KV,需要 input_len 个页面
  • Decode 阶段: 只计算新 token 的 KV,需要 1 个页面

例子:

1
2
3
4
5
6
7
# 批次包含 3 个请求:
# - 请求 A: extend_len=100 (Prefill)
# - 请求 B: extend_len=50 (Prefill)
# - 请求 C: extend_len=1 (Decode)

needed_size = 100 + 50 + 1 = 151
out_loc = [5, 6, 7, ..., 155] # 151 个页面索引

步骤 2: Pad batch (CUDA Graph)

为了使用预先录制的 CUDA Graph,需要将 batch 填充到固定大小:

1
2
3
# 例子: batch.size=5, 最近的 graph batch size=8
padding_size = 8 - 5 = 3
batch.out_loc = [5, 6, 7, 8, 9, dummy_page, dummy_page, dummy_page]

步骤 3: 准备 2D 索引

token_pool 是 2D 表 [max_running_req, max_seq_len],存储所有请求的 token IDs。

load_indices: 加载输入 token IDs

1
2
3
# 请求 0: cached_len=50, device_len=100 → 加载 [50:100]
# 请求 1: cached_len=0, device_len=20 → 加载 [0:20]
# 1D 索引: [50, 51, ..., 99, max_seq_len*1+0, ..., max_seq_len*1+19]

write_indices: 写入采样结果

1
2
3
# 请求 0: device_len=100 → 写到位置 100
# 请求 1: device_len=20 → 写到位置 20
# 1D 索引: [100, max_seq_len*1+20]

为什么需要 1D 索引?

因为 PyTorch 的高级索引需要从不同请求的不同位置批量访问:

1
2
3
4
5
# 使用 1D 索引批量加载
batch.input_ids = self.token_pool.view(-1)[load_indices]

# 使用 1D 索引批量写入
self.token_pool.view(-1)[write_indices] = next_tokens_gpu

步骤 4: 写入 page_table

将新分配的页面索引写入 page_table,供注意力计算使用:

1
2
3
# page_table[请求0][50:100] = out_loc[0:50]
# page_table[请求1][0:20] = out_loc[50:70]
self.page_table.view(-1)[load_indices] = batch.out_loc

为什么要先写 page_table?

因为 prepare_metadata 需要读取 page_table 来构建 FlashInfer 的元数据 (cu_seqlens, indices 等)。

步骤 5: 准备注意力元数据

调用 FlashInfer backend 准备注意力计算所需的元数据 (详见 Day 5)。

2.5 重叠调度 (Overlap Scheduling)

重叠调度是 Mini-SGLang 的核心优化,通过两个 CUDA Stream 隐藏 CPU 延迟:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# scheduler.py: 233-256
def overlap_loop(self, last_data: Forwarde) -> ForwardData | None:
# 1. 接收新消息 (非阻塞)
blocking = not (last_data or self.prefill_manager.runnable or self.decode_manager.runnable)
for msg in self.receive_msg(blocking=blocking):
self._process_one_msg(msg)

# 2. 调度下一个批次 (在 Scheduler stream)
forward_input = self._schedule_next_batch()

# 3. 执行当前批次 (在 Engine stream)
ongoing_data = None
if forward_input is not None:
with self.engine_stream_ctx:
self.engine.stream.wait_stream(self.stream) # 等待 Scheduler stream
ongoing_data = (forward_input, self._forward(forward_input))

# 4. 处理上一批次的结果 (在 Scheduler stream)
self._process_last_data(last_data, ongoing_data)

return ongoing_data

两个 Stream 的分工

  • Scheduler Stream (self.stream): 处理元数据 (CPU 密集)

    • 接收消息
    • 调度批次
    • 准备 2D 索引
    • 分配缓存
    • 处理结果
  • Engine Stream (engine.stream): GPU 计算 (GPU 密集)

    • 模型前向传播
    • 注意力计算
    • 采样

重叠原理

1
2
3
4
5
6
时间轴:
Batch 0: [Scheduler准备]────→[Engine计算]────→[Scheduler处理结果]

Batch 1: [Scheduler准备]────→[Engine计算]────→[Scheduler处理结果]

重叠!隐藏CPU延迟

关键点:

  1. 当 Batch 0 在 Engine Stream 计算时,Batch 1 可以在 Scheduler Stream 准备元数据
  2. engine.stream.wait_stream(self.stream) 确保 Engine 等待 Scheduler 准备完成
  3. last_dataongoing_data 实现流水线: 处理上一批结果的同时执行当前批次

性能提升: 隐藏 10-20% 的 CPU 延迟,提高 GPU 利用率。

2.6 结果处理 (_process_last_data)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# scheduler.py: 75-114
def _process_last_data(self, last_data: ForwardData | None, ongoing_data: ForwardData | None):
if last_data is None:
return

batch, (_, next_tokens_cpu, copy_done) = last_data[0].batch, last_data[1]

# 步骤 1: 等待 GPU→CPU 拷贝完成
copy_done.synchronize()

# 步骤 2: 处理每个请求的采样结果
reply = []
for i, req in enumerate(batch.reqs):
if req in self.finished_reqs or isinstance(req, ChunkedReq):
continue # 跳过已完成和 ChunkedReq

# 更新请求状态
next_token_id = next_tokens_cpu[i]
req.append_host(next_token_id.unsqueeze(0))
next_token = int(next_token_id.item())

# 检查是否完成
finished = not req.can_decode()
if not req.sampling_params.ignore_eos:
finished |= next_token == self.eos_token_id

# 发送给 Detokenizer
reply.append(DetokenizeMsg(uid=req.uid, next_token=next_token, finished=finished))

if finished:
self.finished_reqs.add(req)
self.decode_manager.remove_req(req)

# 步骤 3: 释放已完成但不在 ongoing_data 中的请求
ongoing_reqs = ongoing_data[0].batch.reqs if ongoing_data else []
for req in self.finished_reqs.difference(ongoing_reqs):
self.table_manager.free(req.table_idx)
self.cache_manager.free_and_cache_finished_req(
req.cache_handle,
req.input_ids[:req.cached_len],
self.page_table[req.table_idx, :req.cached_len],
)

# 步骤 4: 只保留 ongoing 的已完成请求
self.finished_reqs.intersection_update(ongoing_reqs)

self.send_result(reply)

关键点:

  1. copy_done.synchronize(): 必须等待 GPU→CPU 拷贝完成,否则访问无效数据
  2. 跳过 ChunkedReq: ChunkedReq 还在 Prefill 阶段,没有生成 token
  3. 延迟释放资源: 如果请求在 ongoing_data 中,延迟释放 (避免重叠调度时的资源冲突)

3. PrefillManager - Chunked Prefill 实现

3.1 核心职责

PrefillManager 负责管理待处理请求队列 (pending_list),并实现 Chunked Prefill (分块预填充):

  1. 维护 pending_list: 所有等待 Prefill 的请求
  2. 前缀匹配: 使用 Radix Tree 匹配已缓存的前缀
  3. 资源分配: 分配 table_idx 和缓存页面
  4. Chunked Prefill: 当 token 预算不足时,分块处理大请求

3.2 Chunked Prefill 的原理

问题: 如果一个请求有 10,000 个 input tokens,但 prefill_budget=2048,怎么办?

解决方案: 分多次处理,每次处理 2048 tokens,创建 ChunkedReq 保存状态。

例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# 请求 A: input_len=10000, prefill_budget=2048

# 第一批次: Prefill 前 2048 tokens
req = ChunkedReq(
input_ids=input_ids[:2048],
cached_len=0,
output_len=0,
...
)
# cached_len 更新为 2048

# 第二批次: Prefill 2048-4096 tokens
req = ChunkedReq(
input_ids=input_ids[:4096],
cached_len=2048,
output_len=0,
...
)
# cached_len 更新为 4096

# ... 继续直到处理完所有 10000 tokens

# 最后一批次: Prefill 8192-10000 tokens
req = Req( # 不再是 ChunkedReq
input_ids=input_ids[:10000],
cached_len=8192,
output_len=0,
...
)
# Prefill 完成,下一批次开始 Decode

3.3 调度流程 (schedule_next_batch)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# prefill.py: 124-149
def schedule_next_batch(self, prefill_budget: int) -> Batch | None:
if len(self.pending_list) == 0:
return None

# 步骤 1: 创建 PrefillAdder (带 token 预算和保留空间)
adder = PrefillAdder(
token_budget=prefill_budget,
reserved_size=self.decode_manager.inflight_tokens, # 为 Decode 保留空间
cache_manager=self.cache_manager,
table_manager=self.table_manager,
)

# 步骤 2: 遍历 pending_list,尝试添加请求
reqs = []
chunked_list = []
for pending_req in self.pending_list:
if req := adder.try_add_one(pending_req):
pending_req.chunked_req = None
if isinstance(req, ChunkedReq):
pending_req.chunked_req = req # 保存 ChunkedReq 状态
chunked_list.append(pending_req) # 需要继续处理
reqs.append(req)
else:
break # 预算不足,停止添加

if len(reqs) == 0:
return None

# 步骤 3: 更新 pending_list (ChunkedReq 优先)
self.pending_list = chunked_list + self.pending_list[len(reqs):]

return Batch(reqs=reqs, phase="prefill")

关键点:

  1. reserved_size: 为正在 Decode 的请求预留空间,避免 Prefill 占用过多内存导致 Decode OOM
  2. ChunkedReq 优先: chunked_list 放在队列最前面,优先完成已开始的请求 (避免资源浪费)
  3. 预算控制: token_budget 限制每批次处理的 token 数,避免单个批次过大

3.4 添加请求 (_add_one_req)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# prefill.py: 63-88
def _add_one_req(self, pending_req: PendingReq, cache_handle, table_idx, cached_len) -> Req:
# 步骤 1: 计算 chunk_size
remain_len = pending_req.input_len - cached_len # 还需要处理的 token 数
chunk_size = min(self.token_budget, remain_len) # 取预算和剩余的最小值
is_chunked = chunk_size < remain_len # 如果预算不够,就是 chunked
CLS = ChunkedReq if is_chunked else Req

# 步骤 2: 扣除预算和更新保留空间
self.token_budget -= chunk_size
self.reserved_size += remain_len + pending_req.output_len

# 步骤 3: 更新 token_pool (输入 token IDs)
_slice = slice(cached_len, cached_len + chunk_size)
device_ids = self.table_manager.token_pool[table_idx][_slice]
device_ids.copy_(pending_req.input_ids[_slice].pin_memory(), non_blocking=True)

# 步骤 4: 创建 Req 或 ChunkedReq
return CLS(
input_ids=pending_req.input_ids[:cached_len + chunk_size],
table_idx=table_idx,
cached_len=cached_len,
output_len=pending_req.output_len,
uid=pending_req.uid,
cache_handle=cache_handle,
sampling_params=pending_req.sampling_params,
)

chunk_size 的计算:

1
2
3
4
5
6
7
8
9
# 例子 1: 预算足够
remain_len = 100, token_budget = 512
chunk_size = min(512, 100) = 100 # 一次处理完
is_chunked = False → 创建 Req

# 例子 2: 预算不足
remain_len = 1000, token_budget = 512
chunk_size = min(512, 1000) = 512 # 只能处理 512
is_chunked = True → 创建 ChunkedReq

为什么只更新 token_pool,不分配 page_table?

因为 page_table 的内容 (页面索引) 需要等 Scheduler 统一分配 out_loc 后才能填充:

1
2
3
4
5
6
7
# PrefillManager: 只更新 token_pool
device_ids = token_pool[table_idx][cached_len:cached_len+chunk_size]
device_ids.copy_(input_ids[...])

# Scheduler: 分配页面并填充 page_table
out_loc = cache_manager.allocate(needed_size)
page_table.view(-1)[load_indices] = out_loc

3.5 资源分配 (_try_allocate_one)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# prefill.py: 38-61
def _try_allocate_one(self, req: PendingReq) -> Tuple[BaseCacheHandle, int] | None:
# 步骤 1: 检查 table_idx 是否可用
if self.table_manager.available_size == 0:
return None

# 步骤 2: 前缀匹配
handle, match_indices = self.cache_manager.match_req(req)
cached_len = handle.cached_len
extend_len = req.input_len - cached_len
estimated_len = extend_len + req.output_len # Prefill + Decode 需要的总页面数

# 步骤 3: 第一次检查内存
if estimated_len + self.reserved_size > self.cache_manager.available_size:
return None

# 步骤 4: 锁定缓存
self.cache_manager.lock(handle)

# 步骤 5: 第二次检查内存 (lock 可能触发驱逐)
if estimated_len + self.reserved_size > self.cache_manager.available_size:
return self.cache_manager.unlock(handle) # 失败,解锁

# 步骤 6: 分配 table_idx
table_idx = self.table_manager.allocate()

# 步骤 7: 复制缓存部分的 page_table 和 token_pool
if cached_len > 0:
device_ids = self.table_manager.token_pool[table_idx][:cached_len]
page_entry = self.table_manager.page_table[table_idx][:cached_len]
device_ids.copy_(req.input_ids[:cached_len].pin_memory(), non_blocking=True)
page_entry.copy_(match_indices) # 复制前缀的页面索引

return handle, table_idx

为什么需要两次检查内存?

这是经典的 TOCTOU (Time-of-Check to Time-of-Use) 问题:

  1. 第一次检查: 快速判断,避免不必要的 lock 操作
  2. lock(handle): 锁定缓存页面,可能触发驱逐其他缓存
  3. 第二次检查: 确认 lock 后内存仍然足够

例子:

1
2
3
4
5
6
7
8
# 第一次检查: available_size=1000 (足够)
# lock(handle): 锁定 100 个页面,触发驱逐 50 个页面
# 第二次检查: available_size=950 (仍然足够) → 成功

# 或者:
# 第一次检查: available_size=1000 (足够)
# lock(handle): 锁定 100 个页面,触发驱逐 980 个页面
# 第二次检查: available_size=20 (不够了) → 失败,解锁

reserved_size 的作用:

reserved_size = decode_manager.inflight_tokens 是为正在 Decode 的请求预留的空间:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 场景:
# - 请求 A, B, C 正在 Decode,总共占用 300 个页面
# - 新请求 D 需要 Prefill,需要 500 个页面
# - 可用内存: 700 个页面

# 如果不保留:
# - 分配 500 个页面给 D → 剩余 200 个页面
# - 下一批次 Decode(A, B, C, D) 需要 304 个页面 (每个请求 +1)
# - 内存不足,OOM!

# 如果保留:
# - reserved_size = 300
# - estimated_len + reserved_size = 500 + 300 = 800 > 700
# - 拒绝分配,避免 OOM

4. DecodeManager - 请求管理

4.1 核心职责

DecodeManager 非常简单,只做 3 件事:

  1. 维护 running_reqs: 所有正在 Decode 的请求集合
  2. 调度 Decode 批次: 将所有 running_reqs 打包成一个 Batch
  3. 计算 inflight_tokens: 为 PrefillManager 提供 reserved_size

4.2 核心方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# decode.py: 10-31
@dataclass
class DecodeManager:
running_reqs: Set[Req] = field(default_factory=set)

# 过滤请求,只保留可以继续 decode 的
def filter_reqs(self, reqs: Iterable[Req]) -> None:
self.running_reqs = {req for req in self.running_reqs.union(reqs) if req.can_decode()}

# 移除已完成的请求
def remove_req(self, req: Req) -> None:
self.running_reqs.discard(req)

# 计算所有正在运行的请求的剩余 token 数
@property
def inflight_tokens(self) -> int:
return sum(req.remain_len for req in self.running_reqs)

# 调度下一个 Decode 批次
def schedule_next_batch(self) -> Batch | None:
if not self.runnable:
return None
return Batch(reqs=list(self.running_reqs), phase="decode")

@property
def runnable(self) -> bool:
return len(self.running_reqs) > 0

4.3 filter_reqs 的作用

调用时机: 在 Scheduler._forward() 中,每次前向传播后调用:

1
2
3
4
5
6
7
# scheduler.py: 225
def _forward(self, forward_input: ForwardInput) -> ForwardOutput:
self._load_token_ids(forward_input)
forward_output = self.engine.forward_batch(batch, sample_args)
self._write_token_ids(forward_input, forward_output)
self.decode_manager.filter_reqs(forward_input.batch.reqs) # 这里调用!
return forward_output

作用: 更新 running_reqs,添加新完成 Prefill 的请求,移除已完成的请求:

1
2
3
4
5
6
7
8
# 例子:
# 当前 running_reqs = {A, B, C} # 正在 Decode
# forward_input.batch.reqs = [D, E] # 刚完成 Prefill 的请求

# filter_reqs 执行:
# 1. union(reqs): {A, B, C} ∪ {D, E} = {A, B, C, D, E}
# 2. 过滤 can_decode(): 假设 A 已完成 (can_decode()=False)
# 3. 结果: running_reqs = {B, C, D, E}

为什么要过滤 can_decode()?

can_decode() 返回 False 的两种情况:

  1. ChunkedReq: 还在 Prefill 阶段,不应该加入 running_reqs
  2. **已完成的请put_len >= max_tokens`,不需要继续 Decode
1
2
3
4
5
6
7
8
9
# ChunkedReq 永远返回 False
class ChunkedReq(Req):
def can_decode(self) -> bool:
return False

# 普通 Req 根据 output_len 判断
class Req:
def can_decode(self) -> bool:
return self.output_len < self.sampling_params.max_tokens

例子:

1
2
3
4
5
6
7
8
# Batch 包含:
# - ChunkedReq(A): 还在 Prefill,can_decode()=False → 被过滤
# - Req(B): 刚完成 Prefill,can_decode()=True → 加入 running_reqs
# - Req(C): 正在 Decode,output_len=50/100,can_decode()=True → 保留
# - Req(D): 已完成,output_len=100/100,can_decode()=False → 被移除

# filter_reqs 后:
# running_reqs = {B, C}

4.4 inflight_tokens 的计算

1
return sum(req.remain_len for req in self.running_reqs)

req.remain_len 是请求还需要生成的 token 数:

1
2
3
4
# core.py
@property
def remain_len(self) -> int:
return self.sampling_params.max_tokens - self.output_len

例子:

1
2
3
4
5
# 请求 A: max_tokens=100, output_len=50 → remain_len=50
# 请求 B: max_tokens=200, output_len=10 → remain_len=190
# 请求 C: max_tokens=50, output_len=45 → remain_len=5

# inflight_tokens = 50 + 190 + 5 = 245

4.5 Decode 批次的特点

关键特点: Decode 批次包含所有 running_reqs,没有预算限制。

为什么?

因为 Decode 每次只生成 1 个 token,计算量很小:

1
2
3
4
5
6
# Prefill: 需要计算 input_len 个 token 的 KV (可能很大)
# Decode: 只需要计算 1 个 token 的 KV (固定大小)

# 例子:
# Prefill batch: [A(1000 tokens), B(500 tokens)] → 1500 tokens
# Decode batch: [A(1 token), B(1 token), C(1 token), ..., Z(1 token)] → 26 tokens

所以 Decode 可以将所有请求放在一个批次,最大化 GPU 利用率。

4.6 完整的请求生命周期

一个请求从收到到完成,经历以下状态:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
1. UserMsg (用户请求)

2. PendingReq (pending_list 中等待)

3a. 如果 input_len 很大,token_budget 不足:
→ ChunkedReq (第一次 Prefill,处理部分 tokens)
→ 继续在 pending_list 最前面
→ ChunkedReq (第二次 Prefill,继续处理)
→ ...
→ Req (最后一次 Prefill,处理完所有 tokens)

3b. 如果 input_len 较小,token_budget 足够:
→ Req (一次 Prefill 完成)

4. Req 加入 running_reqs (开始 Decode)

5. Decode 循环 (每次生成 1 token)

6. 完成 (output_len >= max_tokens 或遇到 EOS)

7. 从 running_reqs 移除,释放资源

状态转换详解

转换 1: UserMsg → PendingReq

1
2
3
4
5
6
7
8
# scheduler.py: 122-136
def _process_one_msg(self, msg: BaseBackendMsg):
if isinstance(msg, UserMsg):
self.prefill_manager.add_one_req(msg)

# prefill.py: 121-122
def add_one_req(self, req: UserMsg) -> None:
self.pending_list.append(PendingReq(req.uid, req.input_ids, req.sampling_params))

转换 2: PendingReq → ChunkedReq (第一次)

条件: chunk_size < remain_len (token 预算不足)

1
2
3
4
5
6
7
8
9
10
11
12
# 请求 A: input_len=10000, token_budget=2048
remain_len = 10000 - 0 = 10000
chunk_size = min(2048, 10000) = 2048
is_chunked = 2048 < 10000 = True

# 创建 ChunkedReq
ChunkedReq(
input_ids=input_ids[:2048], # 只处理前 2048 tokens
cached_len=0,
output_len=0,
)
# pending_list = [A, ...] (A 放在最前面,优先完成)

转换 3: ChunkedReq → ChunkedReq (继续 Prefill)

1
2
3
4
5
6
7
8
9
10
11
# 第二次 Prefill:
# cached_len=2048 (上次处理到这里)
# remain_len = 10000 - 2048 = 7952
# chunk_size = min(2048, 7952) = 2048
# is_chunked = True

ChunkedReq(
input_ids=input_ids[:4096], # 处理到 4096
cached_len=2048,
output_len=0,
)

转换 4: ChunkedReq → Req (最后一次 Prefill)

条件: chunk_size >= remain_len (剩余 tokens 可以一次处理完)

1
2
3
4
5
6
7
8
9
10
11
12
# 第五次 Prefill:
# cached_len=8192
# remain_len = 10000 - 8192 = 1808
# chunk_size = min(2048, 1808) = 1808
# is_chunked = 1808 < 1808 = False

Req( # 创建普通 Req
input_ids=input_ids[:10000], # 处理完所有 tokens
cached_len=8192,
output_len=0,
)
# Prefill 完成!

转换 5: Req → running_reqs (开始 Decode)

时机: Scheduler._forfilter_reqs`

1
2
3
4
5
6
# scheduler.py: 225
self.decode_manager.filter_reqs(forward_input.batch.reqs)

# decode.py: 13-14
def filter_reqs(self, reqs: Iterable[Req]) -> None:
self.running_reqs = {req for req in self.running_reqs.union(reqs) if req.can_decode()}

转换 6: Decode 循环

每次 Decode:

  1. 生成 1 个 token
  2. output_len += 1
  3. 检查是否完成: output_len >= max_tokens 或遇到 EOS
1
2
3
4
# scheduler.py: 91-93
finished = not req.can_decode() # output_len >= max_tokens
if not req.sampling_params.ignore_eos:
finished |= next_token == self.eos_token_id

转换 7: 从 running_reqs 移除

1
2
3
4
5
6
7
8
9
# scheduler.py: 97-100
if finished:
self.finished_reqs.add(req)
self.decode_manager.remove_req(req)

# 释放资源
for req in self.finished_reqs.difference(ongoing_reqs):
self.table_manager.free(req.table_idx)
self.cache_manager.free_and_cache_finished_req(...)

完整例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# 请求 A: input_len=10000, max_tokens=100, token_budget=2048

# T0: 收到请求
UserMsg(uid=1, input_ids=[...10000 tokens...])

PendingReq(uid=1, input_ids=[...10000 tokens...])

# T1-T4: Chunked Prefill (分 5 次处理)
ChunkedReq(input_ids[:2048], cached_len=0) # 第 1 次
ChunkedReq(input_ids[:4096], cached_len=2048) # 第 2 次
ChunkedReq(input_ids[:6144], cached_len=4096) # 第 3 次
ChunkedReq(input_ids[:8192], cached_len=6144) # 第 4 次

# T5: 最后一次 Prefill
Req(input_ids[:10000], cached_len=8192)
# filter_reqs 后: running_reqs = {A}

# T6-T105: Decode 循环 (生成 100 tokens)
Req(output_len=1) → running_reqs = {A}
Req(output_len=2) → running_reqs = {A}
...
Req(output_len=100) → finished!

# T106: 完成
# running_reqs = {} (A 被移除)
# 释放 table_idx 和缓存

5. CacheManager - 缓存管理

5.1 核心职责

CacheManager 是 Scheduler 和底层 KVCache Manager (Radix Tree) 之间的桥梁:

  1. 管理空闲页面 (_free_slots): 维护可用的页面索引
  2. 前缀匹配 (match_req): 调用 Radix Tree 匹配已缓存的前缀
  3. 页面分配 (allocate): 分配新页面,必要时驱逐缓存
  4. 锁定/解锁 (lock/unlock): 防止缓存被驱逐
  5. 缓存完成请求 (free_and_cache_finished_req): 将完成的请求插入 Radix Tree

5.2 初始化

1
2
3
4
5
6
7
8
9
# cache.py: 13-18
def __init__(self, device: torch.device, num_pages: int, type: str):
# 空闲页面索引 (例如: [0, 1, 2, ..., 40959])
self._free_slots = torch.arange(num_pages, dtype=torch.int32, device=device)

# 底层 KVCache Manager (Radix Tree 或 Naive)
self.manager = create_cache_manager(device=device, type=type)

self.num_pages = num_pages

5.3 前缀匹配: 为什么要 [: input_len - 1]?

1
2
3
4
5
# cache.py: 24-27
def match_req(self, req: PendingReq):
input_len = req.input_len
# 关键: 减 1,不匹配最后一个 token
return self.manager.match_prefix(req.input_ids[: input_len - 1])

真正的原因: Logits 必须重新计算

核心问题: KV Cache 不存储 Logits (最后一层的预测概率分布)。

1
2
3
4
5
6
7
# KV Cache 存储的内容:
# - 每一层的 Key 和 Value 矩阵
# - 例如: layer_0_k, layer_0_v, layer_1_k, layer_1_v, ...

# KV Cache 不存储:
# - 最后一层的 Logits (vocab_size 维度的概率分布)
# - 因为 Logits 太大 (例如: 32000 维),存储浪费显存

为什么必须至少计算 1 个 Token?

因为需要 Logits 来采样下一个 token!

1
2
3
4
5
6
7
8
9
10
11
12
13
# 场景: 完全匹配
# 旧请求: [1, 2, 3, 4, 5]
# 新请求: [1, 2, 3, 4, 5]

# 如果全匹配 (不减 1):
cached_len = 5
extend_len = 5 - 5 = 0 # 没有 token 需要计算!

# 问题:
# - 没有 token 进入模型前向传播
# - 模型不输出 Logits
# - 采样器无法采样下一个 token
# - 系统卡死!

减 1 的工程技巧:

1
2
3
4
5
6
7
8
9
10
11
12
# 如果减 1:
# match_prefix([1, 2, 3, 4]) # 只匹配前 4 个

cached_len = 4
extend_len = 5 - 4 = 1 # 强制至少计算 1 个 token!

# 流程:
# 1. 加载前 4 个 token 的 KV Cache
# 2. 计算 token [5] 的前向传播
# 3. 输出 token [5] 位置的 Logits
# 4. 采样器根据 Logits 采样下一个 token [6]
# 5. 成功!

重要澄清: 因果掩码 (Causal Mask) 确保 Token 5 的 KV Cache 绝对不会看到后面的 6 和 7。Token 5 的 KV 在两个请求中是完全一样的,但必须重新计算以获得 Logits

总结: [: input_len - 1] 是一个工程妥协,确保至少计算 1 个 token 来获得 Logits 用于采样。

5.4 页面状态管理

KV Cache 页面有 3 种状态:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
总页面数 (num_pages) = 40960

┌─────────────────────────────────────────────┐
│ 1. 空闲页面 (_free_slots) │
│ - 从未使用过的页面 │
│ - 可以直接分配 │
│ - 例如: [0, 1, 2, ..., 1000] │
├─────────────────────────────────────────────┤
│ 2. 已缓存但可驱逐 (evictable) │
│ - 已经缓存了 KV,但没有被锁定 │
│ - 可以驱逐后重新分配 │
│ - 例如: 已完成的请求的 KV Cache │
├─────────────────────────────────────────────┤
│ 3. 已缓存且被锁定 (locked) │
│ - 正在使用的 KV Cache │
│ - 不能驱逐 │
│ - 例如: 正在 Decode 的请求的 KV Cache │
└─────────────────────────────────────────────┘

available_size 的计算

1
2
3
4
# cache.py: 29-31
@property
def available_size(self) -> int:
return self.manager.size_info.evictable_size + len(self._free_slots)

可用空间 = 空闲页面 + 可驱逐页面

为什么? 因为这两种页面都可以用来分配新的 KV Cache:

  • 空闲页面: 直接分配
  • 可驱逐页面: 驱逐后分配

例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
# 初始状态:
# _free_slots = [0, 1, ..., 999] # 1000 个空闲页面
# evictable_size = 0
# available_size = 0 + 1000 = 1000

# 分配 500 个页面给请求 A:
# _free_slots = [500, 501, ..., 999] # 500 个空闲页面
# available_size = 0 + 500 = 500

# 请求 A 完成,缓存到 Radix Tree (unlock):
# _free_slots = [500, 501, ..., 999] # 500 个空闲页面
# evictable_size = 500 # 请求 A 的 500 个页面可驱逐
# available_size = 500 + 500 = 1000 # 又恢复到 1000!

什么页面是 evictable 的?

  • 可驱逐: 已完成的请求 (插入 Radix Tree 后 unlock)
  • 不可驱逐: 正在 Prefill/Decode 的请求 (已 lock)

5.5 页面分配 (allocate)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# cache.py: 39-52
def allocate(self, needed_len: int) -> torch.Tensor:
# 情况 1: 空闲页面足够
if needed_len <= (free_len := len(self._free_slots)):
allocated = self._free_slots[:needed_len]
self._free_slots = self._free_slots[needed_len:]
return allocated

# 情况 2: 需要驱逐缓存
evicted = self.manager.evict(needed_len - free_len)
merged = torch.cat([self._free_slots, evicted])
allocated = merged[:needed_len]
self._free_slots = merged[needed_len:]
return allocated

驱逐逻辑

evict 返回什么? 被驱逐的页面索引。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 例子: 需要 800 个页面,只有 500 个空闲

# 步骤 1: 检查空闲页面
needed_len = 800
free_len = 500 # 不够!

# 步骤 2: 驱逐缓存
evicted = self.manager.evict(800 - 500) # 驱逐 300 个
# evicted = tensor([500, 501, ..., 799])

# 步骤 3: 合并
merged = torch.cat([self._free_slots, evicted])
# merged = [0, 1, ..., 499, 500, 501, ..., 799] # 800 个

# 步骤 4: 分配
allocated = merged[:800] # [0, 1, ..., 799]
self._free_slots = merged[800:] # []

return allocated

5.6 缓存完成请求 (free_and_cache_finished_req)

1
2
3
4
5
6
7
8
9
10
# cache.py: 54-62
def free_and_cache_finished_req(self, old_handle, input_ids, indices):
# 步骤 1: 插入 Radix Tree
in_cache_len = self.manager.insert_prefix(input_ids, indices)

# 步骤 2: 释放新插入的页面
self._free(indices[old_handle.cached_len : in_cache_len])

# 步骤 3: 解锁
self.unlock(old_handle)

为什么只释放 [old_handle.cached_len : in_cache_len]?

因为只有新插入的部分需要释放:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 例子:
# old_handle.cached_len = 3 (前缀 [1, 2, 3] 已在 Radix Tree)
# input_ids = [1, 2, 3, 4, 5]
# indices = [10, 11, 12, 13, 14]

# insert_prefix 返回 in_cache_len = 5

# 释放 indices[3:5] = [13, 14]
# 为什么?

# - indices[0:3] = [10, 11, 12] 对应 [1, 2, 3]
# 已经在 Radix Tree 中,不需要释放

# - indices[3:5] = [13, 14] 对应 [4, 5]
# 新插入的,需要释放 (已经复制到 Radix Tree)

完整流程:

1
2
3
4
5
6
7
8
9
10
11
请求完成前:
- indices[0:3] → Radix Tree 管理 (前缀)
- indices[3:5] → 请求 A 独占 (新计算的)

插入 Radix Tree 后:
- indices[0:5] → Radix Tree 管理 (整个序列)
- indices[3:5] → 可以释放了 (已复制到 Radix Tree)

unlock 后:
- 所有页面变成 evictable
- evictable_size 增加

6. TableManager - 页表管理

6.1 核心职责

TableManager 是调度系统中最简单的组件,只做 3 件事:

  1. 管理 table_idx: 分配和释放页表的行索引
  2. 维护 page_table: 引用 Engine 的 page_table
  3. 维护 token_pool: 存储所有请求的 token IDs

6.2 初始化

1
2
3
4
5
6
7
# table.py: 5-11
def __init__(self, max_running_reqs: int, page_table: torch.Tensor):
self._max_running_reqs = max_running_reqs
self._free_slots = list(range(max_running_reqs)) # [0, 1, 2, ..., max_running_reqs-1]
self.page_table = page_table # 引用 Engine 的 page_table
# 初始化为 0,确保 dummy request 读取到有效的 token_id
self.token_pool = torch.zeros_like(page_table, dtype=torch.int32)

关键数据结构:

  • page_table: 2D 表 [max_running_req, max_seq_len],存储页面索引
  • token_pool: 2D 表 [max_running_req, max_seq_len],存储 token IDs
  • _free_slots: 可用的行索引列表

6.3 为什么用 list 而不是 torch.Tensor?

1
self._free_slots = list(range(max_running_reqs))  # 使用 Python list

原因: _free_slots 需要频繁的 pop()append() 操作:

1
2
3
4
5
6
7
# 使用 list (当前实现):
self._free_slots.pop() # O(1) - 从尾部删除
self._free_slots.append(slot) # O(1) - 追加到尾部

# 如果使用 torch.Tensor:
self._free_slots = self._free_slots[:-1] # O(n) - 需要复制整个 tensor
self._free_slots = torch.cat([self._free_slots, torch.tensor([slot])]) # O(n)

关键: Python listpop()append() 是 O(1),而 torch.Tensor 的切片和拼接是 O(n)。

6.4 page_table vs token_pool

两个 2D 表存储不同的数据:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 例子: 请求 A 的数据
table_idx = 0

# token_pool[0] 存储 token IDs (输入数据):
token_pool[0] = [1, 2, 3, 4, 5, 0, 0, 0, ...]

# page_table[0] 存储页面索引 (KV Cache 的位置):
page_table[0] = [10, 11, 12, 13, 14, 0, 0, 0, ...]

# 含义:
# - token 1 的 KV Cache 在页面 10
# - token 2 的 KV Cache 在页面 11
# - token 3 的 KV Cache 在页面 12
# - ...

数据流:

1
2
3
4
5
1. PrefillManager: 将 input_ids 写入 token_pool
2. Scheduler: 分配页面,写入 page_table
3. _load_token_ids: 从 token_pool 读取 input_ids
4. Engine: 根据 page_table 读写 KV Cache
5. _write_token_ids: 将采样结果写回 token_pool

6.5 为什么 token_pool 初始化为 0?

1
2
# NOTE: dummy request also use this pool to get the input ids
self.token_pool = torch.zeros_like(page_table, dtype=torch.int32)

原因: dummy request 需要从 token_pool 读取 input_ids。

1
2
3
4
5
6
7
8
9
10
11
12
13
# dummy_req 的 table_idx = max_running_req (最后一行)
# 在 CUDA Graph replay 时,会从 token_pool 读取:
input_ids = token_pool[max_running_req, 0]

# 如果没有初始化:
# - token_pool[max_running_req, 0] = ??? (未定义)
# - 可能读取到任意值,比如 12345
# - model.forward(input_ids=[12345]) → 错误!超出 vocab_size

# 如果初始化为 0:
# - token_pool[max_running_req, 0] = 0
# - model.forward(input_ids=[0]) → 正确!
# - 0 通常是特殊 token (<pad>, <unk>, <bos>)

为什么 token_id = 0 是有效的?

在大多数 tokenizer 中,token_id = 0 是特殊 token:

  • Llama: 0 = <unk> (unknown token)
  • GPT: 0 = <|endoftext|>
  • BERT: 0 = [PAD]

这些特殊 token 不会导致模型崩溃,可以安全地用于 dummy request。

6.6 allocate() 和 free() 的实现

1
2
3
4
5
6
# table.py: 17-21
def allocate(self) -> int:
return self._free_slots.pop() # 从尾部删除,O(1)

def free(self, slot: int) -> None:
self._free_slots.append(slot) # 追加到尾部,O(1)

为什么从尾部 pop?

因为 list.pop() 从尾部删除是 O(1),而 list.pop(0) 从头部删除是 O(n):

1
2
3
4
5
6
7
# pop() 从尾部: O(1)
_free_slots = [0, 1, 2, 3, 4]
slot = e_slots.pop() # 返回 4,剩余 [0, 1, 2, 3]

# pop(0) 从头部: O(n) - 需要移动所有元素
_free_slots = [0, 1, 2, 3, 4]
slot = _free_slots.pop(0) # 返回 0,剩余 [1, 2, 3, 4]

free() 不检查有效性:

1
2
def free(self, slot: int) -> None:
self._free_slots.append(slot) # 直接追加,不检查

信任调用者,不做额外检查 (性能优化)。

6.7 TableManager 总结

  • 最简单的组件: 只有 22 行代码
  • 核心功能: 分配和释放 table_idx
  • 两个 2D 表: page_table (页面索引) + token_pool (token IDs)
  • O(1) 操作: 使用 Python list 的 pop()append()
  • dummy request 支持: token_pool 初始化为 0

7. 总结

7.1 调度系统的核心设计

  1. Prefill 优先: 降低首 token 延迟,提升用户体验
  2. Chunked Prefill: 分块处理大请求,避免单个批次过大
  3. 重叠调度: 两个 CUDA Stream 隐藏 CPU 延迟,提高 GPU 利用率
  4. 资源预留: reserved_size 确保 Decode 不会 OOM
  5. 前缀匹配: Radix Tree 复用已缓存的 KV,减少计算

7.2 关键数据结构

  • PendingReq: 等待处理的请求
  • ChunkedReq: 已处理一部分但还没完成的请求
  • Req: 正常请求
  • Batch: 一批次要处理的请求
  • ForwardInput: 批次 + 采样参数 + 2D 索引
  • ForwardOutput: 采样结果 + 异步拷贝事件

7.3 资源分配时间线

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
1. PrefillManager.schedule_next_batch():
- 分配 table_idx (页表行索引)
- 前缀匹配,复制缓存部分的 page_table
- 更新 token_pool (输入 token IDs)

2. Scheduler._prepare_batch():
- 分配新页面 (out_loc)
- 写入 page_table (新页面索引)
- 准备注意力元数据

3. Engine.forward_batch():
- 模型前向传播
- 注意力计算
- 采样

4. Scheduler._process_last_data():
- 等待 GPU→CPU 拷贝
- 发送 DetokenizeMsg
- 释放已完成请求的资源

7.4 性能优化总结

优化技术 性能提升 原理
Prefill 优先 降低首 token 延迟 50% 新请求优先处理
Chunked Prefill 支持任意长度输入 分块处理大请求
重叠调度 提高吞吐量 10-20% 隐藏 CPU 延迟
前缀匹配 减少计算 30-70% 复用已缓存的 KV
Continuous Batching 提高 GPU 利用率 2-3x 动态批处理

7.5 各组件核心要点

Scheduler (主调度器)

  • 6 大职责: 接收消息、调度批次、准备数据、执行推理、处理结果、资源管理
  • 重叠调度: 两个 CUDA Stream 隐藏 CPU 延迟
  • 批次准备: 5 个步骤 (分配页面、Pad、准备索引、写 page_table、准备元数据)

PrefillManager

  • Chunked Prefill: 分块处理大请求,支持任意长度输入
  • 资源预留: reserved_size 确保 Decode 不会 OOM
  • ChunkedReq 优先: 优先完成已开始的请求

DecodeManager

  • 维护 running_reqs: 所有正在 Decode 的请求
  • filter_reqs: 添加新请求,移除已完成请求
  • 无批次限制: Decode 可以包含所有 running_reqs

CacheManager

  • Logits 不存储: [: input_len - 1] 强制计算 1 个 token 获得 Logits
  • 三种页面状态: 空闲 + 可驱逐 + 锁定
  • 驱逐策略: 空闲不足时从 Radix Tree 驱逐 LRU 页面

TableManager

  • 最简单组件: 只管理 table_idx 的分配和释放
  • O(1) 操作: 使用 Python list 的 pop()append()
  • 两个 2D 表: page_table (页面索引) + token_pool (token IDs)