概述

阶段八学习 Mini-SGLang 的高级特性,主要是自定义 CUDA Kernels。这些 Kernels 通过 C++/CUDA 实现,提供比 PyTorch 更高的性能。

核心问题:为什么需要自定义 CUDA Kernels?

答案

  1. 性能优化:PyTorch 的通用实现有额外开销
  2. 内存效率:直接操作 GPU 内存,避免中间拷贝
  3. 并行优化:针对特定场景优化并行策略
  4. 内存对齐:利用 GPU 缓存行对齐提高带宽

1. Radix 树操作:kernel/radix.py

1.1 核心思想

问题:在 Radix Tree 中,需要频繁比较两个 token 序列是否相等。Python 的比较很慢。

解决方案:用 C++ 实现快速比较,通过 TVM FFI 调用。

1.2 关键代码

1
2
3
4
5
6
7
@lru_cache(maxsize=None)
def _load_radix_module() -> Module:
return load_aot("radix", cpp_files=["radix.cpp"])

def fast_compare_key(x: torch.Tensor, y: torch.Tensor) -> int:
# compare 2 1-D int cpu tensors for equality
return _load_radix_module().fast_compare_key(x, y)

1.3 为什么需要?

Python 实现(慢)

1
2
3
4
5
6
7
8
9
10
11
12
13
def compare_python(x, y):
if len(x) != len(y):
return False
for i in range(len(x)):
if x[i] != y[i]: # 每次都要:
return False # 1. 解释器开销
# 2. 类型检查
# 3. 对象访问
return True

# 假设比较 1000 个 token
# 每次比较:~100ns(Python 解释器开销)
# 总时间:1000 * 100ns = 100μs

C++ 实现(快)

1
2
3
4
5
6
7
8
9
10
11
12
13
int fast_compare_key(Tensor x, Tensor y) {
if (x.size() != y.size()) return -1;

int* x_data = x.data_ptr<int>();
int* y_data = y.data_ptr<int>();
int n = x.size();

// 直接内存比较,可能用 SIMD
return memcmp(x_data, y_data, n * sizeof(int));
}

// 时间:~1μs(直接内存操作)
// 提速:100 倍

1.4 @lru_cache 的作用

1
2
3
4
5
6
7
8
9
10
11
@lru_cache(maxsize=None)
def _load_radix_module() -> Module:
return load_aot("radix", cpp_files=["radix.cpp"])

# 第一次调用
module1 = _load_radix_module() # 编译 + 加载(慢,~100ms)

# 第二次调用
module2 = _load_radix_module() # 直接返回缓存(快,~1μs)

# module1 is module2 == True(同一个对象)

load_aot 的工作流程

1
2
3
4
5
6
7
8
9
10
11
12
13
def load_aot(name, cpp_files):
# 1. 检查是否已编译
binary_path = f"{name}.so"
if os.path.exists(binary_path):
# 检查源文件是否更新
if is_up_to_date(binary_path, cpp_files):
return load_binary(binary_path) # 直接加载

# 2. 编译 C++ 代码
compile_cpp(cpp_files, output=binary_path)

# 3. 加载二进制
return load_binary(binary_path)

关键

  • AOT(Ahead-Of-Time):提前编译
  • 缓存:二进制文件最新就不再编译,直接加载
  • @lru_cache:避免重复加载

1.5 使用场景

RadixCacheManager 中比较 token 序列:

1
2
3
4
5
6
7
8
def _common_prefix_len(self, x: torch.Tensor, y: torch.Tensor) -> int:
min_len = min(len(x), len(y))

# 使用 C++ 实现
for i in range(min_len):
if fast_compare_key(x[i:i+1], y[i:i+1]) != 0:
return i
return min_len

性能提升

1
2
3
4
# 场景:1000 个请求,每个请求平均比较 10 次
# Python:1000 * 10 * 100μs = 1s
# C++:1000 * 10 * 1μs = 10ms
# 提速:100 倍

2. KV Cache 存储:kernel/store.py

2.1 核心思想

问题:在 Attention 计算后,需要将 K 和 V 存储到 KV Cache。PyTorch 的索引操作很慢。

解决方案:用自定义 CUDA Kernel 实现高性能存储。

2.2 关键代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
@lru_cache(maxsize=None)
def _jit_store_module(
element_size: int,
*,
config: KernelConfig = DEFAULT_INDEX_KERNEL_CONFIG,
) -> Module:
args = make_cpp_args(element_size, *config)
return load_jit(
"store",
*args,
cuda_files=["store.cu"],
cuda_wrappers=[("launch", f"StoreKernel<{args}>::run")],
)

def store_cache(
k_cache: torch.Tensor, # KV Cache 的 K 部分
v_cache: torch.Tensor, # KV Cache 的 V 部分
indices: torch.Tensor, # 存储位置索引
k: torch.Tensor, # 新计算的 K
v: torch.Tensor, # 新计算的 V
) -> None:
num_tokens = k_cache.shape[0]
k_cache = k_cache.view(num_tokens, -1) # 展平为 2D
v_cache = v_cache.view(num_tokens, -1)
element_size = k_cache.shape[1] * k_cache.element_size()
module = _jit_store_module(element_size)
module.launch(k_cache, v_cache, indices, k, v)

2.3 JIT vs AOT

AOT(Ahead-Of-Time,提前编译)

