无题
1 | title: 'Mini-Infer (19): 内置算子导入实战 — Conv, ReLU, Pooling 与 Flatten' |
Mini-Infer (19): 内置算子导入实战 — Conv, ReLU, Pooling 与 Flatten
1. 核心逻辑:OperatorImporter 的标准范式
在实现具体算子之前,我们总结出一套标准的导入流程(范式):
- 解析属性 (Parse Attributes):使用
AttributeHelper获取 strides, pads 等参数。 - 获取输入 (Get Inputs):从
ImporterContext中查找输入 Tensor。如果是权重(Initializer),需要特殊处理。 - 创建算子 (Create Operator):构建
Mini-Infer的Operator对象(如Conv2D)。 - 构建节点 (Build Node):创建
Graph::Node,并将 Operator 绑定上去。 - 连接边 (Connect Edges):将输入 Tensor 连接到当前 Node。
- 注册输出 (Register Output):创建输出 Tensor 并注册到 Context,供后续节点使用。
所有算子的导入器都遵循这个“六步走”战略。
2. 算子实战:Conv (最复杂的例子)
ConvImporter 是最典型的代表。它不仅参数多,而且需要处理权重(Input 1)。
1 | // mini_infer/importers/builtin_operators.cpp -> ConvImporter |
3. 算子实战:Gemm (矩阵乘法)
ONNX 的 Gemm 算子定义非常通用:Y = alpha * A * B + beta * C。 但在推理中,最常见的情况是全连接层:Y = A * B^T + C (transB=1)。
我们的 GemmImporter 做了一个聪明的映射:如果检测到它是全连接模式,就将其映射为 Mini-Infer 的 Linear 算子。
1 | // GemmImporter |
4. 算子实战:Reshape 作为 Flatten
这是一个有趣的特例。在 LeNet-5 的 ONNX 导出中,通常看不到 Flatten 算子,取而代之的是 Reshape。
Reshape 算子通常有两个输入:data 和 shape。 但在 LeNet-5 中,这个 Reshape 的作用就是把 [N, 16, 4, 4] 变成 [N, 256]。
为了简化实现,我们在 ReshapeImporter 中做了一个“投机取巧”的处理:如果遇到 Reshape,我们暂且把它当做 Flatten(axis=1) 来处理(针对 LeNet-5 优化)。
(注:在完善的框架中,应该实现通用的 Reshape 算子,读取 shape 输入并执行 view 操作。但在这里,为了快速跑通 LeNet-5,这种映射是允许的。)
1 | // ReshapeImporter |
5. 注册中心:register_builtin_operators
最后,我们将所有编写好的 Importer 类注册到全局。
1 | void register_builtin_operators(OperatorRegistry& registry) { |
6. 总结与里程碑
至此,我们现在拥有的能力:
- Core: 完备的 Tensor/Memory 系统。
- Engine: 高性能的推理引擎。
- Operators: 实现了 Conv, ReLU, Pool, Linear, Flatten。
- Parser: 能够读取
.onnx文件,解析权重,并将上述算子自动组装成计算图。
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 James的成长之路!
评论
