Mini-Infer (18): 编排导入流程 — ModelImporterAttributeHelper

引言:从架构到实现

在之前的博客中,我们设计了 OnnxParser 的顶层入口,定义了 ImporterContextOperatorRegistry 的接口,并实现了 WeightImporter 来解析数据。

现在,我们需要将这些组件真正“运转”起来。

本篇,我们将实现两个核心组件:

  1. AttributeHelper:一个极其使用的工具类,用于解决 ONNX Protobuf 属性访问繁琐的问题。
  2. ModelImporter:整个导入过程的“总指挥”。它负责按照正确的顺序(权重 -> 输入 -> 节点 -> 输出)编排导入流程,并将解析任务分发给注册表。

1. AttributeHelper:优雅地解析属性

ONNX 的 NodeProto 使用 Key-Value 的列表来存储算子属性(如卷积的 strides, pads)。使用原生的 Protobuf API 来查找和读取这些属性非常啰嗦且容易出错。

我们需要一个包装器来简化这个过程。

封装痛点

在原生 Protobuf 中,获取一个名为 kernel_shape 的属性可能需要写十几行代码来遍历 attribute 列表、检查名字、检查类型。

优雅实现

AttributeHelper 将这些逻辑封装在内部,对外提供了极其简洁的接口。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
// mini_infer/importers/attribute_helper.cpp

AttributeHelper::AttributeHelper(const onnx::NodeProto& node) : node_(node) {}

// 核心查找逻辑
const onnx::AttributeProto* AttributeHelper::find_attribute(const std::string& name) const {
for (int i = 0; i < node_.attribute_size(); ++i) {
const auto& attr = node_.attribute(i);
if (attr.name() == name) {
return &attr;
}
}
return nullptr;
}

// 类型安全的访问器 (带默认值)
int64_t AttributeHelper::get_int(const std::string& name, int64_t default_value) const {
const auto* attr = find_attribute(name);
if (attr && attr->has_i()) {
return attr->i();
}
return default_value;
}

std::vector<int64_t> AttributeHelper::get_ints(const std::string& name) const {
const auto* attr = find_attribute(name);
if (attr && attr->ints_size() > 0) {
return std::vector<int64_t>(attr->ints().begin(), attr->ints().end());
}
return {};
}

有了这个工具,未来的算子导入代码将变得非常干净: int stride = attrs.get_int("stride", 1);


2. ModelImporter:导入流程的总指挥

ModelImporter 是解析器的核心引擎。它的 import_model 方法负责协调所有的资源,确保图构建的正确性。

我们将详细分析 import_graph 的四个关键步骤。

步骤 1:导入 Initializers (权重)

权重必须最先导入,因为后续算子(如 Conv)在创建时就需要访问它们。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// mini_infer/importers/model_importer.cpp

core::Status ModelImporter::import_initializers(...) {
for (int i = 0; i < graph_proto.initializer_size(); ++i) {
const auto& tensor_proto = graph_proto.initializer(i);
const std::string& name = tensor_proto.name();

// 调用我们之前实现的 WeightImporter
std::string error_msg;
auto tensor = WeightImporter::import_tensor(tensor_proto, error_msg);

if (!tensor) { /* 错误处理 */ }

// 注册到 Context 中,标记为 Weight
ctx.register_weight(name, tensor);
}
return core::Status::SUCCESS;
}

步骤 2:导入 Inputs & Outputs

这一步主要是在 Context 中注册占位符 Tensor。

  • Inputs: 遍历 graph.input。如果一个输入名字不是权重(!ctx.is_weight(name)),那它就是模型的真实输入,我们需要创建一个空的 Tensor 并注册。
  • Outputs: 遍历 graph.output,同样注册空的 Tensor 占位符。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
// mini_infer/importers/model_importer.cpp
core::Status ModelImporter::import_inputs(...) {
for (int i = 0; i < graph_proto.input_size(); ++i) {
const std::string& name = graph_proto.input(i).name();
// 如果是权重,跳过(因为 Initializer 已经处理过了)
if (ctx.is_weight(name)) continue;

if (!ctx.has_tensor(name)) {
auto input_tensor = std::make_shared<core::Tensor>();
ctx.register_tensor(name, input_tensor);
}
}
return core::Status::SUCCESS;
}

步骤 3:导入 Nodes (核心循环)

这是将 ONNX NodeProto 转换为 Mini-Infer 节点的关键步骤。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// mini_infer/importers/model_importer.cpp

core::Status ModelImporter::import_node(const onnx::NodeProto& node, ImporterContext& ctx) {
const std::string& op_type = node.op_type();

// 1. 检查是否支持
if (!registry_->is_supported(op_type)) {
return core::Status::ERROR_NOT_IMPLEMENTED;
}

// 2. 获取对应的 Importer (从 Registry 中)
auto importer = registry_->get_importer(op_type);

// 3. 执行导入 (委托给具体的 Importer)
return importer->import_operator(ctx, node);
}

步骤 4:图的收尾工作

在所有节点导入完成后,我们需要确保 Graph 对象的状态是完整的:

  1. 创建节点占位符:对于 Graph 的 Input/Output 名字,如果在导入过程中没有对应的 Node 生成(虽然这很少见,但为了鲁棒性),我们需要手动调用 graph->create_node()
  2. 设置 Graph 输入输出:调用 graph->set_inputs()graph->set_outputs(),确立图的边界。
1
2
3
4
5
6
7
8
9
10
11
12
// mini_infer/importers/model_importer.cpp :: import_graph 尾部

// Set graph inputs and outputs
{
std::vector<std::string> input_names;
// ... 收集 input names ...
ctx.get_graph()->set_inputs(input_names);

std::vector<std::string> output_names;
// ... 收集 output names ...
ctx.get_graph()->set_outputs(output_names);
}

3. 总结

本篇我们完成了 Mini-Infer 模型导入功能的逻辑闭环

  1. AttributeHelper:为后续的具体算子导入提供了极其便利的工具,让我们不用再和复杂的 Protobuf API 纠缠。
  2. ModelImporter:实现了严谨的导入流程控制。它像一个流水线管理员,确保了权重、输入、算子按照正确的顺序被解析和注册。

至此,我们的 ONNX Parser 已经具备了处理完整模型文件的能力——除了具体算子的转换逻辑