Mini-Infer (17): 深入字节流 — WeightImporter 与权重加载

1. ONNX 的数据存储格式

ONNX 在存储权重时有两种模式:

  1. Raw Data (二进制流):这是最常用、最高效的模式。所有数据被打包成一个字节流 (std::string) 存储在 raw_data 字段中。这需要我们进行 memcpy
  2. Typed Data (类型化数组):对于较小的张量,ONNX 可能会直接使用 Protobuf 的重复字段(如 float_data, int32_data)。这需要我们遍历并逐个赋值。

WeightImporter 必须能无缝处理这两种情况。

2. WeightImporter 核心逻辑

A. 数据类型转换

首先,我们需要将 ONNX 的数据类型(onnx::TensorProto::FLOAT)映射到 Mini-Infer 的类型(core::DataType::FLOAT32)。

1
2
3
4
5
6
core::DataType WeightImporter::convert_data_type(int onnx_dtype, std::string& error_message) {
switch (onnx_dtype) {
case onnx::TensorProto::FLOAT: return core::DataType::FLOAT32;
// ... 其他类型 ...
}
}

B. import_tensor: 总入口

这是 WeightImporter 的主函数。它遵循以下步骤:

  1. 提取 Shape:从 tensor_proto.dims() 中读取维度。
  2. 创建 Tensor:根据 Shape 和 DataType,在 Mini-Infer 中创建一个新的 Tensor 对象(这会分配内存)。
  3. 分发加载逻辑:判断是 raw_data 还是 typed_data,调用对应的辅助函数。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
std::shared_ptr<core::Tensor> WeightImporter::import_tensor(...) {
// ... 解析 Shape 和 DataType ...

auto tensor = std::make_shared<core::Tensor>(shape, dtype);

if (tensor_proto.has_raw_data()) {
// 模式 1: Raw Data (高效)
import_raw_data(tensor_proto, tensor->data(), expected_size);
} else {
// 模式 2: Typed Data (兼容)
import_typed_data(tensor_proto, tensor->data(), dtype, num_elements);
}

return tensor;
}

C. import_raw_data: 内存拷贝

对于 raw_data,我们只需要一次 std::memcpy。这是加载大模型(如 ResNet, BERT)时性能的关键。

1
2
3
4
5
6
bool WeightImporter::import_raw_data(...) {
const std::string& raw = tensor_proto.raw_data();
if (raw.size() != expected_size) return false;
std::memcpy(data, raw.data(), raw.size());
return true;
}

D. import_typed_data: 类型分发

对于 typed_data,我们需要根据 DataType 访问 Protobuf 中不同的字段(如 float_dataint32_data)。

1
2
3
4
5
6
7
8
9
10
bool WeightImporter::import_typed_data(...) {
switch (dtype) {
case core::DataType::FLOAT32:
for (int i = 0; i < size; ++i) {
float_ptr[i] = tensor_proto.float_data(i);
}
return true;
// ... 其他类型 ...
}
}

3. 为什么这很重要?

WeightImporter 看似简单,但它是模型加载过程中最容易出错的地方。

  • 大小端问题raw_data 通常假定是小端序(Little Endian)。
  • 内存对齐:直接 memcpyvoid* 指针通常是安全的,但如果涉及 SIMD 优化,可能需要考虑对齐问题。

通过将这些复杂的底层逻辑封装在 WeightImporter 中,我们的上层 ModelImporter 代码变得异常干净和健壮。


4. 总结与展望

现在,我们有了:

  1. Parser: 读取文件。
  2. Context: 管理状态。
  3. Registry: 分发算子。
  4. WeightImporter: 加载数据。

我们的 ONNX 导入器已经具备了处理数据的能力。但是,它还不知道如何处理算子。它能读懂权重,但读不懂 Conv 节点里的 stridespads 属性。