1
2
3
4
5
// 需要为每个 element_size 编译一个版本
void store_64KB(...);
void store_80KB(...);
void store_160KB(...);
// 问题:element_size 可能有无数种组合

JIT(Just-In-Time,即时编译)

1
2
3
4
5
6
7
8
9
10
11
12
// 运行时根据 element_size 编译
template<int ElementSize>
void store(...);

// 第一次调用 element_size=64KB
store<65536>(...) // 编译 + 缓存

// 第二次调用 element_size=64KB
store<65536>(...) // 直接使用缓存

// 第一次调用 element_size=80KB
store<81920>(...) // 编译 + 缓存

优势

  • 只编译需要的版本
  • 支持任意 element_size
  • 编译一次,缓存结果(@lru_cache

2.4 element_size 的计算

1
2
3
4
5
6
7
8
9
10
11
# 假设 Llama-7B
# k_cache: [num_pages * page_size, num_layers, num_kv_heads, head_dim]
# = [16000, 32, 8, 128]

k_cache = k_cache.view(num_tokens, -1)
# 展平后:[16000, 32 * 8 * 128] = [16000, 32768]

element_size = k_cache.shape[1] * k_cache.element_size()
# = 32768 * 2 (假设 float16)
# = 65536 字节
# = 64 KB

不同模型的 element_size

模型 num_layers num_kv_heads head_dim dtype element_size
Llama-7B 32 8 128 float16 64 KB
Llama-13B 40 8 128 float16 80 KB
Llama-70B 80 8 128 float16 160 KB

2.5 为什么要 view(num_tokens, -1)

原始形状

1
2
k_cache: [num_pages * page_size, num_layers, num_kv_heads, head_dim]
= [16000, 32, 8, 128]

展平后

1
2
k_cache: [num_tokens, num_layers * num_kv_heads * head_dim]
= [16000, 32768]

CUDA Kernel 实现对比

多维索引(复杂)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
__global__ void store_4d(
float* k_cache, // [num_tokens, num_layers, num_kv_heads, head_dim]
int* indices,
float* k,
int num_layers,
int num_kv_heads,
int head_dim
) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int idx = indices[tid];

// 需要计算 4D 索引
for (int l = 0; l < num_layers; l++) {
for (int h = 0; h < num_kv_heads; h++) {
for (int d = 0; d < head_dim; d++) {
int src_offset = tid * num_layers * num_kv_heads * head_dim
+ l * num_kv_heads * head_dim
+ h * head_dim
+ d;
int dst_offset = idx * num_layers * num_kv_heads * head_dim
+ l * num_kv_heads * head_dim
+ h * head_dim
+ d;
k_cache[dst_offset] = k[src_offset];
}
}
}
}

2D 索引(简单)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
__global__ void store_2d(
char* k_cache, // [num_tokens, element_size]
int* indices,
char* k,
int element_size
) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int idx = indices[tid];

// 直接内存拷贝
memcpy(
k_cache + idx * element_size,
k + tid * element_size,
element_size
);
}

优势

  1. 代码简单:不需要计算多维索引
  2. 性能更好:直接内存拷贝,利用 memcpy 优化
  3. 通用性强:不需要知道具体的维度

2.6 完整的数据流

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# 1. Attention 计算
k, v = attention.forward(x)
# k: [batch_size, num_layers, num_kv_heads, head_dim]
# v: [batch_size, num_layers, num_kv_heads, head_dim]

# 2. 获取存储位置
indices = table_manager.get_indices(batch)
# indices: [batch_size] # 每个 token 存储到哪个位置

# 3. 存储到 KV Cache
store_cache(k_cache, v_cache, indices, k, v)

# 内部流程
# 3.1 展平
k_cache = k_cache.view(num_tokens, -1) # [16000, 32768]
v_cache = v_cache.view(num_tokens, -1)

# 3.2 计算 element_size
element_size = 32768 * 2 = 65536 字节

# 3.3 JIT 编译(第一次)或使用缓存(后续)
module = _jit_store_module(65536)

# 3.4 启动 CUDA Kernel
module.launch(k_cache, v_cache, indices, k, v)
# GPU 并行拷贝:
# k_cache[indices[0]] = k[0]
# k_cache[indices[1]] = k[1]
# ...

2.7 性能对比

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 假设
# - batch_size = 100
# - element_size = 64KB

# PyTorch 实现
k_cache[indices] = k
v_cache[indices] = v
# 时间:100 * 10μs = 1ms(索引开销)

# CUDA Kernel 实现
store_cache(k_cache, v_cache, indices, k, v)
# 时间:100μs(并行拷贝)

# 提速:10 倍

3. Embedding 查找:kernel/index.py

3.1 核心思想

问题:在 LLM 中,需要根据 token ID 查找对应的 Embedding。PyTorch 的索引操作在大 Embedding 表上很慢。

解决方案:用自定义 CUDA Kernel 实现高性能索引。

3.2 关键代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
@lru_cache(maxsize=None)
def _jit_index_module(
element_size: int,
*,
num_splits: int = 1,
config: KernelConfig = DEFAULT_INDEX_KERNEL_CONFIG,
) -> Module:
args = make_cpp_args(element_size, num_splits, *config)
return load_jit(
"index",
*args,
cuda_files=["index.cu"],
cuda_wrappers=[("launch", f"IndexKernel<{args}>::run")],
)

