学习文件kvcache/base.py, kvcache/mha_pool.py, kvcache/naive_manager.py, kvcache/radix_manager.py


1. 为什么需要 KV Cache 管理?

在阶段一我们学习了 KV Cache 的基本概念:缓存已计算的 Key 和 Value,避免重复计算

但在实际系统中,KV Cache 管理面临更多挑战:

挑战1:内存有限

  • GPU 内存有限(例如 24GB)
  • 每个请求的 KV Cache 可能很大(长文本、长对话)
  • 需要驱逐策略来释放空间

挑战2:前缀复用

  • 多个请求可能有相同的前缀(例如 System Prompt)
  • 重复存储浪费内存
  • 需要前缀共享机制

挑战3:并发安全

  • 多个请求同时访问缓存
  • 驱逐时不能删除正在使用的缓存
  • 需要锁机制

2. KV Cache 系统架构

Mini-SGLang 的 KV Cache 系统分为两层:

1
2
3
4
5
6
7
8
9
┌─────────────────────────────────────────┐
│ BaseCacheManager │ 管理层
│ (前缀匹配、驱逐策略、生命周期管理) │
└─────────────────────────────────────────┘

┌─────────────────────────────────────────┐
│ BaseKVCache │ 存储层
│ (实际的 Tensor 存储和读写) │
└─────────────────────────────────────────┘

类比

  • BaseKVCache = 仓库(存储货物)
  • BaseCacheManager = 仓库管理员(决定存什么、删什么、怎么找)

3. BaseKVCache:存储层

3.1 核心接口

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class BaseKVCache(ABC):
@abstractmethod
def k_cache(self, index: int) -> torch.Tensor:
"""获取第 index 层的 K 缓存"""

@abstractmethod
def v_cache(self, index: int) -> torch.Tensor:
"""获取第 index 层的 V 缓存"""

@abstractmethod
def store_kv(
self, k: torch.Tensor, v: torch.Tensor,
out_loc: torch.Tensor, layer_id: int
) -> None:
"""存储 KV 到指定位置"""

@property
@abstractmethod
def device(self) -> torch.device: ...

@property
@abstractmethod
def dtype(self) -> torch.dtype: ...

@property
@abstractmethod
def num_layers(self) -> int: ...

3.2 职责

  1. 分配物理存储空间(GPU 显存)
  2. 提供读写接口(k_cache, v_cache, store_kv)
  3. 管理多层缓存(Transformer 的每一层都有 KV Cache)

3.3 关键设计

为什么 k_cache 和 v_cache 需要 layer_id?

因为 Transformer 有多层(例如 Llama 有 32 层),每一层都有独立的 KV Cache:

1
2
3
4
Layer 0:  K[0], V[0]
Layer 1: K[1], V[1]
...
Layer 31: K[31], V[31]

4. BaseCacheManager:管理层

4.1 核心接口

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class BaseCacheManager(ABC):
@abstractmethod
def match_prefix(self, input_ids: torch.Tensor) -> Tuple[BaseCacheHandle, torch.Tensor]:
"""匹配最长前缀,返回 handle 和 indices"""

@abstractmethod
def lock_handle(self, handle: BaseCacheHandle, unlock: bool = False) -> None:
"""锁定/解锁 handle,防止被驱逐"""

@abstractmethod
def insert_prefix(self, input_ids: torch.Tensor, indices: torch.Tensor) -> int:
"""插入新前缀,返回已缓存的长度"""

@abstractmethod
def evict(self, size: int) -> torch.Tensor:
"""驱逐缓存释放空间,返回被驱逐的 indices"""

@abstractmethod
def reset(self) -> None:
"""重置缓存"""

@property
@abstractmethod
def size_info(self) -> SizeInfo:
"""获取缓存大小信息"""

@abstractmethod
def check_integrity(self) -> None:
"""检查缓存完整性"""

4.2 职责

  1. 前缀匹配:找到可以复用的缓存
  2. 生命周期管理:插入、驱逐、锁定
  3. 并发安全:防止正在使用的缓存被驱逐
  4. 完整性检查:检测缓存损坏

5. BaseCacheHandle:缓存句柄

1
2
3
@dataclass(frozen=True)
class BaseCacheHandle(ABC):
cached_len: int

作用

  • 代表"缓存中的某个前缀"
  • 类似于"提货单"或"借书卡"
  • 用于 lock/unlock 操作

为什么需要 handle?

