Mini-Infer (6): 点亮引擎!实现 infer_shape, ReLU 与 GEMM 抽象
本篇,我们将真正“闭合” Engine 的执行循环。为此,我们必须完成两项核心任务:
实现 infer_shape :这是 Engine 进行“静态内存规划”的钥匙。
实现 forward :编写第一个 Operator(ReLU)的 CPU 计算代码。
我们还将实现一个更复杂的 Linear(全连接)算子,并引出一个全新的、为性能而生的架构层:Kernel 抽象 。
1. 缺失的环节:infer_shape 与内存预分配
在第 5 篇中,我们的 Engine::build() 流水线卡在了 allocate_tensors()。Engine 不知道 Convolution 的输出是多大,也不知道 Linear 的输出是多大。
Operator 基类中的 infer_shape 纯虚函数就是为此而生的“合约”。它要求每个算子必须 有能力“只通过输入的 *Shape*,就计算出输出的 *Shape*” 。
ReLU 是最简单的例子:它不改变形状。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 core::Status ReLU::infer_shape ( const std::vector<core::Shape>& input_shapes, std::vector<core::Shape>& output_shapes) { if (input_shapes.size () != 1 ) { return core::Status::ERROR_INVALID_ARGUMENT; } output_shapes.clear (); output_shapes.push_back (input_shapes[0 ]); return core::Status::SUCCESS; }
Linear(全连接)则更复杂,它必须 改变形状:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 core::Status Linear::infer_shape ( const std::vector<core::Shape>& input_shapes, std::vector<core::Shape>& output_shapes) { const auto & input_shape = input_shapes[0 ]; const auto & weight_shape = input_shapes[1 ]; int in_features = static_cast <int >(input_shape[input_shape.ndim () - 1 ]); int out_features = static_cast <int >(weight_shape[0 ]); std::vector<int64_t > output_dims; for (size_t i = 0 ; i < input_shape.ndim () - 1 ; ++i) { output_dims.push_back (input_shape[i]); } output_dims.push_back (static_cast <int64_t >(out_features)); output_shapes.clear (); output_shapes.push_back (core::Shape (output_dims)); return core::Status::SUCCESS; }
Engine 如何使用它?
现在,Engine::allocate_tensors() 终于可以被实现了(逻辑如下):
遍历 sorted_nodes_(拓扑排序好的节点)。
对于 Input 节点,跳过(因为 Tensor 由用户提供)。
对于其他节点(如 Conv1):
从它的上游节点(node->inputs())收集 Shape。
调用 node->get_operator()->infer_shape(...)。
Engine 得到 Conv1 的输出 Shape。
Engine 立即 为 Conv1 创建(core::Tensor::create(shape, dtype))所有 output_tensors。
进入下一个节点(如 ReLU1),ReLU1 的 infer_shape 会使用 Conv1 刚计算出的 Shape。
循环结束时,图中所有中间 Tensor 都已被预先分配好 。
infer_shape 的实现,正式达成了我们在 Blog 5 中定下的**“静态内存规划”**目标。
2. “Hello, World!”:实现第一个算子 ReLU
有了 infer_shape,我们还需要 forward(真正的计算)。ReLU 是一个完美的“Hello, World”算子。
ReLU::forward 的实现展示了 Operator 的标准执行流程:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 core::Status ReLU::forward ( const std::vector<std::shared_ptr<core::Tensor>>& inputs, std::vector<std::shared_ptr<core::Tensor>>& outputs) { if (inputs.size () != 1 || !inputs[0 ]) { return core::Status::ERROR_INVALID_ARGUMENT; } const auto & input = inputs[0 ]; auto output = core::Tensor::create (input->shape (), input->dtype ()); size_t total_elements = static_cast <size_t >(input->shape ().numel ()); if (input->dtype () == core::DataType::FLOAT32) { const float * input_data = static_cast <const float *>(input->data ()); float * output_data = static_cast <float *>(output->data ()); for (size_t i = 0 ; i < total_elements; ++i) { output_data[i] = std::max (0.0f , input_data[i]); } } else { return core::Status::ERROR_INVALID_ARGUMENT; } outputs.clear (); outputs.push_back (output); return core::Status::SUCCESS; }
最后,我们使用在 Blog 3 中设计的“自注册”宏,来“激活”这个算子:
1 2 REGISTER_OPERATOR (ReLU, ReLU);
在程序启动时,REGISTER_OPERATOR 宏会自动创建一个全局变量,其构造函数会调用 OperatorFactory::register_operator("ReLU", ...)。
至此,我们的 Engine 终于能真正地 创建并执行一个 ReLU 节点了!
3. 新的抽象层:Kernel 与 Linear 算子
ReLU 是“内存带宽”密集型算子,一个 for 循环就够了。但 Linear(全连接层)是“计算”密集型算子,它的核心是GEMM (通用矩阵乘法) :output = input @ weight^T。
我们不应该 在 Linear::forward 中直接写一个三层 for 循环的 GEMM。
为什么? 因为 GEMM 是整个框架的性能瓶颈 。我们未来需要用 AVX2、OpenBLAS、cuBLAS (CUDA) 来替换它。如果把实现写死在 Linear::forward 中,我们将无法进行这种优化。
因此,我们引入一个新的、更底层的抽象:Kernel 。
gemm.h:Kernel 抽象层
Kernel 层不是 一个 Operator。它是一个具体的“计算函数”的调度器。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 namespace mini_infer {namespace kernels {enum class KernelBackend { CPU, CPU_AVX2, CPU_BLAS, CUDA_CUBLAS }; class GEMMKernel {public : template <typename T> static void gemm_nt ( const T* A, const T* B, T* C, int M, int N, int K, KernelBackend backend = KernelBackend::CPU ) ; static KernelBackend get_best_backend () ; }; } }
这个设计将 Linear 算子(负责“业务逻辑”)与 GEMMKernel(负责“高性能计算”)彻底解耦 。
gemm_cpu.cpp:第一个 CPU Kernel 实现
我们提供了 GEMM 的一个“朴素” CPU 实现。注意 gemm_nt_impl,它实现了 C = A @ B^T。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 namespace mini_infer {namespace kernels {namespace cpu {template <typename T>void gemm_nt_impl ( const T* A, const T* B, T* C, int M, int N, int K) { std::memset (C, 0 , sizeof (T) * M * N); for (int m = 0 ; m < M; ++m) { const T* a_row = A + m * K; T* c_row = C + m * N; for (int n = 0 ; n < N; ++n) { const T* b_row = B + n * K; T sum = 0 ; for (int k = 0 ; k < K; ++k) { sum += a_row[k] * b_row[k]; } c_row[n] = sum; } } } } template <typename T>void GEMMKernel::gemm_nt (...) { cpu::gemm_nt_impl (A, B, C, M, N, K); } } }
Linear::forward:调用 Kernel
现在,Linear::forward 的实现变得非常清晰。它的职责不是计算 ,而是**“编排”**:
验证输入 (input, weight, bias)。
从 Shape 中计算出 M, N, K。
创建输出 Tensor。
调用 GEMMKernel 。
(可选)调用 add_bias 辅助函数。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 core::Status Linear::forward ( const std::vector<std::shared_ptr<core::Tensor>>& inputs, std::vector<std::shared_ptr<core::Tensor>>& outputs) { int batch_size = ...; int in_features = ...; int out_features = ...; auto output = core::Tensor::create (output_shape, input->dtype ()); if (input->dtype () == core::DataType::FLOAT32) { const float * input_data = ...; const float * weight_data = ...; float * output_data = ...; kernels::GEMMKernel::gemm_nt <float >( input_data, weight_data, output_data, batch_size, out_features, in_features ); if (param_.use_bias) { const float * bias_data = ...; add_bias (output_data, bias_data, batch_size, out_features); } } outputs.clear (); outputs.push_back (output); return core::Status::SUCCESS; }
总结与展望
Mini-Infer 活了 !
我们终于填补了 Engine 的最后两个 //TODO。
infer_shape 被实现,Engine::allocate_tensors()(静态内存规划)现在成为可能。
forward 被实现,我们通过 REGISTER_OPERATOR 宏向工厂提供了 ReLU 和 Linear 两个算子。
我们还引入了一个全新的**Kernel 抽象层**,将“算子逻辑”与“底层优化”解耦,为未来的 AVX2 和 CUDA 优化铺平了道路。