def indexing(
weights: torch.Tensor, # Embedding 表
indices: torch.Tensor, # token ID
*,
output: torch.Tensor | None = None,
vocab_range: Tuple[int, int] | None = None, # (start, length)
) -> torch.Tensor:
if output is None:
output = weights.new_empty(indices.shape[0], weights.shape[1])

element_size = weights.shape[1] * weights.element_size()

# 根据 element_size 选择 num_splits
if element_size % 2048 == 0:
num_splits = 4
elif element_size % 1024 == 0:
num_splits = 2
else:
num_splits = 1

module = _jit_index_module(element_size, num_splits=num_splits)
module.launch(weights, indices, output, vocab_range)
return output

3.3 num_splits 的作用

问题:Embedding 向量很大(例如 4096 维),单个线程处理太慢。

解决方案:将每个向量分成多份,多个线程并行处理。

1
2
3
4
5
6
7
// 假设 hidden_size = 4096,num_splits = 4
// 每个 split 处理 1024 个元素

// 线程 0: 拷贝 weights[idx][0:1024] → output[i][0:1024]
// 线程 1: 拷贝 weights[idx][1024:2048] → output[i][1024:2048]
// 线程 2: 拷贝 weights[idx][2048:3072] → output[i][2048:3072]
// 线程 3: 拷贝 weights[idx][3072:4096] → output[i][3072:4096]

性能影响

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 假设 hidden_size = 4096,dtype = float16
element_size = 4096 * 2 = 8192 字节

# num_splits = 1(单线程)
# 每个线程拷贝 8192 字节
# 时间:~10μs

# num_splits = 2(2 个线程)
# 每个线程拷贝 4096 字节
# 时间:~5μs(并行)

# num_splits = 4(4 个线程)
# 每个线程拷贝 2048 字节
# 时间:~2.5μs(并行)

为什么不是越多越好?

1
2
3
4
5
6
7
# num_splits = 8(8 个线程)
# 每个线程拷贝 1024 字节
# 时间:~3μs(线程调度开销)

# num_splits = 16)
# 每个线程拷贝 512 字节
# 时间:~5μs(线程调度开销 > 并行收益)

最优选择

  • 2048 字节/线程:平衡并行度和开销
  • 内存对齐:2048 字节是 GPU 缓存行的倍数

3.4 内存对齐的重要性

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// 假设 GPU 缓存行大小 = 128 字节

// 未对齐(element_size = 8000 字节)
// 线程 0: 读取 [0, 2000) → 跨越 16 个缓存行
// 线程 1: 读取 [2000, 4000) → 跨越 16 个缓存行
// 线程 2: 读取 [4000, 6000) → 跨越 16 个缓存行
// 线程 3: 读取 [6000, 8000) → 跨越 16 个缓存行
// 问题:缓存行未对齐,效率低

// 对齐(element_size = 8192 字节)
// 线程 0: 读取 [0, 2048) → 正好 16 个缓存行
// 线程 1: 读取 [2048, 4096) → 正好 16 个缓存行
// 线程 2: 读取 [4096, 6144) → 正好 16 个缓存行
// 线程 3: 读取 [6144, 8192) → 正好 16 个缓存行
// 优势:缓存行对齐,效率高

好的 element_size(对齐 2048)

1
8192, 10240, 12288, 14336, 16384, ...

不好的 element_size(未对齐)

1
8000, 10000, 12000, 14000, 16000, ...

3.5 vocab_range 在 TP 中的使用

用途:只查找 Embedding 表的一部分。

场景:Tensor Parallelism(TP)

1
2
3
4
5
6
7
8
9
10
11
12
13
# 假设 vocab_size = 32000,tp_size = 4
# 每个 GPU 负责一部分词表

# GPU 0: vocab_range = (0, 8000)
# GPU 1: vocab_range = (8000, 8000)
# GPU 2: vocab_range = (16000, 8000)
# GPU 3: vocab_range = (24000, 8000)

# 查找 token_id = 10000
# GPU 0: 跳过(10000 不在 [0, 8000))
# GPU 1: 查找 weights[10000 - 8000] = weights[2000]
# GPU 2: 跳过
# GPU 3: 跳过

完整流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# 假设 vocab_size = 32000,tp_size = 4,batch_size = 10
# token_ids = [100, 8500, 16500, 24500, 1000, 9000, 17000, 25000, 5000, 20000]

# GPU 0: vocab_range = (0, 8000)
output_0 = indexing(
weights_0, # [8000, 4096]
token_ids, # [10]
vocab_range=(0, 8000)
)
# 结果:
# output_0[0] = weights_0[100] # 100 在 [0, 8000)
# output_0[1] = 0 # 8500 不在 [0, 8000)
# output_0[2] = 0 # 16500 不在 [0, 8000)
# ...
# output_0[4] = weights_0[1000] # 1000 在 [0, 8000)
# output_0[8] = weights_0[5000] # 5000 在 [0, 8000)

# GPU 1: vocab_range = (8000, 8000)
output_1 = indexing(
weights_1, # [8000, 4096]
token_ids,
vocab_range=(8000, 8000)
)
# 结果:
# output_1[0] = 0
# output_1[1] = weights_1[8500 - 8000] = weights_1[500]
# output_1[5] = weights_1[9000 - 8000] = weights_1[1000]

