学习文件distributed/info.py, distributed/impl.py, layers/linear.py, message/backend.py


1. 为什么需要分布式推理?

挑战1:模型太大

  • 大模型参数量巨大(例如 Llama-70B 有 700 亿参数)
  • 单个 GPU 显存不够(例如 A100 只有 80GB)
  • 需要多 GPU 并行

挑战2:推理速度慢

  • 单 GPU 计算速度有限
  • 需要并行计算提高吞吐量

解决方案:Tensor Parallelism(TP)

核心思想

  • 将模型的张量(权重、激活)切分到多个 GPU
  • 每个 GPU 计算一部分
  • 通过通信合并结果

类比

  • 就像一个大任务分给多个工人
  • 每个工人负责一部分
  • 最后汇总结果

2. DistributedInfo:分布式信息

2.1 核心概念

1
2
3
4
5
6
7
8
9
10
@dataclass(frozen=True)
class DistributedInfo:
rank: int # 当前进程的编号(0, 1, 2, ...)
size: int # 总进程数(例如 4 个 GPU)

def __post_init__(self):
assert 0 <= self.rank < self.size

def is_primary(self) -> bool:
return self.rank == 0 # rank 0 是主进程

关键字段

  1. rank:当前进程的编号

    • 例如:4 个 GPU,rank = 0, 1, 2, 3
    • rank 0 通常是主进程(primary)
  2. size:总进程数

    • 例如:4 个 GPU,size = 4
  3. is_primary():判断是否是主进程

    • 主进程通常负责初始化、日志输出、保存模型等

2.2 全局变量:_TP_INFO

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
_TP_INFO: DistributedInfo | None = None

def set_tp_info(rank: int, size: int) -> None:
global _TP_INFO
if _TP_INFO is not None:
raise RuntimeError("TP info has been set")
_TP_INFO = DistributedInfo(rank, size)

def get_tp_info() -> DistributedInfo:
if _TP_INFO is None:
raise RuntimeError("TP info has not been set")
return _TP_INFO

def try_get_tp_info() -> DistributedInfo | None:
return _TP_INFO

为什么需要全局变量?

  1. 共享信息

    • 分布式信息在整个程序中都需要使用
    • 避免到处传递参数
  2. 全局唯一

    • 确保所有代码使用相同的 TP 信息
    • 避免不一致

三个函数的区别

函数 作用 未设置时
set_tp_info 设置 TP 信息(只能调用一次) -
get_tp_info 获取 TP 信息 抛出异常
try_get_tp_info 尝试获取 TP 信息 返回 None

2.3 完整的启动流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# main.py
import os
from minisgl.distributed import set_tp_info, get_tp_info

