1
2
3
4
5
6
7
8
9
10
11
12
title: 'Mini-Infer (19): 内置算子导入实战 — Conv, ReLU, Pooling 与 Flatten'
tags:
- AI Infra
categories:
- AI Infra
- Mini-Infer
cover: 'https://imgs.james-blog.top/imgs/dd34fa5844ac016d966f0adde655cee7.png'
toc: true
mathjax: true
abbrlink: 112
date: 2025-12-03 22:40:00
description: 'Mini-Infer (19): 内置算子导入实战 — Conv, ReLU, Pooling 与 Flatten'

Mini-Infer (19): 内置算子导入实战 — Conv, ReLU, Pooling 与 Flatten

1. 核心逻辑:OperatorImporter 的标准范式

在实现具体算子之前,我们总结出一套标准的导入流程(范式):

  1. 解析属性 (Parse Attributes):使用 AttributeHelper 获取 strides, pads 等参数。
  2. 获取输入 (Get Inputs):从 ImporterContext 中查找输入 Tensor。如果是权重(Initializer),需要特殊处理。
  3. 创建算子 (Create Operator):构建 Mini-InferOperator 对象(如 Conv2D)。
  4. 构建节点 (Build Node):创建 Graph::Node,并将 Operator 绑定上去。
  5. 连接边 (Connect Edges):将输入 Tensor 连接到当前 Node。
  6. 注册输出 (Register Output):创建输出 Tensor 并注册到 Context,供后续节点使用。

所有算子的导入器都遵循这个“六步走”战略。

2. 算子实战:Conv (最复杂的例子)

ConvImporter 是最典型的代表。它不仅参数多,而且需要处理权重(Input 1)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// mini_infer/importers/builtin_operators.cpp -> ConvImporter

core::Status ConvImporter::import_operator(ImporterContext& ctx, const onnx::NodeProto& node) {
AttributeHelper attrs(node);

// 1. 解析属性:kernel_shape, strides, pads...
auto kernel_shape = attrs.get_ints("kernel_shape");
auto strides = to_int_vector(attrs.get_ints("strides"));
// ...

// 2. 获取输入:Input (X), Weight (W), Bias (B)
std::string input_name = node.input(0);
// ... 检查 Tensor 是否存在 ...

// 3. 构建 Conv2DParam
operators::Conv2DParam param;
param.kernel_h = kernel_shape[0];
param.stride_h = strides[0];
// ...

// 4. 创建 Operator 和 Node
auto op = std::make_shared<operators::Conv2D>(param);
auto graph_node = ctx.get_graph()->create_node(node.output(0));
graph_node->set_operator(op);

// 5. 连接输入 & 6. 注册输出 ...
return core::Status::SUCCESS;
}

3. 算子实战:Gemm (矩阵乘法)

ONNX 的 Gemm 算子定义非常通用:Y = alpha * A * B + beta * C。 但在推理中,最常见的情况是全连接层:Y = A * B^T + C (transB=1)。

我们的 GemmImporter 做了一个聪明的映射:如果检测到它是全连接模式,就将其映射为 Mini-InferLinear 算子。

1
2
3
4
5
6
7
8
9
10
11
12
13
// GemmImporter
int64_t transB = attrs.get_int("transB", 0);

if (transA == 0 && transB == 1 && alpha == 1.0f) {
// 映射为 Linear 算子
operators::LinearParam param;
param.use_bias = use_bias;
auto op = std::make_shared<operators::Linear>(param);
// ...
} else {
// 复杂的 Gemm 暂不支持,或者可以映射为通用的 MatMul
return core::Status::ERROR_NOT_IMPLEMENTED;
}

4. 算子实战:Reshape 作为 Flatten

这是一个有趣的特例。在 LeNet-5 的 ONNX 导出中,通常看不到 Flatten 算子,取而代之的是 Reshape

Reshape 算子通常有两个输入:datashape。 但在 LeNet-5 中,这个 Reshape 的作用就是把 [N, 16, 4, 4] 变成 [N, 256]

为了简化实现,我们在 ReshapeImporter 中做了一个“投机取巧”的处理:如果遇到 Reshape,我们暂且把它当做 Flatten(axis=1) 来处理(针对 LeNet-5 优化)。

(注:在完善的框架中,应该实现通用的 Reshape 算子,读取 shape 输入并执行 view 操作。但在这里,为了快速跑通 LeNet-5,这种映射是允许的。)

1
2
3
4
5
// ReshapeImporter
ctx.log_info("Reshape operator -> treat as Flatten(axis=1) for LeNet-5");

operators::FlattenParam param(1); // axis=1
auto op = std::make_shared<operators::Flatten>(param);

5. 注册中心:register_builtin_operators

最后,我们将所有编写好的 Importer 类注册到全局。

1
2
3
4
5
6
void register_builtin_operators(OperatorRegistry& registry) {
REGISTER_ONNX_OPERATOR("Conv", ConvImporter);
REGISTER_ONNX_OPERATOR("Relu", ReluImporter);
REGISTER_ONNX_OPERATOR("MaxPool", MaxPoolImporter);
// ...
}

6. 总结与里程碑

至此,我们现在拥有的能力:

  1. Core: 完备的 Tensor/Memory 系统。
  2. Engine: 高性能的推理引擎。
  3. Operators: 实现了 Conv, ReLU, Pool, Linear, Flatten。
  4. Parser: 能够读取 .onnx 文件,解析权重,并将上述算子自动组装成计算图。