# GPU 2: vocab_range = (16000, 8000)
output_2 = indexing(
weights_2, # [8000, 4096]
token_ids,
vocab_range=(16000, 8000)
)
# 结果:
# output_2[2] = weights_2[16500 - 16000] = weights_2[500]
# output_2[6] = weights_2[17000 - 16000] = weights_2[1000]
# output_2[9] = weights_2[20000 - 16000] = weights_2[4000]

# GPU 3: vocab_range = (24000, 8000)
output_3 = indexing(
we # [8000, 4096]
token_ids,
vocab_range=(24000, 8000)
)
# 结果:
# output_3[3] = weights_3[24500 - 24000] = weights_3[500]
# output_3[7] = weights_3[25000 - 24000] = weights_3[1000]

# 最后 all_reduce 求和
output = output_0 + output_1 + output_2 + output_3
# 每个位置只有一个 GPU 有非零值,其他都是 0
# 所以求和后得到正确的结果

关键点

  • 每个 GPU 只处理自己范围内的 token
  • 其他 token 输出 0
  • 最后 all_reduce 求和得到完整结果

3.6 完整的 TP Embedding 层

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class TPEmbedding(nn.Module):
def __init__(self, vocab_size, hidden_size, tp_size, tp_rank):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.tp_size = tp_size
self.tp_rank = tp_rank

# 每个 GPU 只存储一部分词表
self.local_vocab_size = vocab_size // tp_size
self.vocab_start = tp_rank * self.local_vocab_size
self.vocab_end = self.vocab_start + self.local_vocab_size

# 本地权重
self.weight = torch.empty(self.local_vocab_size, hidden_size)

def forward(self, input_ids):
# 1. 本地查找
output = indexing(
self.weight,
input_ids,
vocab_range=(self.vocab_start, self.local_vocab_size)
)

# 2. All-Reduce 求和
if self.tp_size > 1:
output = all_reduce(output)

return output

3.7 性能对比

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 假设
# - vocab_size = 32000
# - hidden_size = 4096
# - batch_size = 100
# - dtype = float16

# PyTorch 实现
output = weights[indices]
# 时间:~100μs(索引开销)

# CUDA Kernel 实现
output = indexing(weights, indices)
# 时间:~10μs(并行拷贝)

# 提速:10 倍

4. NCCL Python 绑定:kernel/pynccl.py

4.1 核心思想

问题:PyTorch 的分布式通信(torch.distributed)有额外开销。

解决方案:直接调用 NCCL(NVIDIA Collective Communications Library),绕过 PyTorch 的封装。

4.2 关键代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
@lru_cache(maxsize=None)
def _load_nccl_module() -> Module:
return load_aot("pynccl", cuda_files=["pynccl.cu"], extra_ldflags=["-lnccl"])

@lru_cache(maxsize=None)
def _get_pynccl_wrapper_cls():
import tvm_ffi

@tvm_ffi.register_object("minisgl.NCCLWrapper")
class PyNCCLImpl(tvm_ffi.Object):
def __init__(self, *args):
self.__ffi_init__(*args)

return PyNCCLImpl

def init_pynccl(
*,
tp_rank: int,
tp_size: int,
tp_cpu_group: torch.distributed.ProcessGroup,
max_size_bytes: int = 0,
) -> PyNCCLCommunicator:
import torch

max_size_bytes = min(max_size_bytes, ENV.PYNCCL_MAX_BUFFER_SIZE.value)

module = _load_nccl_module()
cls = _get_pynccl_wrapper_cls()

if tp_rank == 0:
id_list = [module.create_nccl_uid()]
torch.distributed.broadcast_object_list(
id_list,
src=0,
group=tp_cpu_group,
)
else:
id_list = [None]
torch.distributed.broadcast_object_list(
id_list,
src=0,
group=tp_cpu_group,
)

nccl_id = id_list[0]
assert not nccl_id is None, f"Failed to get NCCL unique ID on {tp_rank = }"

# bypass type checking for the FFI object
return cls(tp_rank, tp_size, max_size_bytes, nccl_id) # type: ignore

4.3 NCCL 初始化流程

步骤 1:Rank 0 创建 NCCL Unique ID

1
2
3
if tp_rank == 0:
id_list = [module.create_nccl_uid()]
# 创建一个全局唯一的 ID,用于所有进程建立通信

步骤 2:广播 NCCL ID 到所有进程

1
2
3
4
5
6
7
torch.distributed.broadcast_object_list(
id_list,
src=0,
group=tp_cpu_group,
)
# Rank 0 广播 ID 给所有其他进程
# 其他进程接收这个 ID

步骤 3:所有进程使用相同的 ID 初始化 NCCL

1
2
3
4
nccl_id = id_list[0]
return cls(tp_rank, tp_size, max_size_bytes, nccl_id)
# 每个进程用相同的 nccl_id 初始化
# NCCL 内部会建立进程间的通信通道

为什么需要 NCCL Unique ID?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 假设 4 个 GPU,4 个进程

# Rank 0 创建 ID
nccl_id = "abc123"

# 广播给所有进程
# Rank 0: nccl_id = "abc123"
# Rank 1: nccl_id = "abc123"
# Rank 2: nccl_id = "abc123"
# Rank 3: nccl_id = "abc123"

