Mini-Infer (17): 深入字节流 — `WeightImporter` 与权重加载
Mini-Infer (17): 深入字节流 — WeightImporter 与权重加载
1. ONNX 的数据存储格式
ONNX 在存储权重时有两种模式:
- Raw Data (二进制流):这是最常用、最高效的模式。所有数据被打包成一个字节流 (
std::string) 存储在raw_data字段中。这需要我们进行memcpy。 - Typed Data (类型化数组):对于较小的张量,ONNX 可能会直接使用 Protobuf 的重复字段(如
float_data,int32_data)。这需要我们遍历并逐个赋值。
WeightImporter 必须能无缝处理这两种情况。
2. WeightImporter 核心逻辑
A. 数据类型转换
首先,我们需要将 ONNX 的数据类型(onnx::TensorProto::FLOAT)映射到 Mini-Infer 的类型(core::DataType::FLOAT32)。
1 | core::DataType WeightImporter::convert_data_type(int onnx_dtype, std::string& error_message) { |
B. import_tensor: 总入口
这是 WeightImporter 的主函数。它遵循以下步骤:
- 提取 Shape:从
tensor_proto.dims()中读取维度。 - 创建 Tensor:根据 Shape 和 DataType,在
Mini-Infer中创建一个新的Tensor对象(这会分配内存)。 - 分发加载逻辑:判断是
raw_data还是typed_data,调用对应的辅助函数。
1 | std::shared_ptr<core::Tensor> WeightImporter::import_tensor(...) { |
C. import_raw_data: 内存拷贝
对于 raw_data,我们只需要一次 std::memcpy。这是加载大模型(如 ResNet, BERT)时性能的关键。
1 | bool WeightImporter::import_raw_data(...) { |
D. import_typed_data: 类型分发
对于 typed_data,我们需要根据 DataType 访问 Protobuf 中不同的字段(如 float_data 或 int32_data)。
1 | bool WeightImporter::import_typed_data(...) { |
3. 为什么这很重要?
WeightImporter 看似简单,但它是模型加载过程中最容易出错的地方。
- 大小端问题:
raw_data通常假定是小端序(Little Endian)。 - 内存对齐:直接
memcpy到void*指针通常是安全的,但如果涉及 SIMD 优化,可能需要考虑对齐问题。
通过将这些复杂的底层逻辑封装在 WeightImporter 中,我们的上层 ModelImporter 代码变得异常干净和健壮。
4. 总结与展望
现在,我们有了:
- Parser: 读取文件。
- Context: 管理状态。
- Registry: 分发算子。
- WeightImporter: 加载数据。
我们的 ONNX 导入器已经具备了处理数据的能力。但是,它还不知道如何处理算子。它能读懂权重,但读不懂 Conv 节点里的 strides 和 pads 属性。
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 James的成长之路!
评论