def main():
# 步骤1:从环境变量读取 rank 和 world_size
rank = int(os.environ.get("RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))

# 步骤2:设置 TP 信息(只设置一次)
set_tp_info(rank=rank, size=world_size)

# 步骤3:后续代码使用 get_tp_info()
tp_info = get_tp_info()

if tp_info.is_primary():
print(f"Starting with {tp_info.size} GPUs")

print(f"Process {tp_info.rank} is ready")

# 初始化模型
model = Model() # Model 内部会调用 get_tp_info()

# 运行推理
output = model.forward(input)

if __name__ == "__main__":
main()

启动

1
2
3
4
5
6
# 使用 torchrun 启动 4 个进程
torchrun --nproc_per_node=4 main.py

# torchrun 会自动设置环境变量:
# - RANK (或 LOCAL_RANK)
# - WORLD_SIZE

输出

1
2
3
4
5
Starting with 4 GPUs
Process 0 is ready
Process 1 is ready
Process 2 is ready
Process 3 is ready

3. DistributedImpl:分布式通信

3.1 两个核心操作

1
2
3
4
5
6
7
@dataclass
class DistributedImpl(ABC):
@abstractmethod
def all_reduce(self, x: torch.Tensor) -> torch.Tensor: ...

@abstractmethod
def all_gather(self, x: torch.Tensor) -> torch.Tensor: ...

3.1.1 all_reduce:所有进程求和

1
2
3
4
5
6
7
8
9
10
11
# 4 个 GPU,每个 GPU 有一个值
# GPU 0: x = [1, 2]
# GPU 1: x = [3, 4]
# GPU 2: x = [5, 6]
# GPU 3: x = [7, 8]

# all_reduce 后,每个 GPU 都得到总和
# GPU 0: x = [16, 20] # [1+3+5+7, 2+4+6+8]
# GPU 1: x = [16, 20]
# GPU 2: x = [16, 20]
# GPU 3: x = [16, 20]

用途

  • 梯度同步(训练时)
  • 全局统计(例如求平均值)

应用场景1:梯度同步

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 4 个 GPU,每个 GPU 计算一部分数据的梯度
# GPU 0: grad = [0.1, 0.2]
# GPU 1: grad = [0.3, 0.4]
# GPU 2: grad = [0.5, 0.6]
# GPU 3: grad = [0.7, 0.8]

# all_reduce 后,每个 GPU 都得到总梯度
all_reduce(grad)
# GPU 0: grad = [1.6, 2.0]
# GPU 1: grad = [1.6, 2.0]
# GPU 2: grad = [1.6, 2.0]
# GPU 3: grad = [1.6, 2.0]

# 然后每个 GPU 用相同的梯度更新参数
optimizer.step()

应用场景2:求平均值

1
2
3
4
5
6
7
8
9
10
11
12
# 4 个 GPU,每个 GPU 处理一部分数据
# GPU 0: loss = 1.0
# GPU 1: loss = 2.0
# GPU 2: loss = 3.0
# GPU 3: loss = 4.0

# all_reduce 求和
all_reduce(loss)
# 每个 GPU: loss = 10.0

# 除以 GPU 数量得到平均值
loss = loss / 4 # 2.5

3.1.2 all_gather:收集所有进程的数据

1
2
3
4
5
6
7
8
9
10
11
# 4 个 GPU,每个 GPU 有一个值
# GPU 0: x = [1, 2]
# GPU 1: x = [3, 4]
# GPU 2: x = [5, 6]
# GPU 3: x = [7, 8]

# all_gather 后,每个 GPU 都得到所有数据
# GPU 0: x = [1, 2, 3, 4, 5, 6, 7, 8]
# GPU 1: x = [1, 2, 3, 4, 5, 6, 7, 8]
# GPU 2: x = [1, 2, 3, 4, 5, 6, 7, 8]
# GPU 3: x = [1, 2, 3, 4, 5, 6, 7, 8]

用途

  • 收集分布式计算的结果
  • 例如:每个 GPU 计算一部分输出,最后合并

应用场景1:收集分布式计算结果

1
2
3
4
5
6
7
8
9
# 4 个 GPU,每个 GPU 计算一部分输出
# GPU 0: output = [1, 2] (batch_size=2)
# GPU 1: output = [3, 4] (batch_size=2)
# GPU 2: output = [5, 6] (batch_size=2)
# GPU 3: output = [7, 8] (batch_size=2)

# all_gather 收集所有输出
output = all_gather(output)
# 每个 GPU: output = [1, 2, 3, 4, 5, 6, 7, 8] (batch_size=8)

应用场景2:Tensor Parallelism 的输出合并

1
2
3
4
5
6
7
8
9
# 4 个 GPU,每个 GPU 计算一部分 head
# GPU 0: output = [batch_size, 8, head_dim] (heads 0-7)
# GPU 1: output = [batch_size, 8, head_dim] (heads 8-15)
# GPU 2: output = [batch_size, 8, head_dim] (heads 16-23)
# GPU 3: output = [batch_size, 8, head_dim] (heads 24-31)

# all_gather 收集所有 head
output = all_gather(output)
# 每个 GPU: output = [batch_size, 32, head_dim] (所有 heads)

3.2 all_reduce vs all_gather

操作 作用 输出大小 用途
all_reduce 所有进程求和 不变 梯度同步、求平均值
all_gather 收集所有数据 扩大 tp_size 倍 收集分布式计算结果

关键区别

  • all_reduce:相同的数据,需要求和/求平均
  • all_gather:不同的数据,需要拼接

4. TorchDistributedImpl:PyTorch 原生实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
@dataclass
class TorchDistributedImpl(DistributedImpl):
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
tp_size = dist.get_world_size()
if tp_size == 1:
return x # 单 GPU,直接返回
dist.all_reduce(x, op=dist.ReduceOp.SUM)
return x

def all_gather(self, x: torch.Tensor) -> torch.Tensor:
tp_size = dist.get_world_size()
if tp_size == 1:
return x # 单 GPU,直接返回
shape = list(x.shape)
shape[0] = shape[0] * tp_size # 第一维扩大 tp_size 倍
out = torch.empty(shape, dtype=x.dtype, device=x.device)
dist.all_gather_into_tensor(out, x)
return out

特点

  • 使用 PyTorch 原生的 torch.distributed
  • 简单、稳定、兼容性好
  • 性能一般

为什么 all_gather 需要扩大第一维?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 输入
# GPU 0: x.shape = [2, 128] (batch_size=2, hidden_dim=128)
# GPU 1: x.shape = [2, 128]
# GPU 2: x.shape = [2, 128]
# GPU 3: x.shape = [2, 128]

# all_gather 实现
tp_size = 4
shape = list(x.shape) # [2, 128]
shape[0] = shape[0] * tp_size # [2*4, 128] = [8, 128]
out = torch.empty(shape, dtype=x.dtype, device=x.device)
dist.all_gather_into_tensor(out, x)

# 输出
# 每个 GPU: out.shape = [8, 128]
# out = [
# x_gpu0[0], x_gpu0[1], # 来自 GPU 0
# x_gpu1[0], x_gpu1[1], # 来自 GPU 1
# x_gpu2[0], x_gpu2[1], # 来自 GPU 2
# x_gpu3[0], x_gpu3[1], # 来自 GPU 3
# ]

为什么只扩大第一维?

因为 all_gather 是沿着第一维(batch 维)拼接:

1
2
3
4
5
6
7
8
9
10
11
# 输入
x_gpu0 = [[1, 2], [3, 4]] # shape: [2, 2]
x_gpu1 = [[5, 6], [7, 8]] # shape: [2, 2]

# all_gather 后
out = [
[1, 2], # 来自 GPU 0
[3, 4], # 来自 GPU 0
[5, 6], # 来自 GPU 1
[7, 8], # 来自 GPU 1
] # shape: [4, 2] (第一维扩大 2 倍)

5. PyNCCLDistributedImpl:NCCL 实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
@dataclass
class PyNCCLDistributedImpl(DistributedImpl):
comm: PyNCCLCommunicator

def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
self.comm.all_reduce(x, "sum")
return x

def all_gather(self, x: torch.Tensor) -> torch.Tensor:
world_size = get_tp_info().size
output_shape = list(x.shape)
output_shape[0] *= world_size
result = x.new_empty(output_shape)
self.comm.all_gather(result, x)
return result

特点

  • 使用自定义的 NCCL 绑定
  • 直接调用 NCCL C++ API
  • 没有 Python 开销
  • 性能更高

NCCL 是什么?

  • NVIDIA Collective Communications Library
  • NVIDIA 提供的高性能 GPU 间通信库
  • 专门优化了 GPU 间的数据传输

性能对比

实现 小数据 大数据
TorchDistributedImpl 基准 基准
PyNCCLDistributedImpl 快 10-20% 快 20-30%

6. DistributedCommunicator:插件系统

1
2
3
4
5
6
7
8
class DistributedCommunicator:
plugins: List[DistributedImpl] = [TorchDistributedImpl()]

def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return self.plugins[-1].all_reduce(x)

def all_gather(self, x: torch.Tensor) -> torch.Tensor:
return self.plugins[-1].all_gather(x)

设计模式:插件系统

  • plugins 是一个列表,存储多个实现
  • 默认使用 TorchDistributedImpl
  • 可以添加 PyNCCLDistributedImpl(性能更高)
  • 总是使用最后一个插件(plugins[-1]

为什么这样设计?

1
2
3
4
5
6
7
8
9
10
11
12
13
# 初始状态
plugins = [TorchDistributedImpl()]

# 如果启用 PyNCCL
enable_pynccl_distributed(...)
# plugins = [TorchDistributedImpl(), PyNCCLDistributedImpl()]

# 使用时
all_reduce(x) # 使用 plugins[-1],即 PyNCCLDistributedImpl

# 如果禁用
destroy_distributed()
# plugins = []

优点

  1. 灵活切换实现

    • 可以动态添加/删除插件
    • 不影响现有代码
  2. 向后兼容

    • 默认使用 PyTorch(兼容性好)
    • 可选启用 NCCL(性能高)
  3. 优先使用高性能实现

    • 最后添加的通常是高性能实现
    • 例如:PyNCCL 比 PyTorch 快

7. enable_pynccl_distributed:启用 PyNCCL

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def enable_pynccl_distributed(
tp_info: DistributedInfo,
tp_cpu_group: torch.distributed.ProcessGroup,
max_bytes: int
) -> None:
if tp_info.size == 1:
return # 单 GPU,不需要

from minisgl.kernel import init_pynccl

comm = init_pynccl(
tp_rank=tp_info.rank,
tp_size=tp_info.size,
tp_cpu_group=tp_cpu_group,
max_size_bytes=max_bytes,
)

DistributedCommunicator.plugins.append(PyNCCLDistributedImpl(comm))

作用

  • 初始化 PyNCCL 通信器
  • 添加到插件列表

完整的启动流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# main.py
import torch.distributed as dist
from minisgl.distributed import set_tp_info, get_tp_info, enable_pynccl_distributed

def main():
# 步骤1:初始化 PyTorch 分布式
dist.init_process_group(backend="nccl")

# 步骤2:设置 TP 信息
rank = dist.get_rank()
world_size = dist.get_world_size()
set_tp_info(rank=rank, size=world_size)

# 步骤3:启用 PyNCCL(可选,提高性能)
tp_info = get_tp_info()
tp_cpu_group = dist.new_group() # 创建 CPU 通信组
max_bytes = 1024 * 1024 * 1024 # 1GB
enable_pynccl_distributed(tp_info, tp_cpu_group, max_bytes)

# 步骤4:初始化模型
model = Model()

# 步骤5:运行推理
output = model.forward(input)

if __name__ == "__main__":
main()

费曼挑战

挑战1:rank 和 size

问题:用简单的话解释 rank 和 size 是什么。

解答

  • rank:当前进程的编号(0, 1, 2, …)
  • size:总进程数(例如 4 个 GPU)
  • 类比:银行有多个编号的窗口,rank 是窗口编号,size 是总窗口数

挑战2:is_primary() 的作用

问题:解释 is_primary() 的作用,主进程通常做什么。

解答

  • 作用:判断是否是主进程(rank 0)
  • 主进程职责
    • 初始化(只需要做一次的操作)
    • 日志输出(避免重复打印)
    • 保存模型(只需要保存一次)
    • 错误处理(只需要报告一次)

挑战3:全局变量的好处

问题:解释为什么 _TP_INFO 是全局变量。

解答

  • 好处1:共享信息,避免到处传递参数
  • 好处2:全局唯一,确保所有代码使用相同的 TP 信息
  • 好处3:代码简洁,不需要在每个函数中传递 tp_info

对比

1
2
3
4
5
6
7
8
9
10
# ❌ 不用全局变量
model = Model(tp_info)
attention = Attention(tp_info)
output = model.forward(x, tp_info)

# ✅ 用全局变量
set_tp_info(rank=0, size=4)
model = Model()
attention = Attention()
output = model.forward(x)

挑战4:set_tp_info 只能调用一次

问题:解释为什么 set_tp_info 只能调用一次。

解答

  • 原因:rank 应该是固定的(由启动脚本决定)
  • 如果允许多次调用
    • rank 可能在运行时改变
    • 导致张量切分错误
    • 导致通信错误
    • 导致数据损坏

挑战5:get_tp_info vs try_get_tp_info

问题:解释 get_tp_info 和 try_get_tp_info 的区别。

解答

  • get_tp_info:确保 TP 信息已设置,否则抛出异常
    • 适用场景:必须有 TP 信息才能继续
  • try_get_tp_info:尝试获取 TP 信息,未设置则返回 None
    • 适用场景:TP 信息是可选的,可以在单 GPU 和多 GPU 模式下都工作

挑战6:all_reduce vs all_gather

问题:解释 all_reduce 和 all_gather 的区别。

解答

  • all_reduce:所有进程求和,输出大小不变
    • 用途:梯度同步、求平均值
    • 适用:相同的数据,需要求和/求平均
  • all_gather:收集所有数据,输出大小扩大 tp_size 倍
    • 用途:收集分布式计算结果
    • 适用:不同的数据,需要拼接

关键场景

1
2
3
4
5
6
7
8
9
# all_reduce:梯度同步
# GPU 0: grad = [0.1, 0.2]
# GPU 1: grad = [0.3, 0.4]
# all_reduce → 每个 GPU: grad = [0.4, 0.6]

# all_gather:收集输出
# GPU 0: output = [1, 2]
# GPU 1: output = [3, 4]
# all_gather → 每个 GPU: output = [1, 2, 3, 4]

挑战7:TorchDistributedImpl vs PyNCCLDistributedImpl

问题:解释两个实现的区别,为什么需要两个。

解答

  • TorchDistributedImpl:PyTorch 原生实现
    • 优点:简单、稳定、兼容性好
    • 缺点:性能一般(有 Python 开销)
  • PyNCCLDistributedImpl:自定义 NCCL 实现
    • 优点:性能高(直接调用 NCCL C++ API)
    • 缺点:需要额外的依赖

为什么需要两个

  • 兼容性:不是所有环境都有 NCCL
  • 灵活性:开发时用 PyTorch(简单),生产时用 NCCL(高性能)

挑战8:插件系统

问题:解释 DistributedCommunicator 的插件系统是怎么工作的。

解答

  • 设计:plugins 是一个列表,存储多个实现
  • 默认:使用 TorchDistributedImpl
  • 启用 PyNCCL:添加 PyNCCLDistributedImpl 到列表
  • 使用:总是使用最后一个插件(plugins[-1])

为什么使用 plugins[-1]

  • 优先使用高性能实现(最后添加的)
  • 向后兼容(如果没有启用 PyNCCL,使用 PyTorch)
  • 灵活切换(可以动态添加/删除插件)

挑战9:all_gather 扩大第一维

问题:解释为什么 all_gather 需要扩大第一维。

解答

  • 原因:all_gather 是沿着第一维(batch 维)拼接
  • 例子
    1
    2
    3
    # GPU 0: x.shape = [2, 128]
    # GPU 1: x.shape = [2, 128]
    # all_gather → out.shape = [4, 128] (第一维扩大 2 倍)

为什么只扩大第一维

  • 因为 all_gather 是按行拼接,不是按列拼接

挑战10:enable_pynccl_distributed 的调用时机

问题:解释 enable_pynccl_distributed 做了什么,什么时候调用。

解答

  • 做了什么
    • 初始化 PyNCCL 通信器
    • 添加到插件列表
  • 什么时候调用
    • 在初始化模型之前
    • 在设置 TP 信息之后
    • 只调用一次

8. 分布式线性层:layers/linear.py

8.1 核心思想:张量切分

问题:一个线性层 y = Wx + b,如果 W 太大(例如 [4096, 4096]),单个 GPU 放不下怎么办?

解决方案:将 W 切分到多个 GPU,每个 GPU 计算一部分,最后合并结果。

8.2 两种切分方式

8.2.1 列并行(Column Parallel)

1
2
3
4
5
6
7
8
9
# 沿着输出维度切分(切分权重矩阵的行)
# 假设 W = [4096, 4096],4 个 GPU

# GPU 0: W[0:1024, :] → 输出 y[0:1024]
# GPU 1: W[1024:2048, :] → 输出 y[1024:2048]
# GPU 2: W[2048:3072, :] → 输出 y[2048:3072]
# GPU 3: W[3072:4096, :] → 输出 y[3072:4096]

# 结果:每个 GPU 的输出是独立的,不需要通信

特点

  • 切分输出维度
  • 每个 GPU 计算不同的输出元素
  • 不需要通信(结果是独立的)

应用

  • LinearColParallelMerged:通用列并行
  • LinearQKVMerged:QKV 投影(专门优化)

8.2.2 行并行(Row Parallel)

1
2
3
4
5
6
7
8
9
10
# 沿着输入维度切分(切分权重矩阵的列)
# 假设 W = [4096, 4096],4 个 GPU

# GPU 0: W[:, 0:1024] → 输出 y_partial_0
# GPU 1: W[:, 1024:2048] → 输出 y_partial_1
# GPU 2: W[:, 2048:3072] → 输y_partial_2
# GPU 3: W[:, 3072:4096] → 输出 y_partial_3

# 结果:y = y_partial_0 + y_partial_1 + y_partial_2 + y_partial_3
# 需要 all_reduce 求和

特点

  • 切分输入维度
  • 每个 GPU 计算相同的输出元素的一部分
  • 需要通信(需要 all_reduce 求和)

应用

  • LinearOProj:Attention 输出投影
  • LinearRowParallel:通用行并行

8.3 为什么需要两种切分方式?

数学原理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 列并行:按输出维度切分
y = Wx
# 切分 W 的行
y[0:n] = W[0:n, :] @ x
y[n:2n] = W[n:2n, :] @ x
# 结果:拼接(不需要求和)

# 行并行:按输入维度切分
y = Wx
# 切分 W 的列
y_partial_0 = W[:, 0:n] @ x[0:n]
y_partial_1 = W[:, n:2n] @ x[n:2n]
# 结果:求和(需要 all_reduce)
y = y_partial_0 + y_partial_1

8.4 核心代码

8.4.1 基类:_LinearTPImpl

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class _LinearTPImpl(BaseOP):
def __init__(
self,
full_isize: int, # 完整的输入维度
full_osize: int, # 完整的输出维度
local_isize: int, # 当前 GPU 的输入维度
local_osize: int, # 当前 GPU 的输出维度
has_bias: bool,
):
self.weight = torch.empty(local_osize, local_isize)
self.bias = torch.empty(local_osize) if has_bias else None

def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight, self.bias)

关键字段

  • full_isize:完整的输入维度(例如 409osize`:完整的输出维度(例如 4096)
  • local_isize:当前 GPU 的输入维度(例如 1024)
  • local_osize:当前 GPU 的输出维度(例如 1024)

8.4.2 列并行:LinearQKVMerged

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class LinearQKVMerged(_LinearTPImpl):
def __init__(
self,
hidden_size: int,
head_dim: int,
num_qo_heads: int,
num_kv_heads: int,
has_bias: bool,
):
tp_info = get_tp_info()

GQA_ratio = divide_even(num_qo_heads, num_kv_heads)
local_num_kv = divide_even(num_kv_heads, tp_info.size)
full_isize = hidden_size
full_osize = (GQA_ratio + 2) * num_kv_heads * head_dim
local_isize = hidden_size
local_osize = (GQA_ratio + 2) * local_num_kv * head_dim
super().__init__(full_isize, full_osize, local_isize, local_osize, has_bias)

为什么是 (GQA_ratio + 2)

1
2
3
4
5
6
7
8
9
10
11
12
# GQA(Grouped Query Attention)
# 例如:num_qo_heads=32, num_kv_heads=8
# GQA_ratio = 32 / 8 = 4(每个 KV 头对应 4 个 Q 头)

# 每个 GPU 负责 local_num_kv 个 KV 头
# 对应的 Q 头数:GQA_ratio * local_num_kv
# 对应的 K 头数:local_num_kv
# 对应的 V 头数:local_num_kv

# 总输出维度:
# (GQA_ratio * local_num_kv + local_num_kv + local_num_kv) * head_dim
# = (GQA_ratio + 2) * local_num_kv * head_dim

例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 假设 4 个 GPU
num_qo_heads = 32
num_kv_heads = 8
tp_size = 4

GQA_ratio = 32 / 8 = 4
local_num_kv = 8 / 4 = 2

# 每个 GPU 负责
local_num_q = GQA_ratio * local_num_kv = 4 * 2 = 8
local_num_k = local_num_kv = 2
local_num_v = local_num_kv = 2

# 总共:8Q + 2K + 2V = 12 个头
# 输出维度:12 * head_dim

8.4.3 行并行:LinearOProj

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class LinearOProj(_LinearTPImpl):
def __init__(self, input_size: int, output_size: int, has_bias: bool):
tp_info = get_tp_info()
full_isize = input_size
full_osize = output_size
local_isize = divide_even(input_size, tp_info.size)
local_osize = output_size
self._comm = DistributedCommunicator()
self._tp_size = tp_info.size
super().__init__(full_isize, full_osize, local_isize, local_osize, has_bias)

def forward(self, x: torch.Tensor) -> torch.Tensor:
y = F.linear(x, self.weight, self.bias)
if self._tp_size > 1:
y = self._comm.all_reduce(y) # 合并所有 GPU 的结果
return y

关键

  • 切分输入维度:local_isize = input_size / tp_size
  • 输出维度不变:local_osize = output_size
  • 需要 all_reduce:合并所有 GPU 的部分结果

8.5 完整的 Attention 数据流

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# 假设 4 个 GPU,hidden_size=4096,num_qo_heads=32,num_kv_heads=8

# 输入:x [batch, 4096](每个 GPU 都有完整的 x)

# 1. QKV 投影(列并行)
GPU 0: QKV_0 = W_0 @ x → Q[0:8], K[0:2], V[0:2] (1536 维)
GPU 1: QKV_1 = W_1 @ x → Q[8:16], K[2:4], V[2:4] (1536 维)
GPU 2: QKV_2 = W_2 @ x → Q[16:24], K[4:6], V[4:6] (1536 维)
GPU 3: QKV_3 = W_3 @ x → Q[24:32], K[6:8], V[6:8] (1536 维)
# 通信:0 次

# 2. Attention 计算(每个 GPU 独立)
GPU 0: attn_0 = Attention(Q[0:8], K[0:2], V[0:2]) (1024 维)
GPU 1: attn_1 = Attention(Q[8:16], K[2:4], V[2:4]) (1024 维)
GPU 2: attn_2 = Attention(Q[16:24], K[4:6], V[4:6]) (1024 维)
GPU 3: attn_3 = Attention(Q[24:32], K[6:8], V[6:8]) (1024 维)
# 通信:0 次

# 3. 输出投影(行并行)
GPU 0: y_0 = W_o_0 @ attn_0 (4096 维)
GPU 1: y_1 = W_o_1 @ attn_1 (4096 维)
GPU 2: y_2 = W_o_2 @ attn_2 (4096 维)
GPU 3: y_3 = W_o_3 @ attn_3 (4096 维)
# 通信:0 次

# 4. All-Reduce(合并结果)
y = all_reduce(y_0 + y_1 + y_2 + y_3) (4096 维)
# 通信:1 次

# 输出:y [batch, 4096](每个 GPU 都有完整的 y)

关键优势

  1. 只通信一次:在最后的 all_reduce
  2. 完全并行:QKV 和 Attention 都是独立计算
  3. 数据局部性:每个 GPU 只处理自己的 head
  4. 内存效率:all_reduce 是原地操作

8.6 为什么不用其他方案?

方案对比

方案 QKV OProj 通信次数 问题
方案1(实际采用) 列并行 行并行 1 次 all_reduce
方案2 列并行 列并行 1 次 all_gather 需要收集所有 head,内存开销大
方案3 行并行 行并行 2 次 all_reduce 通信次数多
方案4 行并行 列并行 1 次 all_reduce Attention 无法并行

方案1 的优势

  • 通信次数最少(1 次)
  • 数据局部性最好(每个 GPU 独立计算)
  • 内存效率最高(all_reduce 原地操作)

费曼挑战(续)

挑战11:列并行 vs 行并行

问题:解释列并行和行并行的区别。

解答

  • 列并行:切分输出维度,每个 GPU 计算不同的输出元素,不需要通信
  • 行并行:切分输入维度,每个 GPU 计算相同的输出元素的一部分,需要 all_reduce 求和

数学原理

1
2
3
4
5
# 列并行:按输出维度切分
y[0:n] = W[0:n, :] @ x (拼接)

# 行并行:按输入维度切分
y = W[:, 0:n] @ x[0:n] + W[:, n:2n] @ x[n:2n] (求和)

挑战12:为什么行并行需要 all_reduce

问题:解释为什么行并行需要 all_reduce,而列并行不需要。

解答

  • 行并行:每个 GPU 计算的是部分和,需要 all_reduce 求总和
  • 列并行:每个 GPU 计算的是独立的输出,不需要通信

类比

  • 行并行:4 个人分别计算 1+2, 3+4, 5+6, 7+8,最后求和得到 36
  • 列并行:4 个人分别计算 1+2, 3+4, 5+6, 7+8,结果是 [3, 7, 11, 15](独立)

挑战13:LinearQKVMerged 的 (GQA_ratio + 2)

问题:解释为什么 LinearQKVMerged 的输出维度是 (GQA_ratio + 2) * local_num_kv * head_dim

解答

  • GQA_ratio:每个 KV 头对应的 Q 头数(例如 4)
  • local_num_kv:每个 GPU 负责的 KV 头数(例如 2)
  • 输出
    • Q 头数:GQA_ratio * local_num_kv = 4 * 2 = 8
    • K 头数:local_num_kv = 2
    • V 头数:local_num_kv = 2
    • 总共:8 + 2 + 2 = 12 个头
    • 输出维度:12 * head_dim = (4 + 2) * 2 * head_dim

挑战14:为什么 Attention 用"列并行 + 行并行"

问题:解释为什么 Attention 层要用"列并行(QKV)+ 行并行(OProj)"的组合。

解答

  • QKV 列并行:每个 GPU 负责不同的 head,完全独立计算
  • Attention 并行:每个 GPU 独立计算自己的 head
  • OProj 行并行:数据已经切分了,直接用行并行,最后 all_reduce
  • 优势:只通信一次,数据局部性好,内存效率高

对比其他方案

  • 如果 OProj 用列并行,需要先 all_gather 收集所有 head(内存开销大)
  • 如果 QKV 用行并行,需要先 all_reduce,然后 Attention 无法并行

挑战15:权重大小计算

问题:假设 hidden_size=4096, num_qo_heads=32, num_kv_heads=8, head_dim=128, tp_size=4,计算每个 GPU 的权重大小。

解答

LinearQKVMerged

1
2
3
4
5
6
7
8
# 完整权重:[4096, (32 + 8 + 8) * 128] = [4096, 6144]
# 列并行切分:[4096, 6144 / 4] = [4096, 1536]

# 或者用公式
local_num_kv = 8 / 4 = 2
GQA_ratio = 32 / 8 = 4
local_osize = (4 + 2) * 2 * 128 = 1536
# 权重:[4096, 1536]

LinearOProj

1
2
# 完整权重:[4096, 4096]
# 行并行切分:[4096, 4096 / 4] = [4096, 1024]

9. 进程间通信:message/backend.py

9.1 核心思想:消息传递

问题:Mini-SGLang 有 3 个进程(Frontend、Tokenizer、Scheduler),它们如何通信?

解决方案:定义统一的消息格式,通过 ZMQ 传递。

9.2 消息类型

9.2.1 BaseBackendMsg:基类

1
2
3
4
5
6
7
8
@dataclass
class BaseBackendMsg:
def encoder(self) -> Dict:
return serialize_type(self) # 序列化为 JSON

@staticmethod
def decoder(json: Dict) -> BaseBackendMsg:
return deserialize_type(globals(), json) # 反序列化

作用

  • 所有消息的基类
  • 提供序列化/反序列化功能(用于 ZMQ 传输)

为什么需要序列化?

  • ZMQ 只能传输字节流
  • 需要将 Python 对象转换为 JSON

为什么用 globals()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# globals() 返回当前模块的全局命名空间
# 包含所有定义的类:UserMsg, ExitMsg, BatchBackendMsg 等

# JSON 数据
json_data = {
"__type__": "UserMsg",
"uid": 123,
"input_ids": [...],
}

# deserialize_type 的工作流程
type_name = json_data["__type__"] # "UserMsg"
cls = globals()[type_name] # 找到 UserMsg 类
obj = cls(**json_data) # 创建对象

9.2.2 BatchBackendMsg:批量消息

1
2
3
@dataclass
class BatchBackendMsg(BaseBackendMsg):
data: List[BaseBackendMsg]

作用

  • 批量发送多个消息
  • 减少通信次数,提高吞吐量

例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# ❌ 不用批量(3 次通信)
send(UserMsg(uid=1, ...)) # 10μs
send(UserMsg(uid=2, ...)) # 10μs
send(UserMsg(uid=3, ...)) # 10μs
# 总时间:30μs

# ✅ 用批量(1 次通信)
send(BatchBackendMsg(data=[
UserMsg(uid=1, ...),
UserMsg(uid=2, ...),
UserMsg(uid=3, ...),
])) # 10μs
# 总时间:10μs
# 提速:3 倍

9.2.3 ExitMsg:退出消息

1
2
3
@dataclass
class ExitMsg(BaseBackendMsg):
pass

作用

  • 通知进程退出
  • 优雅关闭

例子

1
2
3
# Frontend 发送退出消息给 Tokenizer
exit_msg = ExitMsg()
zmq_socket.send_json(exit_msg.encoder())

9.2.4 UserMsg:用户请求

1
2
3
4
5
@dataclass
class UserMsg(BaseBackendMsg):
uid: int # 用户请求 ID
input_ids: torch.Tensor # CPU 1D int32 tensor(已分词)
sampling_params: SamplingParams # 采样参数

作用

  • 表示一个用户请求
  • 从 Tokenizer 发送到 Scheduler

为什么 input_ids 是 CPU Tensor?

  • ZMQ 传输的是 CPU 数据
  • GPU Tensor 无法直接序列化
  • Scheduler 收到后会复制到 GPU

为什么 input_ids 是 1D?

  • 每个 UserMsg 表示一个请求
  • 一个请求的 token 序列是 1D 的
  • 例如:[1, 2, 3, 4, 5](5 个 token)

9.3 完整的通信流程

场景:3 个用户请求同时到达

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# ========== Frontend 进程 ==========
# 1. 接收 3 个 HTTP 请求
request_A = {"prompt": "Hello", "temperature": 0.7}
request_B = {"prompt": "World", "temperature": 0.8}
request_C = {"prompt": "!", "temperature": 0.9}

# 2. 批量发送给 Tokenizer(通过 ZMQ)
# 注意:这里还不是 BatchBackendMsg,只是普通的 JSON
zmq_socket.send_json([
{"uid": 1, "prompt": "Hello", "temperature": 0.7},
{"uid": 2, "prompt": "World", "temperature": 0.8},
{"uid": 3, "prompt": "!", "temperature": 0.9},
])

# ========== Tokenizer 进程 ==========
# 3. 接收 Frontend 的消息
data_list = zmq_socket.recv_json()

# 4. 分词
user_msgs = []
for data in data_list:
input_ids = tokenizer.encode(data["prompt"])
user_msg = UserMsg(
uid=data["uid"],
input_ids=torch.tensor(input_ids, dtype=torch.int32),
sampling_params=SamplingParams(temperature=data["temperature"]),
)
user_msgs.append(user_msg)

# 结果
# user_msgs[0]: UserMsg(uid=1, input_ids=[1, 2, 3], ...)
# user_msgs[1]: UserMsg(uid=2, input_ids=[4, 5, 6], ...)
# user_msgs[2]: UserMsg(uid=3, input_ids=[7], ...)

# 5. 创建 BatchBackendMsg
batch_msg = BatchBackendMsg(data=user_msgs)

# 6. 序列化并发送给 Scheduler
json_data = batch_msg.encoder()
# 结果:
# {
# "__type__": "BatchBackendMsg",
# "data": [
# {"__type__": "UserMsg", "uid": 1, "input_ids": [1, 2, 3], ...},
# {"__type__": "UserMsg", "uid": 2, "input_ids": [4, 5, 6], ...},
# {"__type__": "UserMsg", "uid": 3, "input_ids": [7], ...},
# ]
# }
zmq_socket.send_json(json_data)

# ========== Scheduler 进程 ==========
# 7. 接收 Tokenizer 的消息
json_data = zmq_socket.recv_json()

# 8. 反序列化
batch_msg = BaseBackendMsg.decoder(json_data)
# 类型:BatchBackendMsg

# 9. 提取所有 UserMsg
user_msgs = batch_msg.data
# user_msgs[0]: UserMsg(uid=1, input_ids=[1, 2, 3], ...)
# user_msgs[1]: UserMsg(uid=2, input_ids=[4, 5, 6], ...)
# user_msgs[2]: UserMsg(uid=3, input_ids=[7], ...)

# 10. 拼成一个 batch(padding)
input_ids_list = [msg.input_ids for msg in user_msgs]
# [[1, 2, 3], [4, 5, 6], [7]]

# Padding 到相同长度
max_len = max(len(ids) for ids in input_ids_list) # 3
padded_input_ids = []
for ids in input_ids_list:
padded = torch.cat([ids, torch.zeros(max_len - len(ids), dtype=torch.int32)])
padded_input_ids.append(padded)

batch_input_ids = torch.stack(padded_input_ids).to("cuda")
# shape: [3, 3]
# [[1, 2, 3],
# [4, 5, 6],
# [7, 0, 0]] # padding

# 11. 批量推理
output_tokens = model.forward(batch_input_ids)
# shape: [3, 1]
# [[10],
# [11],
# [12]]

# 12. 分发结果
# uid=1 → token=10
# uid=2 → token=11
# uid=3 → token=12

9.4 关键设计

9.4.1 Frontend → Tokenizer 不用 BatchBackendMsg

为什么?

  • Frontend 发送的是原始文本,还没有分词
  • BatchBackendMsg 是用于 BaseBackendMsg 的批量
  • Frontend 只需要发送普通的 JSON 列表

9.4.2 Tokenizer → Scheduler 用 BatchBackendMsg

为什么?

  • Tokenizer 已经创建了 UserMsg 对象
  • 需要批量发送多个 UserMsg
  • 使用 BatchBackendMsg 包装

9.4.3 Scheduler 需要 Padding

为什么?

  • 3 个请求的长度不同:[1,2,3], [4,5,6], [7]
  • GPU 批量计算需要相同长度
  • 用 0 填充短的序列

Padding 示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 原始
input_ids_A = [1, 2, 3] # 长度 3
input_ids_B = [4, 5, 6] # 长度 3
input_ids_C = [7] # 长度 1

# Padding 后
input_ids_A = [1, 2, 3] # 长度 3
input_ids_B = [4, 5, 6] # 长度 3
input_ids_C = [7, 0, 0] # 长度 3(填充 0)

# 拼成 batch
batch_input_ids = torch.tensor([
[1, 2, 3],
[4, 5, 6],
[7, 0, 0],
]) # shape: [3, 3]

9.5 性能优化

9.5.1 批量通信

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# ❌ 非批量(3 次通信)
send(UserMsg(uid=1, input_ids=[1,2,3], ...)) # 10μs
send(UserMsg(uid=2, input_ids=[4,5,6], ...)) # 10μs
send(UserMsg(uid=3, input_ids=[7], ...)) # 10μs
# 总时间:30μs

# ✅ 批量(1 次通信)
send(BatchBackendMsg(data=[
UserMsg(uid=1, input_ids=[1,2,3], ...),
UserMsg(uid=2, input_ids=[4,5,6], ...),
UserMsg(uid=3, input_ids=[7], ...),
])) # 10μs
# 总时间:10μs
# 提速:3 倍

9.5.2 GPU 批量计算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# ❌ 非批量(3 次推理)
output_A = model.forward([1, 2, 3]) # 1ms
output_B = model.forward([4, 5, 6]) # 1ms
output_C = model.forward([7, 0, 0]) # 1ms
# 总时间:3ms

# ✅ 批量(1 次推理)
output = model.forward([
[1, 2, 3],
[4, 5, 6],
[7, 0, 0],
]) # 1.2ms(批量计算更高效)
# 总时间:1.2ms
# 提速:2.5 倍

9.6 完整流程图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
Frontend (3 个 HTTP 请求)
↓ (ZMQ: JSON 列表)
Tokenizer
↓ (分词)
UserMsg(uid=1, input_ids=[1,2,3])
UserMsg(uid=2, input_ids=[4,5,6])
UserMsg(uid=3, input_ids=[7])
↓ (包装)
BatchBackendMsg(data=[...])
↓ (ZMQ: JSON)
Scheduler
↓ (反序列化)
batch_msg.data = [UserMsg, UserMsg, UserMsg]
↓ (Padding)
batch_input_ids = [[1,2,3], [4,5,6], [7,0,0]]
↓ (GPU 批量推理)
output_tokens = [[10], [11], [12]]
↓ (分发结果)
uid=1 → token=10
uid=2 → token=11
uid=3 → token=12

费曼挑战(续)

挑战16:为什么需要序列化

问题:解释为什么需要序列化/反序列化。

解答

  • ZMQ 只能传输字节流:不能直接传输 Python 对象
  • 序列化:将 Python 对象转换为 JSON(Dict)
  • 反序列化:将 JSON 转换回 Python 对象
  • 作用:实现跨进程通信

流程

1
2
3
4
5
6
7
8
# 发送端
msg = UserMsg(uid=123, ...)
json_data = msg.encoder() # Python 对象 → JSON
zmq_socket.send_json(json_data) # JSON → 字节流

# 接收端
json_data = zmq_socket.recv_json() # 字节流 → JSON
msg = BaseBackendMsg.decoder(json_data) # JSON → Python 对象

挑战17:BatchBackendMsg 的作用

问题:解释 BatchBackendMsg 的作用和优势。

解答

  • 作用:批量发送多个消息
  • 优势
    • 减少通信次数(3 次 → 1 次)
    • 提高吞吐量(提速 3 倍)
    • 减少 ZMQ 开销

对比

1
2
3
# 非批量:3 次通信,30μs
# 批量:1 次通信,10μs
# 提速:3 倍

挑战18:为什么 input_ids 是 CPU Tensor

问题:解释为什么 UserMsg 的 input_ids 是 CPU Tensor。

解答

  • ZMQ 传输的是 CPU 数据:无法直接传输 GPU Tensor
  • GPU Tensor 无法序列化:需要先复制到 CPU
  • Scheduler 收到后会复制到 GPU:用于推理

流程

1
2
3
4
5
6
7
8
# Tokenizer
input_ids_cpu = torch.tensor([1, 2, 3], dtype=torch.int32) # CPU
user_msg = UserMsg(uid=123, input_ids=input_ids_cpu, ...)
send(user_msg)

# Scheduler
user_msg = recv()
input_ids_gpu = user_msg.input_ids.to("cuda") # CPU → GPU

挑战19:为什么需要 Padding

问题:解释为什么 Scheduler 需要 Padding。

解答

  • 问题:3 个请求的长度不同(3, 3, 1)
  • GPU 批量计算需要相同长度:Tensor 必须是矩形
  • 解决方案:用 0 填充短的序列

例子

1
2
3
4
5
# 原始(长度不同)
[[1, 2, 3], [4, 5, 6], [7]] # 无法拼成 Tensor

# Padding 后(长度相同)
[[1, 2, 3], [4, 5, 6], [7, 0, 0]] # 可以拼成 Tensor

挑战20:完整的消息流

问题:描述 3 个用户请求从 Frontend 到 Scheduler 的完整消息流。

解答

  1. Frontend → Tokenizer

    • 发送 JSON 列表(原始文本)
    • 不用 BatchBackendMsg
  2. Tokenizer 处理

    • 分词,创建 UserMsg
    • 包装成 BatchBackendMsg
  3. Tokenizer → Scheduler

    • 序列化 BatchBackendMsg
    • 通过 ZMQ 发送
  4. Scheduler 处理

    • 反序列化 BatchBackendMsg
    • 提取 UserMsg
    • Padding 到相同长度
    • 批量推理

关键

  • 批量通信(减少通信次数)
  • 批量推理(提高 GPU 利用率)
  • Padding(满足 GPU 批量计算要求)