Mini-Infer (16): 模型导入的核心 — `ImporterContext` 与 `OperatorRegistry`
Mini-Infer (16): 模型导入的核心 — ImporterContext 与 OperatorRegistry
引言:从 Protobuf 到 Graph
在上一篇中,我们实现了 OnnxParser,它能将 .onnx 文件反序列化为 Protobuf 对象。
但这仅仅是第一步。onnx::ModelProto 是一棵复杂的语法树,充满了 Node、Initializer 和 ValueInfo。我们需要一个强大的机制将这些“死数据”转化为 Mini-Infer 中“活的” Graph 对象。
本篇,我们将构建模型导入的两个核心组件:
ImporterContext: 一个“共享黑板”,用于在导入过程中追踪所有的 Tensor 和权重,解决 ONNX 基于名字的连接问题。OperatorRegistry: 一个“算子工厂”,负责根据 ONNX 的op_type(如 “Conv”)找到对应的导入逻辑。
1. ImporterContext: 连接一切的桥梁
ONNX 的图结构是基于名字 (String) 的,而 Mini-Infer 的图结构是基于指针 (Pointer) 的。
- ONNX:
Node A输出 “X”,Node B输入 “X”。 - Mini-Infer:
Node A的output_tensor指针需要被传递给Node B。
ImporterContext 的首要职责就是维护这张名字 -> 指针的映射表。
核心接口设计 (importer_context.h)
1 | class ImporterContext { |
实现细节 (operator_importer.cpp)
在实现中,我们将 weight 同时注册为 tensor。这样做的好处是,当算子需要输入时(无论是来自上一层的激活值,还是来自权重的常量),都可以统一调用 get_tensor() 获取,简化了逻辑。
1 | void ImporterContext::register_weight(const std::string& name, std::shared_ptr<core::Tensor> weight) { |
2. OperatorRegistry: 插件化的算子导入
我们不希望在 ModelImporter 中写一个巨大的 if-else 来处理所有的 ONNX 算子。
我们借鉴 TensorRT 的设计,采用注册表模式。每一个 ONNX 算子都有一个独立的 Importer 类,并通过工厂函数注册到 OperatorRegistry 中。
抽象基类:OperatorImporter
所有算子导入器都必须继承自这个基类。它的核心是 import_operator 方法,负责将一个 ONNX NodeProto 转换为 Mini-Infer 的节点。
1 | class OperatorImporter { |
注册表实现:OperatorRegistry
注册表本质上是一个 std::unordered_map,将 op_type 字符串映射到 ImporterFactory(一个创建导入器实例的函数)。
1 | class OperatorRegistry { |
注册宏:REGISTER_ONNX_OPERATOR
为了简化注册过程,我们提供了一个宏。这使得在未来的 builtin_operators.cpp 中注册算子变得非常简洁。
1 |
3. 自动加载内置算子
OperatorRegistry 的构造函数会自动调用 register_builtin_operators()。这是一个全局函数(将在未来的博客中实现),用于批量注册所有内置的算子导入器(如 Conv, Relu, Flatten)。
1 | // operator_importer.cpp |
4. 总结
本篇我们构建了 ONNX 导入器的“神经系统”:
ImporterContext:解决了 ONNX 名字到Mini-Infer指针的映射难题,并统一管理了 Tensor 和 Weight。OperatorRegistry:实现了算子导入逻辑的解耦。添加新算子支持不再需要修改核心代码,只需编写新的 Importer 并注册即可。
这套架构具有极高的可扩展性。未来如果我们需要支持更多的 ONNX 算子,或者支持自定义算子,只需要扩充注册表即可。





