Mini-Infer (14): 迈向 ONNX — `Flatten` 算子与零拷贝视图
Mini-Infer (14): 迈向 ONNX — Flatten 算子与零拷贝视图
1. 为什么需要 Flatten?
在 CNN 网络(如 LeNet-5, VGG)中,数据流通常是这样的: Conv/Pool (4D Tensor) -> Flatten -> Linear (2D Matrix)
Linear 层(全连接层)通常期望输入是一个二维矩阵 [Batch, Features]。而卷积层的输出是四维张量 [Batch, Channel, Height, Width]。
Flatten 的作用就是把 [N, C, H, W] “拍扁” 成 [N, C*H*W]。
2. 算子定义:对齐 ONNX 标准
ONNX 对 Flatten 的定义非常灵活:它有一个 axis 参数。
- 输入:张量
T - 参数:
axis(默认为 1) - 输出:一个 2D 张量。
- 维度 0:输入张量从维度 0 到
axis-1的乘积。 - 维度 1:输入张量从维度
axis到最后的乘积。
- 维度 0:输入张量从维度 0 到
举例:输入 [2, 3, 4, 5],axis=1。
- 输出维度 0:
2(只有维度 0) - 输出维度 1:
3*4*5 = 60 - 结果:
[2, 60]
你的 FlattenParam 和 infer_shape 完美地实现了这一逻辑:
1 | // mini_infer/operators/flatten.cpp |
3. 核心技术:Zero-Copy View (零拷贝视图)
Flatten 算子有一个特殊的性质:它不改变数据,只改变数据的“解释方式”(Shape)。
如果在 Flatten::forward 中,我们申请一块新内存,然后把输入数据 memcpy 过去,那就是巨大的浪费!
Mini-Infer 的解决方案:Tensor::view
我们需要在 Tensor 类中支持一种机制,创建一个共享同一块内存但拥有不同 Shape 的新 Tensor 对象。
(注:这需要你在 core/tensor.h 中添加 view 方法)
1 | // 假设的 Tensor::view 实现 (在 core/tensor.cpp 中) |
有了这个机制,Flatten::forward 变得极其高效:
1 | // mini_infer/operators/flatten.cpp |
4. 总结
Flatten 算子的实现虽然代码量不大,但意义重大:
- 架构完善:它推动了
Tensor类支持 View 语义(零拷贝 reshape),这是高性能推理框架的标配(PyTorch/NumPy 也是如此)。 - ONNX 兼容:它严格遵循 ONNX
Flatten算子的定义,为后续的ONNX Parser开发扫清了障碍。 - 连接层级:它打通了卷积层(4D)和全连接层(2D)之间的任督二脉,使得我们可以自动构建 LeNet-5 网络,而不再需要手动写
reshape代码。
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 James的成长之路!
评论





