Mini-Infer (14): 迈向 ONNX — Flatten 算子与零拷贝视图

1. 为什么需要 Flatten

在 CNN 网络(如 LeNet-5, VGG)中,数据流通常是这样的: Conv/Pool (4D Tensor) -> Flatten -> Linear (2D Matrix)

Linear 层(全连接层)通常期望输入是一个二维矩阵 [Batch, Features]。而卷积层的输出是四维张量 [Batch, Channel, Height, Width]

Flatten 的作用就是把 [N, C, H, W] “拍扁” 成 [N, C*H*W]

2. 算子定义:对齐 ONNX 标准

ONNX 对 Flatten 的定义非常灵活:它有一个 axis 参数。

  • 输入:张量 T
  • 参数:axis (默认为 1)
  • 输出:一个 2D 张量。
    • 维度 0:输入张量从维度 0 到 axis-1 的乘积。
    • 维度 1:输入张量从维度 axis 到最后的乘积。

举例:输入 [2, 3, 4, 5]axis=1

  • 输出维度 0:2 (只有维度 0)
  • 输出维度 1:3*4*5 = 60
  • 结果:[2, 60]

你的 FlattenParaminfer_shape 完美地实现了这一逻辑:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// mini_infer/operators/flatten.cpp

core::Status Flatten::infer_shape(...) {
// ...
// Calculate output shape: [dim_0 * ... * dim_{axis-1}, dim_axis * ... * dim_n]
int64_t first_dim = 1;
for (int i = 0; i < axis; ++i) {
first_dim *= dims[i];
}

int64_t second_dim = 1;
for (size_t i = axis; i < dims.size(); ++i) {
second_dim *= dims[i];
}

output_shapes.push_back(core::Shape({first_dim, second_dim}));
// ...
}

3. 核心技术:Zero-Copy View (零拷贝视图)

Flatten 算子有一个特殊的性质:它不改变数据,只改变数据的“解释方式”(Shape)

如果在 Flatten::forward 中,我们申请一块新内存,然后把输入数据 memcpy 过去,那就是巨大的浪费!

Mini-Infer 的解决方案:Tensor::view

我们需要在 Tensor 类中支持一种机制,创建一个共享同一块内存但拥有不同 Shape 的新 Tensor 对象。

(注:这需要你在 core/tensor.h 中添加 view 方法)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// 假设的 Tensor::view 实现 (在 core/tensor.cpp 中)
std::shared_ptr<Tensor> Tensor::view(const Shape& new_shape) {
if (new_shape.numel() != this->shape_.numel()) {
return nullptr; // 元素总数必须一致
}

// 创建一个新 Tensor
auto new_tensor = std::make_shared<Tensor>();
new_tensor->shape_ = new_shape;
new_tensor->dtype_ = this->dtype_;

// 【关键】共享 data_ 指针!
// shared_ptr 的拷贝构造函数会增加引用计数
new_tensor->data_ = this->data_;

return new_tensor;
}

有了这个机制,Flatten::forward 变得极其高效:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// mini_infer/operators/flatten.cpp

core::Status Flatten::forward(...) {
// ... 计算 output_shapes ...

// Flatten is just a view change - create a view with different shape
// This is a ZERO-COPY operation! (shares the same underlying data)
auto output = input->view(output_shapes[0]);

if (!output) return core::Status::ERROR_INVALID_ARGUMENT;

outputs.clear();
outputs.push_back(output);
return core::Status::SUCCESS;
}

4. 总结

Flatten 算子的实现虽然代码量不大,但意义重大:

  1. 架构完善:它推动了 Tensor 类支持 View 语义(零拷贝 reshape),这是高性能推理框架的标配(PyTorch/NumPy 也是如此)。
  2. ONNX 兼容:它严格遵循 ONNX Flatten 算子的定义,为后续的 ONNX Parser 开发扫清了障碍。
  3. 连接层级:它打通了卷积层(4D)和全连接层(2D)之间的任督二脉,使得我们可以自动构建 LeNet-5 网络,而不再需要手动写 reshape 代码。