Mini-Infer (33): 插件架构 (中) — CRTP 基类与静态多态
1. CRTP 模式回顾
CRTP(Curiously Recurring Template Pattern,奇异递归模板模式)是 C++ 中一种强大的设计模式,用于实现静态多态。
A. 基本形式
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| template <typename Derived> class Base { public: void interface() { static_cast<Derived*>(this)->implementation(); } };
class Derived : public Base<Derived> { public: void implementation() { } };
|
B. 静态多态 vs 动态多态
| 特性 |
动态多态 (virtual) |
静态多态 (CRTP) |
| 绑定时机 |
运行时 |
编译时 |
| 虚函数表 |
需要 |
不需要 |
| 性能开销 |
有(间接调用) |
无(内联) |
| 灵活性 |
高(运行时决定) |
低(编译时决定) |
| 代码膨胀 |
无 |
有(模板实例化) |
在 Mini-Infer 中,我们结合使用两种多态:
- IPlugin 接口:使用虚函数,支持运行时多态(注册表查找)。
- CPUPlugin/CUDAPlugin 基类:使用 CRTP,消除样板代码。
2. CPUPlugin<Derived, ParamType> 设计
A. 类定义
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
|
template <typename Derived, typename ParamType = PluginParam> class CPUPlugin : public IPlugin { public: CPUPlugin() = default; ~CPUPlugin() override = default;
core::DeviceType get_device_type() const noexcept override { return core::DeviceType::CPU; }
std::unique_ptr<IPlugin> clone() const override { auto cloned = std::make_unique<Derived>(static_cast<const Derived&>(*this)); return cloned; }
const void* get_param_ptr() const noexcept override { return param_ ? param_.get() : nullptr; }
void set_param(std::shared_ptr<PluginParam> param) override { param_ = std::dynamic_pointer_cast<ParamType>(param); }
const ParamType* param() const { return param_.get(); }
void set_typed_param(const ParamType& param) { param_ = std::make_shared<ParamType>(param); }
protected: std::shared_ptr<ParamType> param_; };
|
B. get_device_type() 的固定实现
1 2 3
| core::DeviceType get_device_type() const noexcept override { return core::DeviceType::CPU; }
|
所有继承 CPUPlugin 的类都自动返回 CPU,无需重复实现。
C. clone() 的 CRTP 魔法
1 2 3 4
| std::unique_ptr<IPlugin> clone() const override { auto cloned = std::make_unique<Derived>(static_cast<const Derived&>(*this)); return cloned; }
|
这是 CRTP 的核心价值:
Derived 是模板参数,在编译时已知。
static_cast<const Derived&>(*this) 将基类引用转换为派生类引用。
std::make_unique<Derived>(...) 调用派生类的拷贝构造函数。
没有 CRTP 的话,每个派生类都需要手动实现 clone():
1 2 3 4 5 6 7 8 9 10 11 12
| class ReLUCPUPlugin : public IPlugin { std::unique_ptr<IPlugin> clone() const override { return std::make_unique<ReLUCPUPlugin>(*this); } };
class Conv2DCPUPlugin : public IPlugin { std::unique_ptr<IPlugin> clone() const override { return std::make_unique<Conv2DCPUPlugin>(*this); } };
|
D. 参数管理
1 2 3 4 5 6 7 8 9 10 11
| void set_param(std::shared_ptr<PluginParam> param) override { param_ = std::dynamic_pointer_cast<ParamType>(param); }
const ParamType* param() const { return param_.get(); }
void set_typed_param(const ParamType& param) { param_ = std::make_shared<ParamType>(param); }
|
- set_param:接受基类指针,动态转换为具体类型。
- param():返回类型安全的参数指针。
- set_typed_param:直接设置具体类型的参数。
3. SimpleCPUPlugin 简化版
对于无参数的算子(如 ReLU),我们提供更简单的基类:
1 2 3 4 5 6
| template <typename Derived> class SimpleCPUPlugin : public CPUPlugin<Derived, PluginParam> { public: SimpleCPUPlugin() = default; ~SimpleCPUPlugin() override = default; };
|
使用示例:
1 2 3
| class ReLUCPUPlugin : public SimpleCPUPlugin<ReLUCPUPlugin> { };
|
4. CUDAPlugin<Derived, ParamType> 设计
A. 类定义
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
|
template <typename Derived, typename ParamType = PluginParam> class CUDAPlugin : public IPlugin { public: CUDAPlugin() = default; ~CUDAPlugin() override = default;
core::DeviceType get_device_type() const noexcept override { return core::DeviceType::CUDA; }
std::unique_ptr<IPlugin> clone() const override { auto cloned = std::make_unique<Derived>(static_cast<const Derived&>(*this)); return cloned; }
const void* get_param_ptr() const noexcept ov ... } void set_param(std::shared_ptr<PluginParam> param) override { ... } const ParamType* param() const { ... } void set_typed_param(const ParamType& param) { ... }
#ifdef MINI_INFER_USE_CUDA static cudaStream_t get_cuda_stream(const PluginContext& context); #endif
protected: std::shared_ptr<ParamType> param_; };
|
B. get_cuda_stream() 辅助方法
1 2 3 4 5 6 7 8 9 10
| #ifdef MINI_INFER_USE_CUDA static cudaStream_t get_cuda_stream(const PluginContext& context) { if (!context.device_context) { return nullptr; } auto* cuda_ctx = dynamic_cast<backends::cuda::CUDADeviceContext*>( context.device_context); return cuda_ctx ? cuda_ctx->stream() : nullptr; } #endif
|
这个辅助方法让派生类可以方便地获取 CUDA Stream。
C. cuDNN/cuBLAS Handle 访问
类似地,可以添加获取 cuDNN 和 cuBLAS Handle 的辅助方法:
1 2 3 4 5 6 7 8 9 10 11
| static cudnnHandle_t get_cudnn_handle(const PluginContext& context) { auto* cuda_ctx = dynamic_cast<backends::cuda::CUDADeviceContext*>( context.device_context); return cuda_ctx ? cuda_ctx->cudnn_handle() : nullptr; }
static cublasHandle_t get_cublas_handle(const PluginContext& context) { auto* cuda_ctx = dynamic_cast<backends::cuda::CUDADeviceContext*>( context.device_context); return cuda_ctx ? cuda_ctx->cublas_handle() : nullptr; }
|
5. 实战:ReLUCPUPlugin 实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
|
class ReLUCPUPlugin : public SimpleCPUPlugin<ReLUCPUPlugin> { public: const char* get_plugin_type() const noexcept override { return "Relu"; } core::OpType get_op_type() const noexcept override { return core::OpType::kRELU; }
core::Status infer_output_shapes( const std::vector<core::Shape>& input_shapes, std::vector<core::Shape>& output_shapes) const override { if (input_shapes.empty()) { return core::Status::ERROR_INVALID_ARGUMENT; } output_shapes = {input_shapes[0]}; return core::Status::SUCCESS; }
re::Status enqueue( const std::vector<std::shared_ptr<core::Tensor>>& inputs, std::vector<std::shared_ptr<core::Tensor>>& outputs, const PluginContext& context) override {
const auto& input = inputs[0]; auto& output = outputs[0];
const float* in_data = static_cast<const float*>(input->data()); float* out_data = static_cast<float*>(output->data()); const size_t numel = input->shape().numel();
for (size_t i = 0; i < numel; ++i) { out_data[i] = std::max(0.0f, in_data[i]); }
return core::Status::SUCCESS; } };
REGISTER_PLUGIN_SIMPLE(ReLUCPUPlugin, "Relu", kRELU, CPU)
|
注意:
- 继承
SimpleCPUPlugin<ReLUCPUPlugin>(CRTP)。
- 只需实现
get_plugin_type、get_op_type、infer_output_shapes、enqueue。
clone()、get_device_type() 等由基类自动提供。
6. 实战:Conv2DCPUPlugin 实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
|
class Conv2DCPUPlugin : public CPUPlugin<Conv2DCPUPlugin, Conv2DParam> { public: const char* get_plugin_type() const noexcept override { return "Conv"; } core::OpType get_op_type() const noexcept override { return core::OpType::kCONV2D; } int32_t get_nb_inputs() const noexcept override { return param_->use_bias ? 3 }
core::Statuutput_shapes( const std::vector<core::Shape>& input_shapes, std::vector<core::Shape>& output_shapes) const override {
const auto& input_shape = input_shapes[0]; const int64_t N = input_shape[0]; const int64_t C_in = input_shape[1]; const int64_t H_in = input_shape[2]; const int64_t W_in = input_shape[3];
const auto& weight_shape = input_shapes[1]; const int64_t C_out = weight_shape[0];
const int64_t H_out = (H_in + 2 * param_->padding_h - param_->kernel_h) / param_->stride_h + 1; const int64_t W_out = (W_in + 2 * param_->padding_w - param_->kernel_w) / param_->stride_w + 1;
output_shapes = {core::Shape({N, C_out, H_out, W_out})}; return core::Status::SUCCESS; }
core::Status enqueue( const std::vector<std::shared_ptr<core::Tensor>>& inputs, std::vector<std::shared_ptr<core::Tensor>>& outputs, const PluginContext& context) override {
const auto& input = inputs[0]; const auto& weight = inputs[1]; const auto* bias = param_->use_bias ? inputs[2].get() : nullptr; auto& output = outputs[0];
if (param_->activation == ActivationType::RELU) { apply_relu(output); }
return core::Status::SUCCESS; } };
REGISTER_PLUGIN_SIMPLE(Conv2DCPUPlugin, "Conv", kCONV2D, CPU)
|
注意:
- 继承
CPUPlugin<Conv2DCPUPlugin, Conv2DParam>(带参数类型)。
- 通过
param_-> 访问卷积参数。
- 支持激活融合(Conv + ReLU)。
7. CRTP 的优势总结
| 优势 |
说明 |
| 消除样板代码 |
clone()、get_device_type() 等自动生成 |
| 类型安全 |
|
| 零运行时开销 |
没有虚函数调用开销(对于 CRTP 部分) |
| 代码复用 |
公共逻辑集中在基类 |
| 易于扩展 |
添加新算子只需继承基类 |
8. 总结
本篇我们实现了 Mini-Infer 插件架构的 CRTP 基类:
- CPUPlugin<Derived, ParamType>:CPU 插件的 CRTP 基类。
- CUDAPlugin<Derived, ParamType>:CUDA 插件的 CRTP 基类。
- SimpleCPUPlugin / SimpleCUDAPlugin:无参数算子的简化版。
- clone() 的 CRTP 魔法:自动生成深拷贝代码。
- 参数管理:类型安全的参数访问。
下一篇,我们将实现 PluginRegistry,看看如何通过静态注册机制自动发现和创建 Plugin。