因为 match_prefix 返回的 indices 只是物理地址,不包含管理信息。handle 提供了:

  • 匹配的长度(cached_len
  • 抽象的引用(用于锁定)

6. SizeInfo:缓存大小信息

1
2
3
4
5
6
7
class SizeInfo(NamedTuple):
evictable_size: int # 可驱逐的大小
protected_size: int # 受保护的大小(被锁定)

@property
def total_size(self) -> int:
return self.evictable_size + self.protected_size

为什么要区分 evictable 和 protected?

  • protected_size:被 lock_handle 锁定的缓存,正在使用中,不能驱逐
  • evictable_size:未被锁定的缓存,可以驱逐

类比

  • 图书馆的书分为"在架上的"(可借)和"已借出的"(不可借)

7. 完整的缓存使用流程

假设有一个请求 "Hello world how are you",缓存中已有 "Hello world"

步骤1:匹配前缀

1
2
3
4
5
6
input_ids = [Hello, world, how, are, you]
handle, indices = match_prefix(input_ids)

# 返回:
# handle.cached_len = 2 (匹配了 "Hello world")
# indices = [page_5, page_12] (这两个 token 的物理位置)

作用:找到可以复用的部分,避免重复计算。

步骤2:锁定 handle

1
lock_handle(handle, unlock=False)

作用:防止并发驱逐删除正在使用的缓存。

步骤3:分配新 page

1
2
new_pages = allocate_pages(3)  # 为 [how, are, you] 分配
# new_pages = [page_20, page_21, page_22]

步骤4:插入新前缀

1
2
3
4
5
already_cached = insert_prefix(
input_ids=[Hello, world, how, are, you],
indices=[page_5, page_12, page_20, page_21, page_22]
)
# already_cached = 2 (前2个token已在缓存)

作用:更新缓存管理器的数据结构(例如 Radix Tree)。

步骤5:释放重复的 page

1
free_pages([page_5, page_12])

作用:因为这两个 page 已经在缓存中了,不需要重复存储。

步骤6:存储新 token 的 KV

1
store_kv(k, v, out_loc=[page_20, page_21, page_22], layer_id)

作用:实际存储 KV 数据到物理位置。

步骤7:解锁 handle

1
lock_handle(handle, unlock=True)

作用:使用完毕,允许后续驱逐。


8. 关键设计问题

8.1 为什么 match_prefix 不修改缓存?

原因

  • match_prefix 是只读操作(查询)
  • 可以并发调用,不需要加锁
  • 修改操作由 insert_prefix 负责

8.2 为什么 lock_handle 要在 insert_prefix 之前?

并发安全问题

1
2
3
4
5
6
7
8
9
10
# ❌ 错误顺序
handle, indices = match_prefix(...)
# 此时另一个线程调用 evict,可能删除 indices!
insert_prefix(...) # 💥 错误!indices 已被删除

# ✅ 正确顺序
handle, indices = match_prefix(...)
lock_handle(handle) # 立即锁定,防止驱逐
insert_prefix(...)
lock_handle(handle, unlock=True)

8.3 为什么 insert_prefix 返回 already_cached?

避免重复存储

1
2
3
4
5
6
7
8
9
# 缓存中已有: "Hello world"
input_ids = [Hello, world, how, are, you]
indices = [page_5, page_12, page_20, page_21, page_22]

already_cached = insert_prefix(input_ids, indices) # 返回 2

# 调用者应该:
# 1. 释放前2个page(page_5, page_12)
# 2. 只存储后3个token到后3个page

8.4 为什么 evict 可能驱逐更多?

驱逐粒度是"前缀",不是"page"

1
2
3
4
5
6
7
8
# 缓存中有两个前缀:
# Prefix A: "Hello world" (10 tokens, 10 pages)
# Prefix B: "How are you" (3 tokens, 3 pages)

evicted_indices = evict(size=5) # 请求驱逐 5 pages

# 实际可能驱逐 10 pages(整个 Prefix A)
# 因为不能只驱逐半个前缀!

为什么不能只驱逐半个前缀?

  • 前缀匹配需要完整的 token 序列
  • 驱逐 “Hello wor” 保留 “ld” 没有意义
  • 驱逐策略:选择一个或多个完整的前缀驱逐

9. check_integrity:缓存完整性检查

9.1 什么情况下缓存会损坏?

  1. 引用计数错误

    • 某个 handle 被锁定,但 size_info 没有正确更新
    • 导致 evictable_size + protected_size != total_size
  2. Radix Tree 结构错误(后面会学到):

    • 父子节点关系断裂
    • 节点的 token 序列不连续
  3. indices 越界

    • 返回的 indices 超出了 KV Cache 的实际大小
  4. 内存泄漏

    • 某些前缀被删除了,但对应的 page 没有释放

9.2 check_integrity 的作用

定期检查这些不一致,及时发现 bug。


10. KVCacheLayout:缓存布局

1
2
3
class KVCacheLayout(enum.Enum):
LayerFirst = enum.auto() # [num_layers, num_pages, page_size, ...]
PageFirst = enum.auto() # [num_pages, num_layers, page_size, ...]

两种布局的区别

LayerFirst(层优先)

1
2
# Shape: [num_layers, num_pages, page_size, num_heads, head_dim]
k_cache[layer_id][page_id] # 访问某一层的某个page

优点

  • 同一层的数据连续存储
  • 适合按层遍历

PageFirst(页优先)

1
2
# Shape: [num_pages, num_layers, page_size, num_heads, head_dim]
k_cache[page_id][layer_id] # 访问某个page的某一层

优点

  • 同一个 page 的所有层连续存储
  • 适合按 page 分配/释放

费曼挑战

挑战1:BaseKVCache vs BaseCacheManager

问题:用简单的话解释 BaseKVCache 和 BaseCacheManager 的区别。

解答

  • BaseKVCache = 仓库,负责实际存储货物(KV Tensor)
  • BaseCacheManager = 仓库管理员,负责决定存什么、删什么、怎么找货物

关键类比

  • 就像图书馆的"书架"(存储)和"图书管理系统"(管理)

挑战2:为什么需要 lock_handle?

问题:解释 lock_handle 的作用,以及不锁定会发生什么。

解答

  • 作用:防止正在使用的缓存被驱逐
  • 不锁定的后果:并发场景下,线程A正在使用某个前缀,线程B可能因为内存不足驱逐它,导致线程A访问到无效数据

关键类比

  • 就像图书馆借书,借出的书不能被下架

挑战3:handle vs indices

问题:解释 match_prefix 返回的 handle 和 indices 分别是什么。

解答

  • handle:抽象的引用,代表"缓存中的某个前缀",包含 cached_len,用于 lock/unlock
  • indices:物理地址,指向 KV Cache 中的实际存储位置,用于 Attention 计算

关键类比

  • handle = 图书馆借书卡(证明你借了这本书)
  • indices = 书架编号(告诉你书在哪个架子上)

挑战4:insert_prefix 的返回值

问题:解释 insert_prefix 为什么返回 already_cached。

解答

  • 返回值:已经在缓存中的前缀长度
  • 作用:避免重复存储,调用者应该释放这些重复的 page

关键场景

1
2
3
4
# 缓存中已有: "Hello world"
input_ids = [Hello, world, how, are, you]
already_cached = insert_prefix(...) # 返回 2
# 调用者释放前2个page,只存储后3个token

挑战5:为什么 evict 可能驱逐更多?

问题:解释为什么 evict 的实际驱逐大小可能大于请求大小。

解答

  • 原因:驱逐粒度是"前缀",不是"page"
  • 不能只驱逐半个前缀:因为前缀匹配需要完整的 token 序列

关键场景

1
2
3
# 请求驱逐 5 pages
# 但某个前缀有 10 pages
# 实际驱逐 10 pages(整个前缀)

学习总结

核心收获

  1. 两层架构

    • BaseKVCache(存储层)+ BaseCacheManager(管理层)
    • 分离关注点,提高可扩展性
  2. 并发安全

    • lock/unlock 机制防止驱逐正在使用的缓存
    • match → lock → insert → unlock 流程
  3. 前缀复用

    • match_prefix 找到可复用的部分
    • insert_prefix 避免重复存储
  4. 驱逐策略

    • 驱逐粒度是"前缀",不是"page"
    • evictable vs protected 区分
  5. 完整性检查

    • check_integrity 检测缓存损坏
    • 及时发现引用计数、结构错误等问题

11. MHAKVCache:BaseKVCache 的实现

11.1 核心职责

MHAKVCache 是 BaseKVCache 的具体实现,负责:

  1. 在 GPU 上分配物理存储空间
  2. 实现 k_cache, v_cache, store_kv 接口
  3. 支持 Tensor Parallelism(TP)
  4. 支持两种缓存布局(PageFirst / LayerFirst)

11.2 初始化参数

1
2
3
4
5
6
7
8
9
10
11
class MHAKVCache(BaseKVCache):
def __init__(
self,
num_kv_heads: int, # KV 头的数量(GQA)
num_layers: int, # Transformer 层数
head_dim: int, # 每个头的维度
num_pages: int, # 总页数
dtype: torch.dtype, # 数据类型(fp16/bf16)
kv_layout: KVCacheLayout, # 布局方式
device: torch.device, # 设备(GPU)
):

11.3 核心设计

11.3.1 统一的 kv_buffer

1
2
3
4
5
6
7
8
9
# 创建一个统一的 buffer,包含 K 和 V
kv_buffer = torch.empty(
(2, num_layers, num_pages, local_kv_heads, head_dim),
device=device,
dtype=dtype,
)
# 第一维是 2:[0] = K, [1] = V
self._k_buffer = self._kv_buffer[0]
self._v_buffer = self._kv_buffer[1]

为什么用统一的 buffer?

  • ✅ 一次分配,减少内存碎片
  • ✅ K 和 V 的形状完全相同,可以共享布局
  • ✅ 简化内存管理

11.3.2 支持 Tensor Parallelism(TP)

1
2
tp_info = get_tp_info()
local_kv_heads = divide_even(num_kv_heads, tp_info.size)

TP 的作用

  • 将 KV 头均匀分配到多个 GPU 上
  • 解决单 GPU 显存不够的问题

为什么按 KV 头切分?

因为 Attention 计算时,每个 head 独立计算,互不依赖:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 单 GPU(无 TP)
Q.shape = (batch_size, num_qo_heads, head_dim) # (4, 32, 128)
K.shape = (batch_size, num_kv_heads, head_dim) # (4, 8, 128)

# 4 个 GPU(TP = 4)
# GPU 0: heads 0-1
K_0.shape = (4, 2, 128) # KV 头 / 4 = 2

# GPU 1: heads 2-3
K_1.shape = (4, 2, 128)

# ... GPU 2, GPU 3 类似

# 每个 GPU 独立计算,无需通信
output_0 = Attention(Q_0, K_0, V_0)
output_1 = Attention(Q_1, K_1, V_1)
...

# 最后 concat
output = concat([output_0, output_1, output_2, output_3], dim=1)

为什么不按 head_dim 切分?

  • ❌ Attention 计算需要完整的 head_dim(Q @ K^T)
  • ❌ 需要频繁的 GPU 间通信(All-Gather)
  • ❌ 通信开销远大于计算开销

TP 的切分原则:沿着"可并行"的维度切分

11.3.3 支持两种布局

PageFirst(页优先)

1
2
3
4
5
6
7
8
# 原始形状:(2, num_pages, num_layers, local_kv_heads, head_dim)
# 内存布局:同一个 page 的所有层连续存储

# 例如:page_0 的所有层 → page_1 的所有层 → ...
[page_0_layer_0, page_0_layer_1, ..., page_0_layer_31,
page_1_layer_0, page_1_layer_1, ..., page_1_layer_31, ...]

# 优点:分配/释放 page 时,所有层的数据连续

LayerFirst(层优先)

1
2
3
4
5
6
7
8
# 原始形状:(2, num_layers, num_pages, local_kv_heads, head_dim)
# 内存布局:同一层的所有 page 连续存储

# 例如:layer_0 的所有 page → layer_1 的所有 page → ...
[layer_0_page_0, layer_0_page_1, ..., layer_0_page_N,
layer_1_page_0, layer_1_page_1, ..., layer_1_page_N, ...]

# 优点:访问同一层的不同 page 时,数据连续

为什么最后都统一成 LayerFirst?

因为代码的其他部分(Attention、store_kv)都假设是 LayerFirst 布局:

1
self._k_buffer[layer_id]  # 访问第 layer_id 层的所有 page

所以 PageFirst 需要 permute 成 LayerFirst:

1
.permute(0, 2, 1, 3, 4)  # (2, pages, layers, ...) → (2, layers, pages, ...)

11.3.4 多了一个维度 1

1
2
3
self._kv_buffer = kv_buffer.view(2, num_layers, num_pages, 1, local_kv_heads, head_dim)
# ↑
# page_size = 1

为什么需要这个维度?

为了支持 Paged Attention 的批处理。在 Attention 计算时:

1
2
3
4
5
6
7
8
# 单个 token 的 K/V
k.shape = (batch_size, num_kv_heads, head_dim)

# 但 KV Cache 中存储的是多个 token
# 需要一个额外的维度表示"每个 page 有多少个 token"
k_cache.shape = (num_pages, page_size, num_kv_heads, head_dim)
# ↑
# 这里是 page_size

为什么 page_size = 1?

这是一个简化的实现。实际的 Paged Attention(例如 vLLM)中,每个 page 存储多个 token(例如 16 个),可以减少 page 数量。

但 mini-sglang 为了简化,每个 page 只存储 1 个 token。

如果改成 page_size = 16 需要怎么改?

  1. 初始化
1
2
PAGE_SIZE = 16
self._kv_buffer = kv_buffer.view(2, num_layers, num_pages, PAGE_SIZE, local_kv_heads, head_dim)
  1. storage_shape
1
self._storage_shape = (num_pages, PAGE_SIZE, local_kv_heads, head_dim)
  1. out_loc 的含义变了
1
2
3
4
# 当前:out_loc[i] = page_id(直接指向 page)
# 改成 page_size = 16 后:
page_id = out_loc[i] // PAGE_SIZE
offset = out_loc[i] % PAGE_SIZE
  1. store_cache kernel 也需要改
1
2
3
4
5
# 当前:直接写入 k_cache[page_id]
k_cache[page_id] = k[i]

# 改成 page_size = 16 后:
k_cache[page_id, offset] = k[i]

11.3.5 store_kv 实现

1
2
3
4
5
6
7
8
9
10
def store_kv(self, k, v, out_loc, layer_id):
from minisgl.kernel import store_cache

store_cache(
k_cache=self._k_buffer[layer_id].view(self._storage_shape),
v_cache=self._v_buffer[layer_id].view(self._storage_shape),
indices=out_loc,
k=k,
v=v,
)

为什么要 view 成 storage_shape?

因为 store_cache kernel 期望的输入形状是:

1
2
3
k_cache: (num_pages, num_kv_heads, head_dim)
k: (batch_size, num_kv_heads, head_dim)
indices: (batch_size,) # out_loc

self._k_buffer[layer_id] 的形状是:

1
2
3
(num_pages, 1, local_kv_heads, head_dim)
# ↑
# 多了一个维度 1

所以需要 view(self._storage_shape) 去掉这个维度:

1
(num_pages, 1, local_kv_heads, head_dim) → (num_pages, local_kv_heads, head_dim)

11.4 完整示例

假设有以下配置:

  • 4 个 GPU(TP = 4)
  • Llama-7B 模型:32 层,32 个 KV 头,128 维
  • 分配 1000 个 page
  • 使用 bf16(2 字节)
  • LayerFirst 布局

每个 GPU 上的形状

1
2
3
4
5
6
# local_kv_heads = 32 // 4 = 8
_kv_buffer.shape = (2, 32, 1000, 1, 8, 128)

# 显存占用
total_elements = 2 * 32 * 1000 * 1 * 8 * 128 = 65,536,000
memory_bytes = 65,536,000 * 2 = 131,072,000 bytes125 MB

store_kv 的输入形状(单个 GPU)

1
2
3
4
5
6
7
8
9
k.shape = (batch_size, local_kv_heads, head_dim)
= (4, 8, 128)

v.shape = (4, 8, 128)

out_loc.shape = (batch_size,)
= (4,)

layer_id = 0~31

费曼挑战(续)

挑战6:为什么用统一的 kv_buffer?

问题:解释为什么 MHAKVCache 用一个统一的 buffer 存储 K 和 V。

解答

  • 原因1:K 和 V 的形状完全相同,可以共享布局
  • 原因2:一次分配,减少内存碎片
  • 原因3:简化内存管理

关键代码

1
2
3
kv_buffer = torch.empty((2, num_layers, num_pages, local_kv_heads, head_dim), ...)
# ↑
# 第一维是 2:[0] = K, [1] = V

挑战7:为什么 TP 按 KV 头切分?

问题:解释为什么 Tensor Parallelism 按 KV 头切分,而不是按 head_dim 切分。

解答

  • 原因:Attention 计算时,每个 head 独立计算,互不依赖
  • 优点:每个 GPU 独立计算,无需通信
  • 如果按 head_dim 切分:Attention 计算需要完整的 head_dim(Q @ K^T),需要频繁的 GPU 间通信

关键原则

  • TP 沿着"可并行"的维度切分

挑战8:PageFirst vs LayerFirst

问题:解释 PageFirst 和 LayerFirst 两种布局的区别,以及为什么最后都统一成 LayerFirst。

解答

  • PageFirst:同一个 page 的所有层连续存储,适合按 page 分配/释放
  • LayerFirst:同一层的所有 page 连续存储,适合按层访问
  • 为什么统一成 LayerFirst:因为代码的其他部分(Attention、store_kv)都假设是 LayerFirst 布局

关键代码

1
self._k_buffer[layer_id]  # 访问第 layer_id 层的所有 page

挑战9:为什么多了一个维度 1

问题:解释 _kv_buffer.view(2, num_layers, num_pages, 1, local_kv_heads, head_dim) 中间为什么多了一个维度 1

解答

  • 原因:为了支持 Paged Attention 的批处理
  • 含义:这个维度表示"每个 page 有多少个 token"(page_size)
  • 为什么是 1:mini-sglang 简化实现,每个 page 只存储 1 个 token

如果改成 page_size = 16

  • 需要修改初始化、storage_shape、out_loc 的含义、store_cache kernel

挑战10:为什么 store_kv 要 view?

问题:解释 store_kv 为什么要 view 成 storage_shape。

解答

  • 原因:匹配 store_cache kernel 的接口
  • 问题self._k_buffer[layer_id] 的形状是 (num_pages, 1, local_kv_heads, head_dim),多了一个维度 1
  • 解决view(self._storage_shape) 去掉这个维度,变成 (num_pages, local_kv_heads, head_dim)

学习总结

核心收获

  1. 两层架构

    • BaseKVCache(存储层)+ BaseCacheManager(管理层)
    • 分离关注点,提高可扩展性
  2. 并发安全

    • lock/unlock 机制防止驱逐正在使用的缓存
    • match → lock → insert → unlock 流程
  3. 前缀复用

    • match_prefix 找到可复用的部分
    • insert_prefix 避免重复存储
  4. 驱逐策略

    • 驱逐粒度是"前缀",不是"page"
    • evictable vs protected 区分
  5. 完整性检查

    • check_integrity 检测缓存损坏
    • 及时发现引用计数、结构错误等问题
  6. MHAKVCache 实现

    • 统一的 kv_buffer(一次分配,减少碎片)
    • 支持 TP(按 KV 头切分)
    • 支持两种布局(最终统一成 LayerFirst)
    • page_size = 1(简化实现)
    • store_kv 调用自定义 CUDA kernel
  7. TP 的切分原则

    • 沿着"可并行"的维度切分
    • Attention 按 head 切分
    • 不切分 batch_size, seq_len, head_dim

下一步

  • 学习 naive_manager.py:BaseCacheManager 的简单实现(无缓存管理)
  • 学习 radix_manager.py:BaseCacheManager 的高级实现(Radix Tree 前缀共享)

12. NaiveCacheManager:极简实现

12.1 核心设计思想

NaiveCacheManager 是 BaseCacheManager 的极简实现不做任何缓存管理

1
2
3
4
5
class NaiveCacheManager(BaseCacheManager):
def __init__(self, device: torch.device):
self.device = device
self.empty_tensor = torch.empty(0, dtype=torch.int32, device=device)
# ❌ 没有任何数据结构记录缓存!

设计哲学

  • ❌ 不匹配前缀(总是返回空)
  • ❌ 不锁定 handle(什么都不做)
  • ❌ 不插入前缀(直接返回全部长度)
  • ❌ 不支持驱逐(抛出异常)
  • ❌ 不检查完整性(什么都不做)

12.2 逐个方法分析

12.2.1 match_prefix:总是返回空

1
2
3
def match_prefix(self, input_ids: torch.Tensor) -> Tuple[NaiveCacheHandle, torch.Tensor]:
_ = input_ids # unused
return NaiveCacheHandle(0), self.empty_tensor

含义

  • 总是返回 cached_len = 0(没有匹配到任何前缀)
  • 返回空的 indices(没有可复用的缓存)

效果:每个请求都需要重新计算,无法复用缓存。

性能影响

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 场景:两个请求有相同的前缀
# Request 1: "Hello world how are you"
# Request 2: "Hello world what is your name"

# 使用 NaiveCacheManager:
handle1, indices1 = match_prefix([Hello, world, how, are, you])
# 返回:cached_len = 0, indices = []
# 效果:需要计算全部 5 个 token

handle2, indices2 = match_prefix([Hello, world, what, is, your, name])
# 返回:cached_len = 0, indices = []
# 效果:需要计算全部 6 个 token
# 💥 "Hello world" 被重复计算了!

# 使用 RadixCacheManager(后面会学):
handle2, indices2 = match_prefix([Hello, world, what, is, your, name])
# 返回:cached_len = 2, indices = [page_5, page_12]
# 效果:只需要计算后 4 个 token
# ✅ "Hello world" 被复用了!

12.2.2 lock_handle:什么都不做

1
2
def lock_handle(self, handle: BaseCacheHandle, unlock: bool = False) -> None:
_ = handle, unlock # unused

为什么可以这样?

  • 因为 NaiveCacheManager 不支持驱逐
  • 没有驱逐,就不需要锁定

12.2.3 insert_prefix:返回全部长度

1
2
3
def insert_prefix(self, input_ids: torch.Tensor, indices: torch.Tensor) -> int:
assert len(indices) == len(input_ids)
return len(indices)

含义

  • 检查 indicesinput_ids 长度相同
  • 返回 len(indices)(表示所有 token 都已在缓存中)

等等,这不对吧?

是的,这是一个反直觉的设计insert_prefix 返回的是"已经在缓存中的长度",返回 len(indices) 意味着:

  • “这些 token 都已经在缓存中了,不需要再插入”

但这是一个谎言!因为 NaiveCacheManager 根本没有缓存任何东西。

为什么要撒这个谎?

因为 NaiveCacheManager 的设计假设:

  • 调用者已经通过 store_kv 存储了 KV 到物理位置
  • insert_prefix 只是"通知"缓存管理器记录这个前缀
  • 但 NaiveCacheManager 不需要记录任何信息(因为不做缓存管理)
  • 所以直接返回 len(indices),告诉调用者"我已经知道了,不用再存储了"

调用者的行为

1
2
3
4
already_cached = insert_prefix(input_ids, indices)  # 返回 len(indices)

# 调用者会释放"已在缓存中"的 page
free_pages(indices[:already_cached]) # 释放所有 page

生命周期

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 1. 分配 page
pages = allocate_pages(5)

# 2. 存储 KV
store_kv(k, v, out_loc=pages, layer_id=0)

# 3. 插入前缀
already_cached = insert_prefix(input_ids, pages) # 返回 5

# 4. 释放"已在缓存中"的 page
free_pages(pages[:already_cached]) # 释放所有 5 个 page

# 5. 请求结束,page 被释放
# ✅ 没有泄漏,但无法复用

12.2.4 evict:不支持驱逐

1
2
3
4
def evict(self, size: int) -> torch.Tensor:
if size == 0:
return self.empty_tensor
raise NotImplementedError("NaiveCacheManager does not support eviction.")

为什么不支持?

根本原因:它不记录任何缓存信息

1
2
3
4
5
6
7
8
9
10
11
12
13
# 如果强行实现 evict 会怎样?
def evict(self, size: int) -> torch.Tensor:
# 问题1:驱逐哪些 page?
# 没有记录任何缓存信息,不知道哪些 page 被使用了

# 问题2:如何选择驱逐策略?
# FIFO?LRU?没有时间戳或访问记录

# 问题3:驱逐后如何更新数据结构?
# 没有数据结构可以更新

# 💥 无法实现!
raise NotImplementedError()

NaiveCacheManager 的模式

  • 分配 page → 使用 → 立即释放
  • 不保留任何缓存
  • 所以不需要驱逐

12.2.5 size_info:总是返回 0

1
2
3
@property
def size_info(self) -> SizeInfo:
return SizeInfo(evictable_size=0, protected_size=0)

含义:没有任何缓存。

12.2.6 check_integrity:什么都不做

1
2
def check_integrity(self) -> None:
pass

为什么?

  • 没有数据结构可以检查

12.3 empty_tensor 的作用

1
2
def __init__(self, device):
self.empty_tensor = torch.empty(0, dtype=torch.int32, device=device)

为什么要预先创建?

原因1:避免重复创建

1
2
3
4
5
6
7
8
9
# ❌ 不好的实现
def match_prefix(self, input_ids):
return NaiveCacheHandle(0), torch.empty(0, dtype=torch.int32, device=self.device)
# 每次调用都创建新的 tensor

# ✅ 好的实现
def match_prefix(self, input_ids):
return NaiveCacheHandle(0), self.empty_tensor
# 复用同一个 tensor

原因2:保证在正确的设备上

1
2
3
4
5
6
# 如果不指定 device,可能在 CPU 上
empty_tensor = torch.empty(0, dtype=torch.int32) # 默认在 CPU

# 后续使用时可能出错
indices = self.empty_tensor # CPU tensor
k_cache[indices] # 💥 错误!k_cache 在 GPU 上

原因3:类型一致性

1
2
# 确保返回的 indices 类型是 int32
# 与其他 CacheManager 返回的类型一致

12.4 适用场景

✅ 适合的场景

  1. 基准测试

    • 对比缓存管理的性能提升
    • 提供 baseline
  2. 调试

    • 简化系统,排除缓存管理的影响
    • 快速定位问题
  3. 单次请求

    • 每个请求都是独立的
    • 不需要复用前缀
  4. 内存充足

    • 不需要驱逐缓存
    • 每个请求都有足够的 page

❌ 不适合的场景

  1. 多轮对话

    • 需要复用 System Prompt
    • 需要复用历史对话
  2. 批量请求

    • 多个请求有相同的前缀
    • 需要前缀共享
  3. 内存不足

    • 需要驱逐缓存释放空间
    • NaiveCacheManager 不支持驱逐
  4. 生产环境

    • 需要高性能
    • 需要前缀复用

12.5 NaiveCacheManager vs RadixCacheManager

特性 NaiveCacheManager RadixCacheManager
前缀匹配 ❌ 总是返回空 ✅ Radix Tree 匹配
前缀复用 ❌ 无法复用 ✅ 自动复用
驱逐策略 ❌ 不支持 ✅ LRU/FIFO
内存管理 ❌ 即用即销毁 ✅ 长期保留
数据结构 ❌ 无 ✅ Radix Tree
性能 ❌ 低 ✅ 高
适用场景 基准测试、调试 生产环境

费曼挑战(续)

挑战11:NaiveCacheManager 的设计思想

问题:用简单的话解释 NaiveCacheManager 的核心设计思想。

解答

  • 核心思想:不做任何缓存管理,提供极简实现
  • 行为:所有方法都是空实现或返回默认值
  • 效果:每个请求都需要重新计算,无法复用前缀

关键类比

  • 就像一个"假的"图书管理系统,不记录任何借阅信息

挑战12:match_prefix 返回空的影响

问题:解释 match_prefix 总是返回空对性能的影响。

解答

  • 影响:无法复用前缀,重复计算
  • 场景:两个请求有相同的前缀 “Hello world”,都需要重新计算
  • 对比:RadixCacheManager 可以复用前缀,只计算一次

关键场景

1
2
3
4
# Request 1: "Hello world how are you"
# Request 2: "Hello world what is your name"
# NaiveCacheManager:计算 5 + 6 = 11 个 token
# RadixCacheManager:计算 5 + 4 = 9 个 token(复用 "Hello world")

挑战13:insert_prefix 的"谎言"

问题:解释 insert_prefix 为什么返回 len(indices),这是什么意思。

解答

  • 返回值含义:“所有 token 都已在缓存中”
  • 为什么是谎言:NaiveCacheManager 根本没有缓存任何东西
  • 为什么这样设计:因为不需要记录任何信息,直接告诉调用者"我已经知道了"

调用者行为

1
2
already_cached = insert_prefix(input_ids, indices)  # 返回 len(indices)
free_pages(indices[:already_cached]) # 释放所有 page

挑战14:为什么不支持 evict?

问题:解释为什么 NaiveCacheManager 不支持
解答

  • 根本原因:它不记录任何缓存信息
  • 无法实现:不知道哪些 page 被使用了,不知道驱逐哪些
  • 不需要驱逐:因为 page 在请求结束时立即释放,不保留任何缓存

关键理解

  • NaiveCacheManager 的模式是"即用即销毁"

挑战15:empty_tensor 的作用

问题:解释为什么要预先创建 empty_tensor。

解答

  • 原因1:避免重复创建(性能优化)
  • 原因2:保证在正确的设备上(GPU)
  • 原因3:类型一致性(int32)

关键代码

1
self.empty_tensor = torch.empty(0, dtype=torch.int32, device=device)

学习总结

核心收获

  1. 两层架构

    • BaseKVCache(存储层)+ BaseCacheManager(管理层)
    • 分离关注点,提高可扩展性
  2. 并发安全

    • lock/unlock 机制防止驱逐正在使用的缓存
    • match → lock → insert → unlock 流程
  3. 前缀复用

    • match_prefix 找到可复用的部分
    • insert_prefix 避免重复存储
  4. 驱逐策略

    • 驱逐粒度是"前缀",不是"page"
    • evictable vs protected 区分
  5. 完整性检查

    • check_integrity 检测缓存损坏
    • 及时发现引用计数、结构错误等问题
  6. MHAKVCache 实现

    • 统一的 kv_buffer(一次分配,减少碎片)
    • 支持 TP(按 KV 头切分)
    • 支持两种布局(最终统一成 LayerFirst)
    • page_size = 1(简化实现)
    • store_kv 调用自定义 CUDA kernel
  7. TP 的切分原则

    • 沿着"可并行"的维度切分
    • Attention 按 head 切分
    • 不切分 batch_size, seq_len, head_dim
  8. NaiveCacheManager 实现

    • 极简实现,不做任何缓存管理
    • 所有方法都是空实现或返回默认值
    • 适合基准测试、调试、单次请求
    • 不适合多轮对话、批量请求、生产环境
  9. NaiveCacheManager 的关键理解

    • 不是"关闭缓存",而是"不管理缓存"
    • 不会内存泄漏,但性能很差
    • 提供 baseline 对比

13. RadixCacheManager:Radix Tree 前缀共享

13.1 什么是 Radix Tree?

Radix Tree(基数树) 是一种压缩的前缀树,用于高效存储和查找共享前缀的字符串。

类比

  • 想象一个文件系统的目录结构
  • 共享的路径只存储一次
  • 例如:/home/user/doc1.txt/home/user/doc2.txt 共享 /home/user/

在 KV Cache 中的应用

  • 多个请求可能有相同的前缀(例如 System Prompt)
  • Radix Tree 可以让这些请求共享前缀的 KV Cache
  • 节省内存,提高性能

例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Request 1: "Hello world"
# Request 2: "Hello there"

# 不使用 Radix Tree(重复存储):
Request 1: [Hello, world] → [page_1, page_2]
Request 2: [Hello, there] → [page_3, page_4]
# 总共:4 个 page

# 使用 Radix Tree(共享前缀):
root
└── [Hello] → [page_1] # 共享
├── [world] → [page_2]
└── [there] → [page_3]
# 总共:3 个 page,节省 1 个 page

13.2 RadixTreeNode:树的节点

1
2
3
4
5
6
7
8
9
10
11
12
class RadixTreeNode:
def __init__(self):
self.children: Dict[int, RadixTreeNode] = {} # 子节点
self._parent: RadixTreeNode | None = None # 父节点
self.ref_count: int = 0 # 引用计数
self.uuid: int # 唯一标识
self.timestamp: int # 时间戳(LRU)

# 存储的数据
self._key: torch.Tensor # token 序列
self._value: torch.Tensor # page indices
self._length: int # 长度

关键字段

  1. children:子节点字典

    • key = 第一个 token ID
    • value = 子节点
  2. ref_count:引用计数

    • 表示有多少个请求正在使用这个节点
    • ref_count > 0:受保护,不能驱逐
    • ref_count = 0:可驱逐
  3. timestamp:时间戳

    • 用于 LRU 驱逐策略
    • 最近访问的节点 timestamp 更大
  4. _key 和 _value

    • _key:token 序列(例如 [Hello, world]
    • _value:page indices(例如 [page_5, page_12]

13.3 核心方法详解

13.3.1 _walk:遍历树找到最长匹配

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def _walk(self, input_ids: torch.Tensor) -> Tuple[RadixTreeNode, int]:
prefix_len = 0
node = self.root_node
tic = time.monotonic_ns()

while prefix_len < len(input_ids):
this_id = int(input_ids[prefix_len].item())
if this_id not in node.children:
return node, prefix_len # 没有子节点,返回

node = node.children[this_id]
match_len = node.get_match_len(input_ids[prefix_len:])
prefix_len += match_len

if match_len != node.length:
node = node._split_at(match_len) # 部分匹配,需要分裂
return node, prefix_len

node.timestamp = tic # 更新时间戳(LRU)

return node, prefix_len

作用:从根节点开始,沿着树向下走,找到最长匹配的前缀。

例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 树结构:
root
└── Node1: key=[Hello, world]

# 查找 "Hello there"
node, prefix_len = _walk([Hello, there])

# 步骤1:从 root 开始
# - this_id = Hello
# - root.children[Hello] = Node1
# - 进入 Node1

# 步骤2:匹配 Node1
# - Node1.key = [Hello, world]
# - 输入 = [Hello, there]
# - match_len = 1(只有 Hello 匹配)

# 步骤3:match_len (1) != node.length (2)
# - 需要分裂!
# - 调用 Node1._split_at(1)

# 返回:node = 分裂后的新节点, prefix_len = 1

13.3.2 _split_at:节点分裂

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def _split_at(self, pos: int) -> RadixTreeNode:
assert 0 < pos < self.length
parent = self.parent

# 创建新节点(前半部分)
new_node = RadixTreeNode(self.timestamp)
new_node.set_key_value(self._key[:pos], self._value[:pos])
new_node.set_parent(parent)
new_node.ref_count = self.ref_count

# 修改当前节点(后半部分)
self.set_key_value(self._key[pos:], self._value[pos:])
self.set_parent(new_node)

return new_node

作用:“细胞分裂”,将一个节点分裂成两个节点。

例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 原节点:
# key = [Hello, world]
# value = [page_1, page_2]
# parent = root

# 调用 _split_at(1):

# 新节点(返回):
# key = [Hello]
# value = [page_1]
# parent = root

# 原节点(修改后):
# key = [world]
# value = [page_2]
# parent = new_node

# 树结构:
# root → new_node[Hello] → self[world]

关键点

  • 新节点继承原节点的 ref_count(因为引用的是同一条路径)
  • 原节点变成新节点的子节点

13.3.3 match_prefix:匹配前缀

1
2
3
4
5
6
7
8
9
10
11
12
13
def match_prefix(self, input_ids: torch.Tensor) -> Tuple[RadixCacheHandle, torch.Tensor]:
node, prefix_len = self._walk(input_ids)
if prefix_len == 0:
return RadixCacheHandle(0, node), self.empty_tensor

# 收集从根到当前节点的所有 value
value_list = []
matched_node = node
while not node.is_root():
value_list.append(node.value)
node = node.parent
value_list.reverse()
return RadixCacheHandle(prefix_len, matched_node), torch.cat(value_list)

作用:找到最长匹配的前缀,返回 handle 和 indices。

为什么需要向上遍历收集 value?

因为整个路径都是匹配的,计算时需要所有信息:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 树结构:
root
└── Node1: [Hello], value=[page_1]
└── Node2: [world], value=[page_2]
└── Node3: [how, are, you], value=[page_3, page_4, page_5]

# 匹配 "Hello world how are you"
node, prefix_len = _walk([Hello, world, how, are, you])
# node = Node3, prefix_len = 5

# 收集 value(从 Node3 向上到 root)
value_list = []
# Node3: [page_3, page_4, page_5]
# Node2: [page_2]
# Node1: [page_1]

value_list.reverse()
# value_list = [[page_1], [page_2], [page_3, page_4, page_5]]

indices = torch.cat(value_list)
# indices = [page_1, page_2, page_3, page_4, page_5]

13.3.4 insert_prefix:插入新前缀

1
2
3
4
5
6
7
8
def insert_prefix(self, input_ids: torch.Tensor, indices: torch.Tensor) -> int:
node, prefix_len = self._walk(input_ids)
if prefix_len < len(input_ids):
new_node = RadixTreeNode()
new.set_key_value(input_ids[prefix_len:], indices[prefix_len:])
new_node.set_parent(node)
self.evictable_size += new_node.length
return prefix_len

作用:插入新的前缀到树中。

返回值:实际匹配的长度(和 NaiveCacheManager 的区别)。

例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 树结构:
root
└── Node1: [Hello, world], value=[page_1, page_2]

# 插入 "Hello world how"
node, prefix_len = _walk([Hello, world, how])
# node = Node1, prefix_len = 2

# prefix_len (2) < len(input_ids) (3)
# 创建新节点
new_node = RadixTreeNode()
new_node.set_key_value([how], [page_3])
new_node.set_parent(Node1)

# 树结构(插入后):
root
└── Node1: [Hello, d], value=[page_1, page_2]
└── Node2: [how], value=[page_3]

# 返回 prefix_len = 2
# 调用者会释放前 2 个 page(已在缓存中)
# 保留 page_3(新插入的)

13.3.5 lock_handle:锁定/解锁节点

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def lock_handle(self, handle: BaseCacheHandle, unlock: bool = False):
node = handle.node
if unlock:
while not node.is_root():
node.ref_count -= 1
if node.ref_count == 0:
self.evictable_size += node.length
self.protected_size -= node.length
node = node.parent
else:
while not node.is_root():
if node.ref_count == 0:
self.evictable_size -= node.length
self.protected_size += node.length
node.ref_count += 1
node = node.parent

作用:锁定从当前节点到根节点的整条路径。

为什么要锁定整条路径?

因为使用当前节点需要从根到当前节点的所有信息:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 树结构:
root
└── Node1: [Hello]
└── Node2: [world]
└── Node3: [how, are, you]

# 如果要使用 Node3,需要:
# - Node1 的 KV("Hello")
# - Node2 的 KV("world")
# - Node3 的 KV("how are you")

# 所以必须锁定整条路径:
lock_handle(handle_to_node3)
# 锁定:Node3 → Node2 → Node1
# 所有节点的 ref_count += 1

如果只锁定 Node3 会怎样?

1
2
3
4
5
6
7
8
9
# 只锁定 Node3
Node3.ref_count = 1 # 受保护
Node2.ref_count = 0 # 可驱逐
Node1.ref_count = 0 # 可驱逐

# 内存不足,触发驱逐
evict(size=10)
# 💥 Node1 或 Node2 可能被驱逐!
# 但 Node3 依赖它们,导致数据损坏!

13.3.6 evict:驱逐缓存

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def evict(self, size: int) -> torch.Tensor:
if size == 0:
return self.empty_tensor

leave_nodes = self._collect_leave_nodes_for_evict()
heapq.heapify(leave_nodes) # 最小堆(按 timestamp)
evicted_indices = []
evicted_size = 0

while evicted_size < size:
node = heapq.heappop(leave_nodes) # 弹出最旧的叶子节点
evicted_size += node.length
evicted_indices.append(node.value)
self.evictable_size -= node.length
parent = node.parent
del parent.children[int(node._key[0].item())]

if parent.is_leaf() and parent.ref_count == 0:
heapq.heappush(leave_nodes, parent) # 父节点也变成叶子了

return torch.cat(evicted_indices)

作用:使用 LRU 策略驱逐最旧的叶子节点。

为什么使用最小堆?

  • 按 timestamp 排序
  • 驱逐最久没使用的节点(LRU)

为什么只驱逐叶子节点?

原因1:叶子节点没有依赖

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 树结构:
root
└── Node1: [Hello]
├── Node2: [world]
└── Node3: [there]

# 如果驱逐 Node1(非叶子节点):
# 💥 Node2 和 Node3 会失去父节点!
# 树结构被破坏!

# 如果驱逐 Node2(叶子节点):
# ✅ 只删除 Node2,不影响其他节点
root
└── Node1: [Hello]
└── Node3: [there]

原因2:驱逐叶子节点后,父节点可能也变成叶子

1
2
3
4
5
6
7
8
9
10
11
12
# 初始状态:
root
└── Node1: [Hello]
└── Node2: [world] # 叶子节点

# 驱逐 Node2:
root
└── Node1: [Hello] # 现在也是叶子节点了!

# 如果 Node1.ref_count == 0,可以继续驱逐
if parent.is_leaf() and parent.ref_count == 0:
heapq.heappush(leave_nodes, parent)

13.4 完整示例

假设有以下场景:

1
2
3
# Request 1: "Hello world"
# Request 2: "Hello there"
# Request 3: "Hello world how are you"

步骤1:插入 Request 1

1
2
3
4
5
6
7
8
9
# 初始状态:
root

# 插入 "Hello world"
insert_prefix([Hello, world], [page_1, page_2])

# 树结构:
root
└── Node1: key=[Hello, world], value=[page_1, page_2]

步骤2:插入 Request 2(分裂)

1
2
3
4
5
6
7
8
9
10
11
# 插入 "Hello there"
insert_prefix([Hello, there], [page_3])

# _walk 发现只有 "Hello" 匹配
# 调用 Node1._split_at(1)

# 树结构(分裂后):
root
└── Node1: key=[Hello], value=[page_1]
├── Node2: key=[world], value=[page_2]
└── Node3: key=[there], value=[page_3]

步骤3:插入 Request 3

1
2
3
4
5
6
7
8
9
10
11
12
13
# 插入 "Hello world how are you"
insert_prefix([Hello, world, how, are, you], [page_4, page_5, page_6])

# _walk 匹配到 Node2("Hello world")
# prefix_len = 2
# 创建新节点存储 "how are you"

# 树结构:
root
└── Node1: key=[Hello], value=[page_1]
├── Node2: key=[world], value=[page_2]
│ └── Node4: key=[how, are, you], value=[page_4, page_5, page_6]
└── Node3: key=[there], value=[page_3]

步骤4:Request 1 和 2 结束

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Request 1 unlock
# 解锁路径:Node2 → Node1
Node2.ref_count = 0
Node1.ref_count = 1 # 还有 Request 3 在用

# Request 2 unlock
# 解锁路径:Node3 → Node1
Node3.ref_count = 0
Node1.ref_count = 0

# Request 3 还在运行
# 锁定路径:Node4 → Node2 → Node1
Node4.ref_count = 1
Node2.ref_count = 1
Node1.ref_count = 1

# 引用计数状态:
Node1.ref_count = 1 # 受保护
Node2.ref_count = 1 # 受保护
Node3.ref_count = 0 # 可驱逐
Node4.ref_count = 1 # 受保护

步骤5:驱逐 2 个 page

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 可驱逐的叶子节点:Node3 (length=1)
# 驱逐 Node3
evicted_size = 1

# 💥 还需要驱逐 1 个 page,但没有其他可驱逐的节点了!
# 驱逐失败!

# 如果 Request 3 也结束了:
# 可驱逐:Node3 (1 page), Node4 (3 pages)
# 驱逐 Node3 (1 page)
# 驱逐 Node4 (3 pages)
# 总共驱逐 4 pages(超过需求的 2 pages)

# 树结构(驱逐后):
root
└── Node1: key=[Hello], value=[page_1]
└── Node2: key=[world], value=[page_2]

13.5 root 节点的特殊性

1
2
3
def __init__(self, device):
self.root_node = RadixTreeNode()
self.root_node.ref_count = 1 # root is always protected

为什么 ref_count = 1?

防止被驱逐

  • root.ref_count = 1,永远不会被驱逐
  • 保证树的根节点始终存在
  1. 作为所有路径的起点
1
2
3
4
root
├── Node1: [Hello]
├── Node2: [How]
└── Node3: [What]
  1. 简化 lock_handle 逻辑
1
2
3
4
while not node.is_root():  # 遍历到 root 就停止
node.ref_count += 1
node = node.parent
# 不需要锁定 root(它永远受保护)

13.6 完全相同的请求

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 树结构:
root
└── Node1: [Hello, world], value=[page_1, page_2], timestamp=1000

# Request 1: "Hello world"
node, prefix_len = _walk([Hello, world])
# node = Node1, prefix_len = 2
# 更新 timestamp = 2000

# Request 2: "Hello world"(完全相同)
node, prefix_len = _walk([Hello, world])
# node = Node1, prefix_len = 2
# 再次更新 timestamp = 3000

# 树结构(更新后):
root
└── Node1: [Hello, world], value=[page_1, page_2], timestamp=3000

为什么要更新 timestamp?

  • 使用 LRU 驱逐策略
  • 最近使用的节点 timestamp 更大
  • 驱逐时优先驱逐 timestamp 最小的节点
  • 更新 timestamp 可以防止常用的前缀被驱逐

费曼挑战(续)

挑战16:Radix Tree 的核心思想

问题:用简单的话解释 Radix Tree 的核心思想。

解答

  • 核心思想:压缩的前缀树,共享相同的前缀
  • 类比:文件系统的目录结构,共享的路径只存储一次
  • 效果:节省内存,提高性能

关键场景

1
2
3
4
5
# "Hello world" 和 "Hello there" 共享 "Hello"
root
└── [Hello] # 共享
├── [world]
└── [there]

挑战17:ref_count 的作用

问题:解释 ref_count 的作用,以及为什么需要它。

解答

  • 作用:引用计数,表示有多少个请求正在使用这个节点
  • ref_count > 0:受保护,不能驱逐
  • ref_count = 0:可驱逐
  • 为什么需要:防止正在使用的缓存被驱逐

挑战18:_walk 和 _split_at

问题:解释 _walk 方法做了什么,以及为什么需要 _split_at。

解答

  • _walk:从根节点开始,沿着树向下走,找到最长匹配的前缀
  • _split_at:“细胞分裂”,将一个节点分裂成两个节点
  • 为什么需要分裂:当只有部分匹配时,需要分裂节点来共享前缀

关键场景

1
2
3
4
5
6
7
8
# 原节点:[Hello, world]
# 输入:[Hello, there]
# 只有 "Hello" 匹配,需要分裂

# 分裂后:
# [Hello] # 共享
# ├── [world]
# └── [there]

挑战19:lock_handle 锁定整条路径

问题:解释为什么 lock_handle 要锁定从当前节点到根节点的整条路径。

解答

  • 原因:使用当前节点需要从根到当前节点的所有信息
  • 如果只锁定当前节点:父节点可能被驱逐,导致数据损坏

关键场景

1
2
3
# 使用 Node3 需要:
# Node1 的 KV + Node2 的 KV + Node3 的 KV
# 所以必须锁定整条路径:Node3 → Node2 → Node1

挑战20:evict 只驱逐叶子节点

问题:解释为什么 evict 只驱逐叶子节点。

解答

  • 原因1:叶子节点没有依赖,驱逐不会破坏树结构
  • 原因2:驱逐叶子节点后,父节点可能也变成叶子,可以继续驱逐
  • 如果驱逐非叶子节点:子节点会失去父节点,树结构被破坏

关键代码

1
2
if parent.is_leaf() and parent.ref_count == 0:
heapq.heappush(leave_nodes, parent)