学习文件 :distributed/info.py, distributed/impl.py, layers/linear.py, message/backend.py
1. 为什么需要分布式推理?
挑战1:模型太大
大模型参数量巨大(例如 Llama-70B 有 700 亿参数)
单个 GPU 显存不够(例如 A100 只有 80GB)
需要多 GPU 并行
挑战2:推理速度慢
解决方案:Tensor Parallelism(TP)
核心思想 :
将模型的张量(权重、激活)切分到多个 GPU
每个 GPU 计算一部分
通过通信合并结果
类比 :
就像一个大任务分给多个工人
每个工人负责一部分
最后汇总结果
2. DistributedInfo:分布式信息
2.1 核心概念
1 2 3 4 5 6 7 8 9 10 @dataclass(frozen=True ) class DistributedInfo : rank: int size: int def __post_init__ (self ): assert 0 <= self .rank < self .size def is_primary (self ) -> bool : return self .rank == 0
关键字段 :
rank :当前进程的编号
例如:4 个 GPU,rank = 0, 1, 2, 3
rank 0 通常是主进程(primary)
size :总进程数
is_primary() :判断是否是主进程
2.2 全局变量:_TP_INFO
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 _TP_INFO: DistributedInfo | None = None def set_tp_info (rank: int , size: int ) -> None : global _TP_INFO if _TP_INFO is not None : raise RuntimeError("TP info has been set" ) _TP_INFO = DistributedInfo(rank, size) def get_tp_info () -> DistributedInfo: if _TP_INFO is None : raise RuntimeError("TP info has not been set" ) return _TP_INFO def try_get_tp_info () -> DistributedInfo | None : return _TP_INFO
为什么需要全局变量?
共享信息 :
分布式信息在整个程序中都需要使用
避免到处传递参数
全局唯一 :
三个函数的区别 :
函数
作用
未设置时
set_tp_info
设置 TP 信息(只能调用一次)
-
get_tp_info
获取 TP 信息
抛出异常
try_get_tp_info
尝试获取 TP 信息
返回 None
2.3 完整的启动流程
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 import osfrom minisgl.distributed import set_tp_info, get_tp_infodef main (): rank = int (os.environ.get("RANK" , 0 )) world_size = int (os.environ.get("WORLD_SIZE" , 1 )) set_tp_info(rank=rank, size=world_size) tp_info = get_tp_info() if tp_info.is_primary(): print (f"Starting with {tp_info.size} GPUs" ) print (f"Process {tp_info.rank} is ready" ) model = Model() output = model.forward(input ) if __name__ == "__main__" : main()
启动 :
1 2 3 4 5 6 torchrun --nproc_per_node=4 main.py
输出 :
1 2 3 4 5 Starting with 4 GPUs Process 0 is ready Process 1 is ready Process 2 is ready Process 3 is ready
3. DistributedImpl:分布式通信
3.1 两个核心操作
1 2 3 4 5 6 7 @dataclass class DistributedImpl (ABC ): @abstractmethod def all_reduce (self, x: torch.Tensor ) -> torch.Tensor: ... @abstractmethod def all_gather (self, x: torch.Tensor ) -> torch.Tensor: ...
3.1.1 all_reduce:所有进程求和
用途 :
应用场景1:梯度同步
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 all_reduce(grad) optimizer.step()
应用场景2:求平均值
1 2 3 4 5 6 7 8 9 10 11 12 all_reduce(loss) loss = loss / 4
3.1.2 all_gather:收集所有进程的数据
用途 :
收集分布式计算的结果
例如:每个 GPU 计算一部分输出,最后合并
应用场景1:收集分布式计算结果
1 2 3 4 5 6 7 8 9 output = all_gather(output)
应用场景2:Tensor Parallelism 的输出合并
1 2 3 4 5 6 7 8 9 output = all_gather(output)
3.2 all_reduce vs all_gather
操作
作用
输出大小
用途
all_reduce
所有进程求和
不变
梯度同步、求平均值
all_gather
收集所有数据
扩大 tp_size 倍
收集分布式计算结果
关键区别 :
all_reduce :相同的数据,需要求和/求平均
all_gather :不同的数据,需要拼接
4. TorchDistributedImpl:PyTorch 原生实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 @dataclass class TorchDistributedImpl (DistributedImpl ): def all_reduce (self, x: torch.Tensor ) -> torch.Tensor: tp_size = dist.get_world_size() if tp_size == 1 : return x dist.all_reduce(x, op=dist.ReduceOp.SUM) return x def all_gather (self, x: torch.Tensor ) -> torch.Tensor: tp_size = dist.get_world_size() if tp_size == 1 : return x shape = list (x.shape) shape[0 ] = shape[0 ] * tp_size out = torch.empty(shape, dtype=x.dtype, device=x.device) dist.all_gather_into_tensor(out, x) return out
特点 :
使用 PyTorch 原生的 torch.distributed
简单、稳定、兼容性好
性能一般
为什么 all_gather 需要扩大第一维?
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 tp_size = 4 shape = list (x.shape) shape[0 ] = shape[0 ] * tp_size out = torch.empty(shape, dtype=x.dtype, device=x.device) dist.all_gather_into_tensor(out, x)
为什么只扩大第一维?
因为 all_gather 是沿着第一维(batch 维)拼接:
1 2 3 4 5 6 7 8 9 10 11 x_gpu0 = [[1 , 2 ], [3 , 4 ]] x_gpu1 = [[5 , 6 ], [7 , 8 ]] out = [ [1 , 2 ], [3 , 4 ], [5 , 6 ], [7 , 8 ], ]
5. PyNCCLDistributedImpl:NCCL 实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 @dataclass class PyNCCLDistributedImpl (DistributedImpl ): comm: PyNCCLCommunicator def all_reduce (self, x: torch.Tensor ) -> torch.Tensor: self .comm.all_reduce(x, "sum" ) return x def all_gather (self, x: torch.Tensor ) -> torch.Tensor: world_size = get_tp_info().size output_shape = list (x.shape) output_shape[0 ] *= world_size result = x.new_empty(output_shape) self .comm.all_gather(result, x) return result
特点 :
使用自定义的 NCCL 绑定
直接调用 NCCL C++ API
没有 Python 开销
性能更高
NCCL 是什么?
NVIDIA Collective Communications Library
NVIDIA 提供的高性能 GPU 间通信库
专门优化了 GPU 间的数据传输
性能对比 :
实现
小数据
大数据
TorchDistributedImpl
基准
基准
PyNCCLDistributedImpl
快 10-20%
快 20-30%
6. DistributedCommunicator:插件系统
1 2 3 4 5 6 7 8 class DistributedCommunicator : plugins: List [DistributedImpl] = [TorchDistributedImpl()] def all_reduce (self, x: torch.Tensor ) -> torch.Tensor: return self .plugins[-1 ].all_reduce(x) def all_gather (self, x: torch.Tensor ) -> torch.Tensor: return self .plugins[-1 ].all_gather(x)
设计模式:插件系统
plugins 是一个列表,存储多个实现
默认使用 TorchDistributedImpl
可以添加 PyNCCLDistributedImpl(性能更高)
总是使用最后一个插件(plugins[-1])
为什么这样设计?
1 2 3 4 5 6 7 8 9 10 11 12 13 plugins = [TorchDistributedImpl()] enable_pynccl_distributed(...) all_reduce(x) destroy_distributed()
优点 :
灵活切换实现 :
向后兼容 :
默认使用 PyTorch(兼容性好)
可选启用 NCCL(性能高)
优先使用高性能实现 :
最后添加的通常是高性能实现
例如:PyNCCL 比 PyTorch 快
7. enable_pynccl_distributed:启用 PyNCCL
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 def enable_pynccl_distributed ( tp_info: DistributedInfo, tp_cpu_group: torch.distributed.ProcessGroup, max_bytes: int ) -> None : if tp_info.size == 1 : return from minisgl.kernel import init_pynccl comm = init_pynccl( tp_rank=tp_info.rank, tp_size=tp_info.size, tp_cpu_group=tp_cpu_group, max_size_bytes=max_bytes, ) DistributedCommunicator.plugins.append(PyNCCLDistributedImpl(comm))
作用 :
完整的启动流程 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 import torch.distributed as distfrom minisgl.distributed import set_tp_info, get_tp_info, enable_pynccl_distributeddef main (): dist.init_process_group(backend="nccl" ) rank = dist.get_rank() world_size = dist.get_world_size() set_tp_info(rank=rank, size=world_size) tp_info = get_tp_info() tp_cpu_group = dist.new_group() max_bytes = 1024 * 1024 * 1024 enable_pynccl_distributed(tp_info, tp_cpu_group, max_bytes) model = Model() output = model.forward(input ) if __name__ == "__main__" : main()
费曼挑战
挑战1:rank 和 size
问题 :用简单的话解释 rank 和 size 是什么。
解答 :
rank :当前进程的编号(0, 1, 2, …)
size :总进程数(例如 4 个 GPU)
类比 :银行有多个编号的窗口,rank 是窗口编号,size 是总窗口数
挑战2:is_primary() 的作用
问题 :解释 is_primary() 的作用,主进程通常做什么。
解答 :
作用 :判断是否是主进程(rank 0)
主进程职责 :
初始化(只需要做一次的操作)
日志输出(避免重复打印)
保存模型(只需要保存一次)
错误处理(只需要报告一次)
挑战3:全局变量的好处
问题 :解释为什么 _TP_INFO 是全局变量。
解答 :
好处1 :共享信息,避免到处传递参数
好处2 :全局唯一,确保所有代码使用相同的 TP 信息
好处3 :代码简洁,不需要在每个函数中传递 tp_info
对比 :
1 2 3 4 5 6 7 8 9 10 model = Model(tp_info) attention = Attention(tp_info) output = model.forward(x, tp_info) set_tp_info(rank=0 , size=4 ) model = Model() attention = Attention() output = model.forward(x)
挑战4:set_tp_info 只能调用一次
问题 :解释为什么 set_tp_info 只能调用一次。
解答 :
原因 :rank 应该是固定的(由启动脚本决定)
如果允许多次调用 :
rank 可能在运行时改变
导致张量切分错误
导致通信错误
导致数据损坏
挑战5:get_tp_info vs try_get_tp_info
问题 :解释 get_tp_info 和 try_get_tp_info 的区别。
解答 :
get_tp_info :确保 TP 信息已设置,否则抛出异常
try_get_tp_info :尝试获取 TP 信息,未设置则返回 None
适用场景:TP 信息是可选的,可以在单 GPU 和多 GPU 模式下都工作
挑战6:all_reduce vs all_gather
问题 :解释 all_reduce 和 all_gather 的区别。
解答 :
all_reduce :所有进程求和,输出大小不变
用途:梯度同步、求平均值
适用:相同的数据,需要求和/求平均
all_gather :收集所有数据,输出大小扩大 tp_size 倍
用途:收集分布式计算结果
适用:不同的数据,需要拼接
关键场景 :
挑战7:TorchDistributedImpl vs PyNCCLDistributedImpl
问题 :解释两个实现的区别,为什么需要两个。
解答 :
TorchDistributedImpl :PyTorch 原生实现
优点:简单、稳定、兼容性好
缺点:性能一般(有 Python 开销)
PyNCCLDistributedImpl :自定义 NCCL 实现
优点:性能高(直接调用 NCCL C++ API)
缺点:需要额外的依赖
为什么需要两个 :
兼容性:不是所有环境都有 NCCL
灵活性:开发时用 PyTorch(简单),生产时用 NCCL(高性能)
挑战8:插件系统
问题 :解释 DistributedCommunicator 的插件系统是怎么工作的。
解答 :
设计 :plugins 是一个列表,存储多个实现
默认 :使用 TorchDistributedImpl
启用 PyNCCL :添加 PyNCCLDistributedImpl 到列表
使用 :总是使用最后一个插件(plugins[-1])
为什么使用 plugins[-1] :
优先使用高性能实现(最后添加的)
向后兼容(如果没有启用 PyNCCL,使用 PyTorch)
灵活切换(可以动态添加/删除插件)
挑战9:all_gather 扩大第一维
问题 :解释为什么 all_gather 需要扩大第一维。
解答 :
原因 :all_gather 是沿着第一维(batch 维)拼接
例子 :
为什么只扩大第一维 :
因为 all_gather 是按行拼接,不是按列拼接
挑战10:enable_pynccl_distributed 的调用时机
问题 :解释 enable_pynccl_distributed 做了什么,什么时候调用。
解答 :
做了什么 :
什么时候调用 :
在初始化模型之前
在设置 TP 信息之后
只调用一次
8. 分布式线性层:layers/linear.py
8.1 核心思想:张量切分
问题 :一个线性层 y = Wx + b,如果 W 太大(例如 [4096, 4096]),单个 GPU 放不下怎么办?
解决方案 :将 W 切分到多个 GPU,每个 GPU 计算一部分,最后合并结果。
8.2 两种切分方式
8.2.1 列并行(Column Parallel)
特点 :
切分输出维度
每个 GPU 计算不同的输出元素
不需要通信 (结果是独立的)
应用 :
LinearColParallelMerged:通用列并行
LinearQKVMerged:QKV 投影(专门优化)
8.2.2 行并行(Row Parallel)
特点 :
切分输入维度
每个 GPU 计算相同的输出元素的一部分
需要通信 (需要 all_reduce 求和)
应用 :
LinearOProj:Attention 输出投影
LinearRowParallel:通用行并行
8.3 为什么需要两种切分方式?
数学原理 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 y = Wx y[0 :n] = W[0 :n, :] @ x y[n:2n] = W[n:2n, :] @ x y = Wx y_partial_0 = W[:, 0 :n] @ x[0 :n] y_partial_1 = W[:, n:2n] @ x[n:2n] y = y_partial_0 + y_partial_1
8.4 核心代码
8.4.1 基类:_LinearTPImpl
1 2 3 4 5 6 7 8 9 10 11 12 13 14 class _LinearTPImpl (BaseOP ): def __init__ ( self, full_isize: int , full_osize: int , local_isize: int , local_osize: int , has_bias: bool , ): self .weight = torch.empty(local_osize, local_isize) self .bias = torch.empty(local_osize) if has_bias else None def forward (self, x: torch.Tensor ) -> torch.Tensor: return F.linear(x, self .weight, self .bias)
关键字段 :
full_isize:完整的输入维度(例如 409osize`:完整的输出维度(例如 4096)
local_isize:当前 GPU 的输入维度(例如 1024)
local_osize:当前 GPU 的输出维度(例如 1024)
8.4.2 列并行:LinearQKVMerged
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 class LinearQKVMerged (_LinearTPImpl ): def __init__ ( self, hidden_size: int , head_dim: int , num_qo_heads: int , num_kv_heads: int , has_bias: bool , ): tp_info = get_tp_info() GQA_ratio = divide_even(num_qo_heads, num_kv_heads) local_num_kv = divide_even(num_kv_heads, tp_info.size) full_isize = hidden_size full_osize = (GQA_ratio + 2 ) * num_kv_heads * head_dim local_isize = hidden_size local_osize = (GQA_ratio + 2 ) * local_num_kv * head_dim super ().__init__(full_isize, full_osize, local_isize, local_osize, has_bias)
为什么是 (GQA_ratio + 2)?
例子 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 num_qo_heads = 32 num_kv_heads = 8 tp_size = 4 GQA_ratio = 32 / 8 = 4 local_num_kv = 8 / 4 = 2 local_num_q = GQA_ratio * local_num_kv = 4 * 2 = 8 local_num_k = local_num_kv = 2 local_num_v = local_num_kv = 2
8.4.3 行并行:LinearOProj
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 class LinearOProj (_LinearTPImpl ): def __init__ (self, input_size: int , output_size: int , has_bias: bool ): tp_info = get_tp_info() full_isize = input_size full_osize = output_size local_isize = divide_even(input_size, tp_info.size) local_osize = output_size self ._comm = DistributedCommunicator() self ._tp_size = tp_info.size super ().__init__(full_isize, full_osize, local_isize, local_osize, has_bias) def forward (self, x: torch.Tensor ) -> torch.Tensor: y = F.linear(x, self .weight, self .bias) if self ._tp_size > 1 : y = self ._comm.all_reduce(y) return y
关键 :
切分输入维度:local_isize = input_size / tp_size
输出维度不变:local_osize = output_size
需要 all_reduce:合并所有 GPU 的部分结果
8.5 完整的 Attention 数据流
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 GPU 0 : QKV_0 = W_0 @ x → Q[0 :8 ], K[0 :2 ], V[0 :2 ] (1536 维) GPU 1 : QKV_1 = W_1 @ x → Q[8 :16 ], K[2 :4 ], V[2 :4 ] (1536 维) GPU 2 : QKV_2 = W_2 @ x → Q[16 :24 ], K[4 :6 ], V[4 :6 ] (1536 维) GPU 3 : QKV_3 = W_3 @ x → Q[24 :32 ], K[6 :8 ], V[6 :8 ] (1536 维) GPU 0 : attn_0 = Attention(Q[0 :8 ], K[0 :2 ], V[0 :2 ]) (1024 维) GPU 1 : attn_1 = Attention(Q[8 :16 ], K[2 :4 ], V[2 :4 ]) (1024 维) GPU 2 : attn_2 = Attention(Q[16 :24 ], K[4 :6 ], V[4 :6 ]) (1024 维) GPU 3 : attn_3 = Attention(Q[24 :32 ], K[6 :8 ], V[6 :8 ]) (1024 维) GPU 0 : y_0 = W_o_0 @ attn_0 (4096 维) GPU 1 : y_1 = W_o_1 @ attn_1 (4096 维) GPU 2 : y_2 = W_o_2 @ attn_2 (4096 维) GPU 3 : y_3 = W_o_3 @ attn_3 (4096 维) y = all_reduce(y_0 + y_1 + y_2 + y_3) (4096 维)
关键优势 :
只通信一次 :在最后的 all_reduce
完全并行 :QKV 和 Attention 都是独立计算
数据局部性 :每个 GPU 只处理自己的 head
内存效率 :all_reduce 是原地操作
8.6 为什么不用其他方案?
方案对比
方案
QKV
OProj
通信次数
问题
方案1(实际采用)
列并行
行并行
1 次 all_reduce
无
方案2
列并行
列并行
1 次 all_gather
需要收集所有 head,内存开销大
方案3
行并行
行并行
2 次 all_reduce
通信次数多
方案4
行并行
列并行
1 次 all_reduce
Attention 无法并行
方案1 的优势 :
通信次数最少(1 次)
数据局部性最好(每个 GPU 独立计算)
内存效率最高(all_reduce 原地操作)
费曼挑战(续)
挑战11:列并行 vs 行并行
问题 :解释列并行和行并行的区别。
解答 :
列并行 :切分输出维度,每个 GPU 计算不同的输出元素,不需要通信
行并行 :切分输入维度,每个 GPU 计算相同的输出元素的一部分,需要 all_reduce 求和
数学原理 :
1 2 3 4 5 y[0 :n] = W[0 :n, :] @ x (拼接) y = W[:, 0 :n] @ x[0 :n] + W[:, n:2n] @ x[n:2n] (求和)
挑战12:为什么行并行需要 all_reduce
问题 :解释为什么行并行需要 all_reduce,而列并行不需要。
解答 :
行并行 :每个 GPU 计算的是部分和 ,需要 all_reduce 求总和
列并行 :每个 GPU 计算的是独立的输出 ,不需要通信
类比 :
行并行 :4 个人分别计算 1+2, 3+4, 5+6, 7+8,最后求和得到 36
列并行 :4 个人分别计算 1+2, 3+4, 5+6, 7+8,结果是 [3, 7, 11, 15](独立)
挑战13:LinearQKVMerged 的 (GQA_ratio + 2)
问题 :解释为什么 LinearQKVMerged 的输出维度是 (GQA_ratio + 2) * local_num_kv * head_dim。
解答 :
GQA_ratio :每个 KV 头对应的 Q 头数(例如 4)
local_num_kv :每个 GPU 负责的 KV 头数(例如 2)
输出 :
Q 头数:GQA_ratio * local_num_kv = 4 * 2 = 8
K 头数:local_num_kv = 2
V 头数:local_num_kv = 2
总共:8 + 2 + 2 = 12 个头
输出维度:12 * head_dim = (4 + 2) * 2 * head_dim
挑战14:为什么 Attention 用"列并行 + 行并行"
问题 :解释为什么 Attention 层要用"列并行(QKV)+ 行并行(OProj)"的组合。
解答 :
QKV 列并行 :每个 GPU 负责不同的 head,完全独立计算
Attention 并行 :每个 GPU 独立计算自己的 head
OProj 行并行 :数据已经切分了,直接用行并行,最后 all_reduce
优势 :只通信一次,数据局部性好,内存效率高
对比其他方案 :
如果 OProj 用列并行,需要先 all_gather 收集所有 head(内存开销大)
如果 QKV 用行并行,需要先 all_reduce,然后 Attention 无法并行
挑战15:权重大小计算
问题 :假设 hidden_size=4096, num_qo_heads=32, num_kv_heads=8, head_dim=128, tp_size=4,计算每个 GPU 的权重大小。
解答 :
LinearQKVMerged :
1 2 3 4 5 6 7 8 local_num_kv = 8 / 4 = 2 GQA_ratio = 32 / 8 = 4 local_osize = (4 + 2 ) * 2 * 128 = 1536
LinearOProj :
9. 进程间通信:message/backend.py
9.1 核心思想:消息传递
问题 :Mini-SGLang 有 3 个进程(Frontend、Tokenizer、Scheduler),它们如何通信?
解决方案 :定义统一的消息格式,通过 ZMQ 传递。
9.2 消息类型
9.2.1 BaseBackendMsg:基类
1 2 3 4 5 6 7 8 @dataclass class BaseBackendMsg : def encoder (self ) -> Dict : return serialize_type(self ) @staticmethod def decoder (json: Dict ) -> BaseBackendMsg: return deserialize_type(globals (), json)
作用 :
所有消息的基类
提供序列化/反序列化功能(用于 ZMQ 传输)
为什么需要序列化?
ZMQ 只能传输字节流
需要将 Python 对象转换为 JSON
为什么用 globals()?
1 2 3 4 5 6 7 8 9 10 11 12 13 14 json_data = { "__type__" : "UserMsg" , "uid" : 123 , "input_ids" : [...], } type_name = json_data["__type__" ] cls = globals ()[type_name] obj = cls(**json_data)
9.2.2 BatchBackendMsg:批量消息
1 2 3 @dataclass class BatchBackendMsg (BaseBackendMsg ): data: List [BaseBackendMsg]
作用 :
例子 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 send(UserMsg(uid=1 , ...)) send(UserMsg(uid=2 , ...)) send(UserMsg(uid=3 , ...)) send(BatchBackendMsg(data=[ UserMsg(uid=1 , ...), UserMsg(uid=2 , ...), UserMsg(uid=3 , ...), ]))
9.2.3 ExitMsg:退出消息
1 2 3 @dataclass class ExitMsg (BaseBackendMsg ): pass
作用 :
例子 :
1 2 3 exit_msg = ExitMsg() zmq_socket.send_json(exit_msg.encoder())
9.2.4 UserMsg:用户请求
1 2 3 4 5 @dataclass class UserMsg (BaseBackendMsg ): uid: int input_ids: torch.Tensor sampling_params: SamplingParams
作用 :
表示一个用户请求
从 Tokenizer 发送到 Scheduler
为什么 input_ids 是 CPU Tensor?
ZMQ 传输的是 CPU 数据
GPU Tensor 无法直接序列化
Scheduler 收到后会复制到 GPU
为什么 input_ids 是 1D?
每个 UserMsg 表示一个请求
一个请求的 token 序列是 1D 的
例如:[1, 2, 3, 4, 5](5 个 token)
9.3 完整的通信流程
场景:3 个用户请求同时到达
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 request_A = {"prompt" : "Hello" , "temperature" : 0.7 } request_B = {"prompt" : "World" , "temperature" : 0.8 } request_C = {"prompt" : "!" , "temperature" : 0.9 } zmq_socket.send_json([ {"uid" : 1 , "prompt" : "Hello" , "temperature" : 0.7 }, {"uid" : 2 , "prompt" : "World" , "temperature" : 0.8 }, {"uid" : 3 , "prompt" : "!" , "temperature" : 0.9 }, ]) data_list = zmq_socket.recv_json() user_msgs = [] for data in data_list: input_ids = tokenizer.encode(data["prompt" ]) user_msg = UserMsg( uid=data["uid" ], input_ids=torch.tensor(input_ids, dtype=torch.int32), sampling_params=SamplingParams(temperature=data["temperature" ]), ) user_msgs.append(user_msg) batch_msg = BatchBackendMsg(data=user_msgs) json_data = batch_msg.encoder() zmq_socket.send_json(json_data) json_data = zmq_socket.recv_json() batch_msg = BaseBackendMsg.decoder(json_data) user_msgs = batch_msg.data input_ids_list = [msg.input_ids for msg in user_msgs] max_len = max (len (ids) for ids in input_ids_list) padded_input_ids = [] for ids in input_ids_list: padded = torch.cat([ids, torch.zeros(max_len - len (ids), dtype=torch.int32)]) padded_input_ids.append(padded) batch_input_ids = torch.stack(padded_input_ids).to("cuda" ) output_tokens = model.forward(batch_input_ids)
9.4 关键设计
9.4.1 Frontend → Tokenizer 不用 BatchBackendMsg
为什么?
Frontend 发送的是原始文本 ,还没有分词
BatchBackendMsg 是用于 BaseBackendMsg 的批量
Frontend 只需要发送普通的 JSON 列表
9.4.2 Tokenizer → Scheduler 用 BatchBackendMsg
为什么?
Tokenizer 已经创建了 UserMsg 对象
需要批量发送多个 UserMsg
使用 BatchBackendMsg 包装
9.4.3 Scheduler 需要 Padding
为什么?
3 个请求的长度不同:[1,2,3], [4,5,6], [7]
GPU 批量计算需要相同长度
用 0 填充短的序列
Padding 示例 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 input_ids_A = [1 , 2 , 3 ] input_ids_B = [4 , 5 , 6 ] input_ids_C = [7 ] input_ids_A = [1 , 2 , 3 ] input_ids_B = [4 , 5 , 6 ] input_ids_C = [7 , 0 , 0 ] batch_input_ids = torch.tensor([ [1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 0 , 0 ], ])
9.5 性能优化
9.5.1 批量通信
1 2 3 4 5 6 7 8 9 10 11 12 13 14 send(UserMsg(uid=1 , input_ids=[1 ,2 ,3 ], ...)) send(UserMsg(uid=2 , input_ids=[4 ,5 ,6 ], ...)) send(UserMsg(uid=3 , input_ids=[7 ], ...)) send(BatchBackendMsg(data=[ UserMsg(uid=1 , input_ids=[1 ,2 ,3 ], ...), UserMsg(uid=2 , input_ids=[4 ,5 ,6 ], ...), UserMsg(uid=3 , input_ids=[7 ], ...), ]))
9.5.2 GPU 批量计算
1 2 3 4 5 6 7 8 9 10 11 12 13 14 output_A = model.forward([1 , 2 , 3 ]) output_B = model.forward([4 , 5 , 6 ]) output_C = model.forward([7 , 0 , 0 ]) output = model.forward([ [1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 0 , 0 ], ])
9.6 完整流程图
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 Frontend (3 个 HTTP 请求) ↓ (ZMQ: JSON 列表) Tokenizer ↓ (分词) UserMsg(uid=1, input_ids=[1,2,3]) UserMsg(uid=2, input_ids=[4,5,6]) UserMsg(uid=3, input_ids=[7]) ↓ (包装) BatchBackendMsg(data=[...]) ↓ (ZMQ: JSON) Scheduler ↓ (反序列化) batch_msg.data = [UserMsg, UserMsg, UserMsg] ↓ (Padding) batch_input_ids = [[1,2,3], [4,5,6], [7,0,0]] ↓ (GPU 批量推理) output_tokens = [[10], [11], [12]] ↓ (分发结果) uid=1 → token=10 uid=2 → token=11 uid=3 → token=12
费曼挑战(续)
挑战16:为什么需要序列化
问题 :解释为什么需要序列化/反序列化。
解答 :
ZMQ 只能传输字节流 :不能直接传输 Python 对象
序列化 :将 Python 对象转换为 JSON(Dict)
反序列化 :将 JSON 转换回 Python 对象
作用 :实现跨进程通信
流程 :
1 2 3 4 5 6 7 8 msg = UserMsg(uid=123 , ...) json_data = msg.encoder() zmq_socket.send_json(json_data) json_data = zmq_socket.recv_json() msg = BaseBackendMsg.decoder(json_data)
挑战17:BatchBackendMsg 的作用
问题 :解释 BatchBackendMsg 的作用和优势。
解答 :
作用 :批量发送多个消息
优势 :
减少通信次数(3 次 → 1 次)
提高吞吐量(提速 3 倍)
减少 ZMQ 开销
对比 :
问题 :解释为什么 UserMsg 的 input_ids 是 CPU Tensor。
解答 :
ZMQ 传输的是 CPU 数据 :无法直接传输 GPU Tensor
GPU Tensor 无法序列化 :需要先复制到 CPU
Scheduler 收到后会复制到 GPU :用于推理
流程 :
1 2 3 4 5 6 7 8 input_ids_cpu = torch.tensor([1 , 2 , 3 ], dtype=torch.int32) user_msg = UserMsg(uid=123 , input_ids=input_ids_cpu, ...) send(user_msg) user_msg = recv() input_ids_gpu = user_msg.input_ids.to("cuda" )
挑战19:为什么需要 Padding
问题 :解释为什么 Scheduler 需要 Padding。
解答 :
问题 :3 个请求的长度不同(3, 3, 1)
GPU 批量计算需要相同长度 :Tensor 必须是矩形
解决方案 :用 0 填充短的序列
例子 :
1 2 3 4 5 [[1 , 2 , 3 ], [4 , 5 , 6 ], [7 ]] [[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 0 , 0 ]]
挑战20:完整的消息流
问题 :描述 3 个用户请求从 Frontend 到 Scheduler 的完整消息流。
解答 :
Frontend → Tokenizer :
发送 JSON 列表(原始文本)
不用 BatchBackendMsg
Tokenizer 处理 :
分词,创建 UserMsg
包装成 BatchBackendMsg
Tokenizer → Scheduler :
序列化 BatchBackendMsg
通过 ZMQ 发送
Scheduler 处理 :
反序列化 BatchBackendMsg
提取 UserMsg
Padding 到相同长度
批量推理
关键 :
批量通信(减少通信次数)
批量推理(提高 GPU 利用率)
Padding(满足 GPU 批量计算要求)