# 所有进程用相同的 ID 初始化
# NCCL 内部会:
# 1. 建立 Rank 0 ↔ Rank 1 的通信通道
# 2. 建立 Rank 0 ↔ Rank 2 的通信通道
# 3. 建立 Rank 0 ↔ Rank 3 的通信通道
# 4. 建立 Rank 1 ↔ Rank 2 的通信通道
# 5. 建立 Rank 1 ↔ Rank 3 的通信通道
# 6. 建立 Rank 2 ↔ Rank 3 的通信通道

4.4 PyNCCLCommunicator 接口

1
2
3
4
5
6
7
8
9
10
11
12
class PyNCCLCommunicator:
def all_reduce(self, input: torch.Tensor, op: Literal["sum"]) -> None:
# 所有进程求和
pass

def all_gather(self, output: torch.Tensor, input: torch.Tensor) -> None:
# 收集所有进程的数据
pass

def get_buffer(self) -> int:
# 获取通信缓冲区地址
pass

使用场景

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 在 distributed/impl.py 中使用
class PyNCCLDistributedImpl(DistributedImpl):
comm: PyNCCLCommunicator

def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
self.comm.all_reduce(x, "sum")
return x

def all_gather(self, x: torch.Tensor) -> torch.Tensor:
world_size = get_tp_info().size
output_shape = list(x.shape)
output_shape[0] *= world_size
result = x.new_empty(output_shape)
self.comm.all_gather(result, x)
return result

4.5 PyTorch vs NCCL 性能对比

PyTorch 实现

1
2
3
4
5
6
7
8
# torch.distributed.all_reduce
import torch.distributed as dist

def all_reduce_pytorch(x):
dist.all_reduce(x, op=dist.ReduceOp.SUM)
return x

# 时间:~100μs(包含 PyTorch 封装开销)

NCCL 实现

1
2
3
4
5
6
# 直接调用 NCCL
def all_reduce_nccl(x):
comm.all_reduce(x, "sum")
return x

# 时间:~50μs(绕过 PyTorch 封装)

提速:2 倍

为什么 NCCL 更快?

  1. 无 PyTorch 封装开销:直接调用 NCCL C++ API
  2. 无 Python GIL:C++ 实现,不受 GIL 限制
  3. 优化的通信算法:NCCL 针对 GPU 优化

4.6 max_size_bytes 的作用

1
max_size_bytes = min(max_size_bytes, ENV.PYNCCL_MAX_BUFFER_SIZE.value)

用途:限制 NCCL 通信缓冲区的大小。

为什么需要?

  • NCCL 需要预分配通信缓冲区
  • 缓冲区太大会浪费 GPU 内存
  • 缓冲区太小会导致通信失败

默认值

1
2
# 假设 ENV.PYNCCL_MAX_BUFFER_SIZE = 1GB
max_size_bytes = 1 * 1024 * 1024 * 1024 # 1GB

费曼挑战

挑战 1:为什么需要自定义 CUDA Kernels?

问题:解释为什么需要自定义 CUDA Kernels,而不是直接用 PyTorch。

解答

  • 性能优化:PyTorch 的通用实现有额外开销(索引、类型检查)
  • 内存效率:直接操作 GPU 内存,避免中间拷贝
  • 并行优化:针对特定场景优化并行策略(num_splits)
  • 内存对齐:利用 GPU 缓存行对齐提高带宽

例子

1
2
3
4
5
6
7
# PyTorch:k_cache[indices] = k
# 时间:1ms(索引开销)

# CUDA Kernel:store_cache(k_cache, v_cache, indices, k, v)
# 时间:100μs(并行拷贝)

# 提速:10 倍

挑战 2:JIT vs AOT

问题:解释 JIT 和 AOT 的区别,以及为什么 store.pyindex.py 用 JIT。

解答

  • AOT(Ahead-Of-Time):提前编译,需要为每个参数组合编译一个版本
  • JIT(Just-In-Time):运行时编译,根据参数动态编译
  • 为什么用 JIT:不同模型的 element_size 不同,JIT 可以支持任意组合

对比

1
2
3
4
5
6
7
8
9
10
# AOT:需要编译无数个版本
store_64KB(...)
store_80KB(...)
store_160KB(...)
# 问题:element_size 可能有无数种

# JIT:运行时编译
store<ElementSize>(...)
# 第一次:编译 + 缓存
# 后续:直接使用缓存

挑战 3:element_size 的计算

问题:解释 element_size 是什么,以及如何计算。

解答

  • 定义:单个 token 的字节数
  • 计算element_size = 向量维度 * 数据类型字节数

例子

1
2
3
4
5
6
7
8
9
10
11
12
# Llama-7B
# k_cache: [num_tokens, num_layers, num_kv_heads, head_dim]
# = [16000, 32, 8, 128]

# 展平后
k_cache = k_cache.view(num_tokens, -1)
# [16000, 32 * 8 * 128] = [16000, 32768]

# element_size
element_size = 32768 * 2 # float16 = 2 字节
= 65536 字节
= 64 KB

挑战 4:num_splits 的选择

问题:解释为什么根据 element_size 选择不同的 num_splits

解答

  • 目标:每个线程处理 2048 字节(最优)
  • 原因
    • 2048 字节对齐 GPU 缓存行
    • 平衡并行度和线程调度开销

计算

