概述
本文深入分析 Mini-SGLang 的模型加载架构,对比传统方式(如 HuggingFace Transformers)与 Mini-SGLang 的设计差异,理解结构与权重分离的设计哲学及其带来的优势。
核心发现:Mini-SGLang 不是通过解析配置文件动态构建模型,而是先手动定义网络结构,再加载权重。这种设计使得框架具有极高的可定制性和优化空间。
1. 支持的模型
1.1 当前支持
Mini-SGLang 目前支持 2 个模型系列:
1. Llama 系列
- Llama-2 (7B, 13B, 70B)
- Llama-3 (8B, 70B)
- Llama-3.1 (8B, 70B, 405B)
- Llama-3.2
- 其他 Llama 架构变体(Mistral、Yi 等)
2. Qwen3 系列
- Qwen3-0.6B
- Qwen3-14B
- Qwen3-32B
1.2 模型判断逻辑
1 2 3 4 5 6 7 8 9 10 11
| def create_model(model_path: str, model_config: ModelConfig) -> BaseLLMModel: model_name = model_path.lower() if "llama" in model_name: from .llama import LlamaForCausalLM return LlamaForCausalLM(model_config) elif "qwen3" in model_name: from .qwen3 import Qwen3ForCausalLM return Qwen3ForCausalLM(model_config) else: raise ValueError(f"Unsupported model: {model_path}")
|
特点:
- 通过模型路径名称判断(简单直接)
- 易于扩展(添加新模型只需加一个 elif)
- 支持参数规模:0.6B - 405B
2. 传统方式 vs Mini-SGLang 方式
1 2 3 4 5 6 7 8
| model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-32B")
|
特点:
- ✅ 使用简单,开箱即用
- ❌ 黑盒操作,难以定制
- ❌ 结构由配置文件动态生成
- ❌ 无法在加载前优化结构
- ❌ 内存占用高(CPU + GPU 双份)
2.2 Mini-SGLang 方式(结构与权重分离)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| hf_config = AutoConfig.from_pretrained("Qwen/Qwen3-32B") model_config = ModelConfig.from_hf(hf_config)
with torch.device("meta"): model = Qwen3ForCausalLM(model_config)
state_dict = load_hf_weight(model_path, device)
model.load_state_dict(state_dict)
|
特点:
- ✅ 结构和权重完全分离
- ✅ 可以在加载权重前优化结构
- ✅ 可以自定义权重加载逻辑
- ✅ 透明可控,易于调试
- ✅ 内存占用低(meta device)
3. 完整的模型加载流程
3.1 流程图
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| 用户命令 ↓ python -m minisgl --model "Qwen/Qwen3-32B" --tp 4 ↓ ┌─────────────────────────────────────────┐ │ 1. 启动入口: __main__.py │ │ → launch_server() │ └─────────────────────────────────────────┘ ↓ ┌─────────────────────────────────────────┐ │ 2. 解析参数: server/args.py │ │ → ServerArgs │ │ - model_path: "Qwen/Qwen3-32B" │ │ p_size: 4 │ │ - dtype: torch.float16 │ └─────────────────────────────────────────┘ ↓ ┌─────────────────────────────────────────┐ │ 3. 启动多进程: server/launch.py │ │ - 4 个 Scheduler 进程 (TP 0-3) │ │ - 1 个 Detokenizer 进程 │ │ - N 个 Tokenizer 进程 │ └─────────────────────────────────────────┘ ↓ ┌─────────────────────────────────────────┐ │ 4. 初始化 Scheduler │ │ → Engine(config) │ └─────────────────────────────────────────┘
|
3.2 配置加载阶段
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| @dataclass(frozen=True) class EngineConfig: model_path: str tp_info: DistributedInfo dtype: torch.dtype @cached_property def hf_config(self): return cached_load_hf_config(self.model_path) @cached_property def model_config(self) -> ModelConfig: return ModelConfig.from_hf(self.hf_config)
|
步骤 1:加载 HuggingFace 配置
1 2 3 4 5 6 7 8 9
| @lru_cache() def _load_config(model_path: str): from transformers import AutoConfig return AutoConfig.from_pretrained(model_path)
def cached_load_hf_config(model_path: str) -> LlamaConfig: config = _load_config(model_path) return type(config)(**config.to_dict())
|
- 只下载
config.json(几 KB)
- 使用
@lru_cache 避免重复加载
- 返回 HuggingFace 的配置对象
步骤 2:转换配置格式
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
| @dataclass(frozen=True) class ModelConfig: num_layers: int num_qo_heads: int num_kv_heads: int head_dim: int hidden_size: int vocab_size: int intermediate_size: int rms_norm_eps: float rotary_config: RotaryConfig hidden_act: str tie_word_embeddings: bool @classmethod def from_hf(cls, config: LlamaConfig) -> ModelConfig: num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) return cls( num_layers=config.num_hidden_layers, num_qo_heads=config.num_attention_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, hidden_size=config.hidden_size, vocab_size=config.vocab_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, rms_norm_eps=config.rms_norm_eps, tie_word_embeddings=getattr(config, "tie_word_embeddings", False), rotary_config=RotaryConfig( head_dim=head_dim, rotary_dim=head_dim, max_position=config.max_position_embeddings, base=config.rope_theta, scaling=getattr(config, "rope_scaling", None), ), )
|
为什么要转换?
- HuggingFace 配置包含很多无关字段
- Mini-SGLang 只需要推理相关的参数
- 统一配置格式,方便后续使用
3.3 模型结构构建阶段
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| def __init__(self, config: EngineConfig): set_rope_device(self.device) with torch.device("meta"), torch_dtype(config.dtype): self.model = create_model(config.model_path, config.model_confi**什么是 Meta Device?**
```python
model = Qwen3ForCausalLM(config)
with torch.device("meta"): model = Qwen3ForCausalLM(config)
|
Meta Device 的优势:
- 节省内存:不需要在 CPU 和 GPU 上各存一份
- 加载更快:跳过 CPU → GPU 的拷贝
- 支持大模型:可以加载超过单 GPU 显存的模型
模型结构构建:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
| class Qwen3ForCausalLM(BaseLLMModel): def __init__(self, config: ModelConfig): self.model = Qwen3Model(config) self.lm_head = ParallelLMHead( num_embeddings=config.vocab_size, embedding_config.hidden_size, tie_word_embeddings=config.tie_word_embeddings, tied_embedding=self.model.embed_tokens if config.tie_word_embeddings else None, ) super().__init__()
class Qwen3Model(BaseOP): def __init__(self, config: ModelConfig): self.embed_tokens = VocabParallelEmbedding( num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, ) self.layers = OPList([ Qwen3DecoderLayer(config, layer_id) for layer_id in range(config.num_layers) ]) self.norm = RMSNormFused( size=config.hidden_size, eps=config.rms_norm_eps, )
class Qwen3DecoderLayer(BaseOP): def __init__(self, config: ModelConfig, layer_id: int): self.input_layernorm = RMSNormFused( size=config.hidden_size, eps=config.rms_norm_eps, ) self.self_attn = Qwen3Attn(config, layer_id, has_qk_norm=True) self.post_attention_layernorm = RMSNormFused( size=config.hidden_size, eps=config.rms_norm_eps, ) self.mlp = Qwen3MLP(config)
|
网络结构(以 Qwen3-32B 为例):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
| Qwen3ForCausalLM ├─ embed_tokens: VocabParallelEmbedding │ vocab_size=151936, hidden_size=5120 │ TP 切分: 每个 GPU 存储 151936/4 = 37984 词 │ ├─ layers: OPList[Qwen3DecoderLayer] × 64 │ 每层包含: │ ├─ input_layernorm: RMSNormFused │ ├─ self_attn: Qwen3Attn (RopeAttn) │ │ ├─ qkv_proj: LinearQKVMerged │ │ │ hidden=5120 → q=5120, k=1024, v=1024 │ │ │ TP 切分: 列切分,每 GPU 1/4 │ │ ├─ q_norm, k_norm: RMSNorm (Qwen3 特有) │ │ ├─ attn: AttentionLan│ │ │ RoPE + FlashAttention/FlashInfer │ │ └─ o_proj: LinearOProj │ │ 5120 → 5120 │ │ TP 切分: 行切分 + All-Reduce │ │ │ ├─ post_attention_layernorm: RMSNormFused │ └─ mlp: Qwen3MLP (GatedMLP) │ ├─ gate_up_proj: LinearColParallelMerged │ │ 5120 → [27648, 27648] │ │ TP 切分: 列切分,每 GPU 1/4 │ ├─ act_fn: silu_and_mul │ └─ down_proj: LinearRowParallel │ 27648 → 5120 │ TP 切分: 行切分 + All-Reduce │ ├─ norm: RMSNormFused (最后的 LayerNorm) │ └─ lm_head: ParallelLMHead hidden_size=5120 → vocab_siz51936 TP 切分: 每个 GPU 输出 37984 个 logits
|
关键点:
- 此时所有参数都在 “meta” 设备上
- 结构已经完全确定
- 但权重还是空的(未初始化)
3.4 权重加载阶段
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| def _load_weight_state_dict(self, config: EngineConfig) -> Dict[str, torch.Tensor]: if config.use_dummy_weight: return { k: torch.randn_like(v, device=self.device) for k, v in self.model.state_dict().items() } else: return { k: v.to(self.dtype) for k, v in load_hf_weight(config.model_path, self.device).items() }
self.model.load_state_dict(self._load_weight_state_dict(config))
|
权重加载的完整流程:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| def load_hf_weight(model_path: str, device: torch.device) -> Dict[str, torch.Tensor]: if os.path.isdir(model_path): hf_folder = model_path else: hf_folder = snapshot_download( model_path, allow_patterns=["*.safetensors"], ) files = glob.glob(f"{hf_folder}/*.safetensors") state_dict: Dict[str, torch.Tensor] = {} for file in sorted(files): with safetensors.safe_open(file, framework="pt", device="cpu") as f: for name in f.keys(): state_dict[name] = f.get_tensor(name) if get_tp_info().size > 1: state_dict = _shard_state_dict(state_dict) state_dict = {k: v.to(device) for k, v in state_dict.items()} return _merge_state_dict(state_dict)
|
4. 权重优化:TP 切分与合并
4.1 TP 权重切分
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
| def _shard_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: shard_state_dict: Dict[str, torch.Tensor] = {} tp_info = get_tp_info() r = tp_info.rank n = tp_info.size SPLIT_DIM_0_LIST = [ ".q_proj", ".k_proj", ".v_proj", ".gate_proj", ".up_proj", ] SPLIT_DIM_1_LIST = [ ".o_proj", ".down_proj", ] for key, value in state_dict.items(): if any(key.count(sub) for sub in SPLIT_DIM_0_LIST): elif any(key.count(sub) for sub in SPLIT_DIM_1_LIST): shard_state_dict[key] = value.chunk(n, dim=1)[r] elif key.count("lm_head") or key.count("embed_tokens"): num_embeddings = value.shape[0] num_embeddings_per_partition = divide_up(num_embeddings, n) vocab_start_idx = r * num_embeddings_per_partition vocab_end_idx = min((r + 1) * num_embeddings_per_partition, num_embeddings) shard_dict[key] = value[vocab_start_idx:vocab_end_idx, :] else: shard_state_dict[key] = value return shard_state_dict
|
TP 切分示例(tp_size=4):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| q_proj.weight: [5120, 5120]
GPU 0: q_proj.weight: [1280, 5120] GPU 1: q_proj.weight: [1280, 5120] GPU 2: q_proj.weight: [1280, 5120] GPU 3: q_proj.weight: [1280, 5120]
o_proj.weight: [5120, 5120]
GPU 0: o_proj.weight: [5120, 1280] GPU 1: o_proj.weight: [5120, 1280] Gj.weight: [5120, 1280] GPU 3: o_proj.weight: [5120, 1280]
|
为什么这样切分?
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
| Y = X @ W
W = [W_0 | W_1 | W_2 | W_3]
GPU_i: Y_i = X @ W_i
Y = [Y_0 | Y_1 | Y_2 | Y_3]
Y = X @ W
W = [W_0; W_1; W_2; W_3]
GPU_i: X_i = X[:, i*M/4 : (i+1)*M/4] GPU_i: Y_i = X_i @ W_i
Y = sum(Y_0, Y_1, Y_2, Y_3)
|
4.2 权重合并优化
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
| def _merge_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: filtered_state_dict: Dict[str, torch.Tensor] = {} for key in list(state_dict.keys()): if key.count(".q_proj"): q_proj = state_dict[key] k_proj = state_dict[key.replace(".q_proj", ".k_proj")] v_proj = state_dict[key.replace(".q_proj", ".v_proj")] new_key = key.replace(".q_proj", ".qkv_proj") filtered_state_dict[new_key] = torch.cat([q_proj, k_proj, v_proj], dim=0) del state_dict[key] del state_dict[key.replace(".q_proj", ".k_proj")] del state_dict[key.replace(".q_proj", ".v_proj")] elif key.count(".gate_proj"): gate_proj = state_dict[key] up_proj = state_dict[key.replace(".gate_proj", ".up_proj")] new_key = key.replace(".gate_proj", ".gate_up_proj") filtered_state_dict[new_key] = torch.cat([gate_proj, up_proj], dim=0) del state_dict[key] del state_dict[key.replace(".gate_proj", ".up_proj")] elif key.count(".k_proj") or key.count(".v_proj") or key.count("up_proj"): continue else: filtered_state_dict[key] = state_dict[key] return filtered_state_dict
|
权重合并的优势:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| q_proj.weight: [5120, 5120] k_proj.weight: [1024, 5120] v_proj.weight: [1024, 5120]
q = x @ q_proj.weight k = x @ k_proj.weight v = x @ v_proj.weight
qkv_proj.weight: [7168, 5120]
qkv = x @ qkv_proj.weight q, k, v = qkv.split([5120, 1024, 1024], dim=-1)
|
5. 为什么要分离结构与权重?
5.1 支持自定义优化
传统方式(困难):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-32B")
for layer in model.model.layers: q = layer.self_attn.q_proj.weight k = layer.self_attn.k_proj.weight v = layer.self_attn.v_proj.weight qkv = nn.Linear(...) qkv.weight = nn.Parameter(torch.cat([q, k, v], dim=0)) layer.self_attn.qkv_proj = qkv del layer.self_attn.q_proj del layer.self_attn.k_proj del layer.self_attn.v_proj
|
Mini-SGLang 方式(简单):
1 2 3 4 5 6 7 8 9 10
| class Qwen3Attn(BaseOP): def __init__(self, config, layer_id): self.qkv_proj = LinearQKVMerged(...) self.attn = AttentionLayer(. self.o_proj = LinearOProj(...)
def _merge_state_dict(state_dict): ...
|
传统方式(内存占用高):
1 2 3 4 5 6 7 8
| model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-32B")
|
Mini-SGLang 方式(内存占用低):
1 2 3 4 5 6 7 8 9 10 11
| with torch.device("meta"): model = Qwen3ForCausalLM(config)
state_dict = load_hf_weight(model_path, device)
model.load_state_dict(state_dict)
|
节省:96GB → 32GB(节省 67%)
5.3 支持 TP 自动切分
传统方式(手动处理):
1 2 3 4 5 6 7 8 9 10 11
| model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-32B")
for name, param in model.named_parameters(): if "q_proj" in name or "k_proj" in name or "v_proj" in name: param.data = param.data.chunk(tp_size, dim=0)[tp_rank] elif "o_proj" in name or "down_proj" in name: param.data = param.data.chunk(tp_size, dim=1)[tp_rank]
|
Mini-SGLang 方式(自动处理):
1 2 3 4 5 6
| state_dict = load_hf_weight(model_path, device)
model.load_state_dict(state_dict)
|
5.4 支持灵活的 Attention 实现
传统方式(固定实现):
1 2 3
| model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-32B")
|
Mini-SGLang 方式(灵活选择):
1 2 3 4 5 6 7 8 9 10 11 12 13
| class Qwen3Attn(BaseOP): def __init__(self, config, layer_id): self.qkv_proj = LinearQKVMerged(...) if config.attention_backend == "flashinfer": self.attn = FlashInferAttention(...) elif config.attention_backend == "flashattention": self.attn = FlashAttention(...) else: self.attn = NaiveAttention(...) self.o_proj = LinearOProj(...)
|
6. 适配新模型的工作量
6.1 场景 1:架构完全兼容(如 Mistral)
工作量:10 分钟
1 2 3 4 5 6 7
| from .llama import LlamaForCausalLM as MistralForCausalLM
elif "mistral" in model_name: from .mistral import MistralForCausalLM return MistralForCausalLM(model_config)
|
6.2 场景 2:小差异(如 Qwen3 的 QK Norm)
工作量:半天
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| class Qwen3Attn(BaseOP): def __init__(self, config, layer_id): self.qkv_proj = LinearQKVMerged(...) self.q_norm = RMSNorm(head_dim, eps=config.rms_norm_eps) self.k_norm = RMSNorm(head_dim, eps=config.rms_norm_eps) self.attn = AttentionLayer( ..., q_norm=self.q_norm, k_norm=self.k_norm, ) self.o_proj = LinearOProj(...)
|
6.3 场景 3:中等差异(如 Falcon 的 MQA)
工作量:1-2 天
1 2 3 4 5 6 7 8 9 10 11 12
| class FalconAttn(BaseOP): def __init__(self, config, layer_id): assert config.num_kv_heads == 1 self.qkv_proj = LinearQKVMerged( ..., num_kv_heads=1, ) self.attn = AttentionLayer(...) self.o_proj = LinearOProj(...)
|
6.4 场景 4:大差异(如 BLOOM 的 ALiBi)
工作量:3-5 天
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
| class ALiBiPositionEncoding(BaseOP): def __init__(self, num_heads): self.slopes = self._get_slopes(num_heads) def forward(self, q, k): ...
class BLOOMAttn(BaseOP): def __init__(self, config, layer_id): self.qkv_proj = LinearQKVMerged(...) self.alibi = ALiBiPositionEncoding(config.num_qo_heads) self.attn = AttentionLayer( ..., position_encoding=self.alibi, ) self.o_proj = LinearOProj(...)
|
7. 总结
7.1 设计哲学
Mini-SGLang 的核心设计哲学:
- 结构与权重分离:先定义结构,再加载权重
- 显式优于隐式:手动构建模型,而不是动态生成
- 优化前置:在加载权重时就做好优化(TP、合并、量化)
- 透明可控:每一步都清晰可见,易于调试和定制
7.2 关键优势
| 维度 |
传统方式 |
Mini-SGLang |
| 易用性 |
⭐⭐⭐⭐⭐ |
⭐⭐⭐ |
| 可定制性 |
⭐⭐ |
⭐⭐⭐⭐⭐ |
| 内存效率 |
⭐⭐ |
⭐⭐⭐⭐⭐ |
| 性能优化 |
⭐⭐⭐ |
⭐⭐⭐⭐⭐ |
| TP 支持 |
⭐⭐ |
⭐⭐⭐⭐⭐ |
| 透明度 |
⭐⭐ |
⭐⭐⭐⭐⭐ |