Mini-Infer (16): 模型导入的核心 — ImporterContextOperatorRegistry

引言:从 Protobuf 到 Graph

上一篇中,我们实现了 OnnxParser,它能将 .onnx 文件反序列化为 Protobuf 对象。

但这仅仅是第一步。onnx::ModelProto 是一棵复杂的语法树,充满了 NodeInitializerValueInfo。我们需要一个强大的机制将这些“死数据”转化为 Mini-Infer 中“活的” Graph 对象。

本篇,我们将构建模型导入的两个核心组件:

  1. ImporterContext: 一个“共享黑板”,用于在导入过程中追踪所有的 Tensor 和权重,解决 ONNX 基于名字的连接问题。
  2. OperatorRegistry: 一个“算子工厂”,负责根据 ONNX 的 op_type(如 “Conv”)找到对应的导入逻辑。

1. ImporterContext: 连接一切的桥梁

ONNX 的图结构是基于名字 (String) 的,而 Mini-Infer 的图结构是基于指针 (Pointer) 的。

  • ONNX: Node A 输出 “X”,Node B 输入 “X”。
  • Mini-Infer: Node Aoutput_tensor 指针需要被传递给 Node B

ImporterContext 的首要职责就是维护这张名字 -> 指针的映射表。

核心接口设计 (importer_context.h)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class ImporterContext {
public:
// ...
// 1. Tensor 管理:名字 -> 指针
void register_tensor(const std::string& name, std::shared_ptr<core::Tensor> tensor);
std::shared_ptr<core::Tensor> get_tensor(const std::string& name);
bool has_tensor(const std::string& name) const;

// 2. 权重管理:Initializer 也是 Tensor,但需要特殊标记
void register_weight(const std::string& name, std::shared_ptr<core::Tensor> weight);
std::shared_ptr<core::Tensor> get_weight(const std::string& name);
bool is_weight(const std::string& name) const;

// 3. 图构建:向 Graph 添加节点
void add_node(std::shared_ptr<graph::Node> node);
// ...
};

实现细节 (operator_importer.cpp)

在实现中,我们将 weight 同时注册为 tensor。这样做的好处是,当算子需要输入时(无论是来自上一层的激活值,还是来自权重的常量),都可以统一调用 get_tensor() 获取,简化了逻辑。

1
2
3
4
5
void ImporterContext::register_weight(const std::string& name, std::shared_ptr<core::Tensor> weight) {
weights_[name] = weight;
// 同时注册到 tensors_ 表中,方便统一查找
register_tensor(name, weight);
}

2. OperatorRegistry: 插件化的算子导入

我们不希望在 ModelImporter 中写一个巨大的 if-else 来处理所有的 ONNX 算子。

我们借鉴 TensorRT 的设计,采用注册表模式。每一个 ONNX 算子都有一个独立的 Importer 类,并通过工厂函数注册到 OperatorRegistry 中。

抽象基类:OperatorImporter

所有算子导入器都必须继承自这个基类。它的核心是 import_operator 方法,负责将一个 ONNX NodeProto 转换为 Mini-Infer 的节点。

1
2
3
4
5
6
7
8
9
10
11
12
class OperatorImporter {
public:
virtual ~OperatorImporter() = default;

// 核心接口:导入单个节点
virtual core::Status import_operator(
ImporterContext& ctx,
const onnx::NodeProto& node
) = 0;

virtual const char* get_op_type() const = 0;
};

注册表实现:OperatorRegistry

注册表本质上是一个 std::unordered_map,将 op_type 字符串映射到 ImporterFactory(一个创建导入器实例的函数)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class OperatorRegistry {
public:
using ImporterFactory = std::function<std::unique_ptr<OperatorImporter>()>;

// 注册算子
void register_operator(const std::string& op_type, ImporterFactory factory) {
importers_[op_type] = factory;
}

// 获取导入器
std::unique_ptr<OperatorImporter> get_importer(const std::string& op_type) {
auto it = importers_.find(op_type);
if (it != importers_.end()) {
return it->second(); // 调用工厂函数创建实例
}
return nullptr;
}
// ...
private:
std::unordered_map<std::string, ImporterFactory> importers_;
};

注册宏:REGISTER_ONNX_OPERATOR

为了简化注册过程,我们提供了一个宏。这使得在未来的 builtin_operators.cpp 中注册算子变得非常简洁。

1
2
3
4
#define REGISTER_ONNX_OPERATOR(op_type, importer_class) \
registry.register_operator(op_type, []() -> std::unique_ptr<OperatorImporter> { \
return std::make_unique<importer_class>(); \
})

3. 自动加载内置算子

OperatorRegistry 的构造函数会自动调用 register_builtin_operators()。这是一个全局函数(将在未来的博客中实现),用于批量注册所有内置的算子导入器(如 Conv, Relu, Flatten)。

1
2
3
4
5
// operator_importer.cpp
void OperatorRegistry::register_builtin_operators() {
// 调用全局注册函数,将 *this (registry) 传进去
::mini_infer::importers::register_builtin_operators(*this);
}

4. 总结

本篇我们构建了 ONNX 导入器的“神经系统”:

  1. ImporterContext:解决了 ONNX 名字到 Mini-Infer 指针的映射难题,并统一管理了 Tensor 和 Weight。
  2. OperatorRegistry:实现了算子导入逻辑的解耦。添加新算子支持不再需要修改核心代码,只需编写新的 Importer 并注册即可。

这套架构具有极高的可扩展性。未来如果我们需要支持更多的 ONNX 算子,或者支持自定义算子,只需要扩充注册表即可。