1
2
3
4
5
6
7
8
if element_size % 2048 == 0:
num_splits = element_size // 2048
# 例如:8192 // 2048 = 4
elif element_size % 1024 == 0:
num_splits = element_size // 1024
# 例如:4096 // 1024 = 4(但代码中是 2)
else:
num_splits = 1

挑战 5:vocab_range 在 TP 中的作用

问题:解释 vocab_range 如何在 Tensor Parallelism 中使用。

解答

  • 词表切分:每个 GPU 负责一部分词表
  • 本地查找:只查找自己范围内的 token,其他输出 0
  • All-Reduce:求和得到完整结果

流程

1
2
3
4
5
6
7
8
# GPU 0: vocab_range = (0, 8000)
# output_0 = [weights[100], 0, 0, 0]

# GPU 1: vocab_range = (8000, 8000)
# output_1 = [0, weights[500], 0, 0]

# All-Reduce
# output = output_0 + output_1 + output_2 + output_3

挑战 6:NCCL 初始化流程

问题:解释 NCCL 的初始化流程,为什么需要 NCCL Unique ID。

解答

  • 步骤 1:Rank 0 创建 NCCL Unique ID
  • 步骤 2:广播 ID 到所有进程
  • 步骤 3:所有进程用相同的 ID 初始化 NCCL
  • 为什么需要:所有进程需要用相同的 ID 才能建立通信通道

流程

1
2
3
4
5
6
7
8
9
# Rank 0
nccl_id = create_nccl_uid() # "abc123"
broadcast(nccl_id)

# Rank 1, 2, 3
nccl_id = receive() # "abc123"

# 所有进程
init_nccl(nccl_id) # 建立通信通道

挑战 7:PyTorch vs NCCL 性能

问题:解释为什么直接调用 NCCL 比 PyTorch 的 torch.distributed 更快。

解答

  • 无 PyTorch 封装开销:直接调用 NCCL C++ API
  • 无 Python GIL:C++ 实现,不受 GIL 限制
  • 优化的通信算法:NCCL 针对 GPU 优化

性能对比

1
2
3
4
5
6
7
8
9
# PyTorch
torch.distributed.all_reduce(x)
# 时间:~100μs

# NCCL
comm.all_reduce(x, "sum")
# 时间:~50μs

# 提速:2 倍

挑战 8:内存对齐的重要性

问题:解释为什么内存对齐可以提高性能。

解答

  • GPU 缓存行:GPU 以缓存行为单位读取内存(例如 128 字节)
  • 对齐访问:如果数据对齐缓存行,可以减少缓存行读取次数
  • 未对齐访问:如果数据跨越缓存行,需要读取更多缓存行

例子

1
2
3
4
5
6
7
8
9
// 缓存行大小 = 128 字节

// 未对齐(2000 字节)
// 需要读取:2000 / 128 = 16 个缓存行(向上取整)

// 对齐(2048 字节)
// 需要读取:2048 / 128 = 16 个缓存行(正好)

// 但是未对齐可能跨越缓存行边界,导致额外读取

挑战 9:view(num_tokens, -1) 的作用

问题:解释为什么要将多维 Tensor 展平为 2D。

解答

  • 简化 CUDA Kernel:不需要计算多维索引
  • 性能更好:直接内存拷贝,利用 memcpy 优化
  • 通用性强:不需要知道具体的维度

对比

1
2
3
4
5
6
7
8
9
10
11
// 4D 索引(复杂)
for (int l = 0; l < num_layers; l++) {
for (int h = 0; h < num_kv_heads; h++) {
for (int d = 0; d < head_dim; d++) {
// 计算 4D 索引
}
}
}

// 2D 索引(简单)
memcpy(dst, src, element_size);

挑战 10:完整的性能优化策略

问题:总结 Mini-SGLang 的性能优化策略。

解答

1. 自定义 CUDA Kernels

  • fast_compare_key:C++ 实现,提速 100 倍
  • store_cache:并行拷贝,提速 10 倍
  • indexing:并行查找,提速 10 倍

2. JIT 编译

  • 根据 element_size 动态编译
  • 支持任意模型配置
  • 编译一次,缓存结果

3. 并行优化

  • num_splits:将大向量分成多份并行处理
  • 平衡并行度和线程调度开销

4. 内存对齐

  • 2048 字节对齐 GPU 缓存行
  • 提高内存带宽利用率

5. 直接调用 NCCL

  • 绕过 PyTorch 封装
  • 提速 2 倍

6. Tensor Parallelism

  • 词表切分:vocab_range
  • 本地查找 + All-Reduce

挑战 11:NCCL Unique ID 的作用

问题:解释 NCCL Unique ID 的作用,为什么需要它。

解答

  • 作用:作为"会合点"(Rendezvous Point),让所有进程建立通信通道
  • 类比:就像一个"房间号",所有用相同房间号的进程进入同一个房间
  • 为什么需要:所有进程必须用相同的 ID 才能建立通信

流程

1
2
3
4
5
6
7
8
# Rank 0 创建 ID
nccl_id = "abc123"

# 广播给所有进程
# Rank 0, 1, 2, 3: nccl_id = "abc123"

# 所有进程用相同 ID 初始化
# NCCL 内部建立进程间的通信通道

挑战 12:为什么是 Rank 0 创建 ID

问题:解释为什么是 Rank 0 创建 NCCL ID 并广播。

解答

  • 需要协调者:所有进程必须用相同的 ID
  • 约定俗成:Rank 0 通常是"主进程"
  • 可以是任何 Rank:但需要所有进程都知道是谁

错误做法

1
2
3
4
5
# 每个进程创建自己的 ID
# Rank 0: "abc123"
# Rank 1: "def456"
# Rank 2: "ghi789"
# 问题:ID 不同,无法建立通信

正确做法

1
2
3
# Rank 0 创建并广播
# 所有进程:nccl_id = "abc123"
# 结果:可以建立通信

挑战 13:max_size_bytes 的作用

问题:解释 max_size_bytes 的作用。

解答

  • 作用:限制 NCCL 通信缓冲区的大小
  • 为什么需要
    • NCCL 需要预分配 GPU 内存作为通信缓冲区
    • 缓冲区太大会浪费内存,可能导致 OOM
    • 缓冲区太小会导致通信失败

限制

1
2
max_size_bytes = min(max_size_bytes, ENV.PYNCCL_MAX_BUFFER_SIZE.value)
# 例如:min(10GB, 1GB) = 1GB

分块通信

1
2
3
4
# 如果 Tensor 大小 > max_size_bytes
# 需要分块通信
# 例如:2GB Tensor,1GB 缓冲区
# 分两次:all_reduce(tensor[0:1GB]) + all_reduce(tensor[1GB:2GB])

挑战 14:tp_cpu_group 的作用

问题:解释 tp_cpu_group 的作用。

解答

  • 作用:用于 CPU 端的通信(广播 NCCL ID)
  • 为什么需要
    • NCCL 只能在 GPU 上通信
    • NCCL ID 是 Python 对象(字符串),需要在 CPU 上广播
    • NCCL 还没初始化,无法使用 NCCL 广播

流程

1
2
3
4
5
6
7
# 1. CPU 通信:广播 NCCL ID
torch.distributed.broadcast_object_list(
id_list, src=0, group=tp_cpu_group
)

# 2. GPU 通信:用 ID 初始化 NCCL
comm = PyNCCLImpl(tp_rank, tp_size, max_size_bytes, nccl_id)

挑战 15:PyTorch vs NCCL 性能差异

问题:解释为什么直接调用 NCCL 比 PyTorch 更快。

解答

  • 跳过参数检查:不检查类型、设备、形状
  • 跳过内部调度:不需要选择 backend
  • FFI 开销更小:TVM FFI 比 PyTorch 封装更轻量

性能对比

步骤 PyTorch NCCL 直接调用
Python 函数调用
参数检查
内部调度
调用 NCCL
总时间 100μs 50μs

提速:2 倍


5. 性能测试:benchmark/perf.py

5.1 核心思想

问题:如何准确测量 GPU 代码的性能?

解决方案

  1. 使用 CUDA Event 精确计时
  2. Warmup 避免冷启动
  3. 支持 CUDA Graph 测试
  4. 计算内存带宽

5.2 关键代码

5.2.1 perf_cuda() - CUDA 性能测试

1
2
3
4
5
6
7
def perf_cuda(
f: Callable[[], Any], # 要测试的函数
*,
init_stream: bool = True, # 是否创建新 Stream
repetitions: int = 10, # 重复次数
cuda_graph_repetitions: int | None = 10, # CUDA Graph 重复次数
) -> float:

测试流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# 1. 创建 CUDA Event
tic = torch.cuda.Event(enable_timing=True)
toc = torch.cuda.Event(enable_timing=True)

# 2. Warmup
f() # 预热

# 3. CUDA Graph Capture(可选)
if cuda_graph_repetitions:
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
for _ in range(N):
f()
replay = g.replay
else:
replay = f

# 4. 测量时间
replay() # Warmup
tic.record()
for _ in range(repetitions):
replay()
toc.record()
toc.synchronize()

# 5. 计算平均时间
dur = tic.elapsed_time(toc)
return dur / (N * repetitions)

5.2.2 compare_memory_kernel_perf() - 对比性能

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def compare_memory_kernel_perf(
*,
baseline: Callable[[], Any], # 基线实现
our_impl: Callable[[], Any], # 我们的实现
memory_footprint: int, # 内存占用(字节)
description: str = " ",
) -> Tuple[float, float]:
# 1. 测试基线
rf_cuda(baseline)
bandwidth_0 = memory_footprint / (dur * 1e6) # GB/s

# 2. 测试我们的实现
dur = perf_cuda(our_impl)
bandwidth_1 = memory_footprint / (dur * 1e6) # GB/s

# 3. 输出对比
logger.info(f"{description}Baseline: {dur} ms | {bandwidth_0} GB/s | Our Impl: {dur} ms | {bandwidth_1} GB/s")

return bandwidth_0, bandwidth_1

5.3 为什么用 CUDA Event?

Python time.time() 不准确

1
2
3
4
5
6
7
import time

# 错误做法
start = time.time()
gpu_kernel() # GPU 异步执行
end = time.time()
# 只测量了 CPU 提交任务的时间,不是 GPU 执行时间

CUDA Event 准确

1
2
3
4
5
6
7
# 正确做法
tic.record() # 在 GPU 命令流中
gpu_kernel()
toc.record() # 在 GPU 命令流中插入"结束标记"
toc.synchronize() # 等待 GPU 执行完成

dur = tic.elapsed_time(toc) # GPU 实际执行时间

5.4 为什么需要 Warmup?

原因 1:JIT 编译

1
2
f()  # 第一次:触发 JIT 编译(慢,~100ms)
f() # 第二次:使用缓存(快,~1ms)

原因 2:GPU 缓存预热

1
2
f()  # 第一次:数据加载到缓存(慢)
f() # 第二次:数据已在缓存(快)

原因 3:GPU 频率调整

1
2
f()  # 第一次:GPU 可能处于低频
f() # 第二次:GPU 升频到高频

原因 4:CUDA Context 初始化

1
2
f()  # 第一次:初始化 Context(慢,~10ms)
f() # 第二次:Context 已初始化(快)

5.5 CUDA Graph 的性能提升

不用 CUDA Graph

1
2
3
4
5
for _ in range(100):
kernel_1() GPU 执行
kernel_2() # CPU 提交 → GPU 执行
kernel_3() # CPU 提交 → GPU 执行
# CPU 开销:300 次提交

用 CUDA Graph

1
2
3
4
5
6
7
8
9
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
kernel_1()
kernel_2()
kernel_3()

for _ in range(100):
g.replay() # CPU 提交一次,GPU 重放 3 个 kernel
# CPU 开销:100 次提交

性能提升

  • 减少 CPU → GPU 提交开销
  • 减少 GPU 等待 CPU 时间
  • 提升吞吐量

5.6 带宽计算

公式

1
bandwidth = memory_footprint / (dur * 1e6)  # GB/s

单位转换

1
2
3
4
5
6
7
# memory_footprint: 字节
# dur: 毫秒
# 1e6: 转换因子

bandwidth = memory_footprint / (dur * 1e-3) # 字节/秒
= memory_footprint / (dur * 1e-3) / 1e9 # GB/秒
= memory_footprint / (dur * 1e6) # GB/秒

为什么重要?

1
2
3
4
5
6
7
8
9
10
# GPU 理论带宽(A100)
theoretical_bandwidth = 1555 GB/s

# 实际带宽
actual_bandwidth = 100 GB/s

# 带宽利用率
utilization = 100 / 1555 = 6.4%

# 结论:只利用了 6.4% 的带宽,还有优化空间

优化目标

  • 提高带宽利用率
  • 接近理论带宽(通常能达到 80-90%)

费曼挑战(续)

挑战 16:为什么用 CUDA Event

问题:解释为什么用 CUDA Event 而不是 time.time()

解答

  • GPU 异步执行:CPU 只是提交任务,不等待 GPU 完成
  • time.time() 不准确:只测量了 CPU 提交时间
  • CUDA Event 准确:在 GPU 命令流中插入标记,测量 GPU 实际执行时间

对比

1
2
# time.time():测量 CPU 提交时间(1μs)
# CUDA Event:测量 GPU 执行时间(100μs)

挑战 17:Warmup 的作用

问题:解释为什么需要 Warmup。

解答

  • JIT 编译:第一次调用触发编译
  • GPU 缓存预热:第一次调用加载数据到缓存
  • GPU 频率调整:第一次调用可能处于低频
  • CUDA Context 初始化:第一次 CUDA 操作初始化 Context

结论:第一次调用通常比后续慢,需要 Warmup 避免影响测量


挑战 18:CUDA Graph 的性能提升

问题:解释 CUDA Graph 如何提升性能。

解答

  • 减少 CPU 开销:不用 Graph 需要 300 次提交,用 Graph 只需 100 次
  • 减少 GPU 等待:GPU 不需要等待 CPU 提交下一个 kernel
  • 提升吞吐量:更多时间用于计算,更少时间用于调度

对比

1
2
3
# 不用 Graph:100 * 3 = 300 次提交
# 用 Graph:100 次提交
# 减少:66% 的 CPU 开销

挑战 19:带宽计算公式

问题:解释带宽的计算公式和意义。

解答

  • 公式bandwidth = memory_footprint / (dur * 1e6) (GB/s)
  • 意义:每秒传输的数据量
  • 重要性:衡量带宽利用率,利用率越高性能越好

例子

1
2
3
4
5
6
# memory_footprint = 1GB
# dur = 10ms
# bandwidth = 1e9 / (10 * 1e6) = 100 GB/s

# 理论带宽 = 1555 GB/s (A100)
# 利用率 = 100 / 1555 = 6.4%

挑战 20:完整的性能测试流程

问题:描述完整的性能测试流程。

解答

步骤 1:准备数据和函数

1
2
3
4
5
6
7
8
9
10
# 基线实现
def baseline():
return weights[indices]

# 我们的实现
def our_impl():
return indexing(weights, indices)

# 内存占用
memory_footprint = 2 * batch_size * hidden_size * 2 # 字节

步骤 2:性能测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 1. Warmup
f()

# 2. CUDA Graph Capture(可选)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
f()

# 3. 测量时间
tic.record()
for _ in range(repetitions):
g.replay()
toc.record()
toc.synchronize()

# 4. 计算带宽
dur = tic.elapsed_time(toc) / repetitions
bandwidth = memory_footprint / (dur * 1e6)

步骤 3:对比结果

1
2
3
4
# 输出:
# Baseline: 0.100 ms | 16.384 GB/s
# Our Impl: 0.010 ms | 163.840 GB/s
# 提速:10 倍