Mini-Infer (33): 插件架构 (中) — CRTP 基类与静态多态

1. CRTP 模式回顾

CRTP(Curiously Recurring Template Pattern,奇异递归模板模式)是 C++ 中一种强大的设计模式,用于实现静态多态

A. 基本形式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
template <typename Derived>
class Base {
public:
void interface() {
static_cast<Derived*>(this)->implementation();
}
};

class Derived : public Base<Derived> {
public:
void implementation() {
// 具体实现
}
};

B. 静态多态 vs 动态多态

特性 动态多态 (virtual) 静态多态 (CRTP)
绑定时机 运行时 编译时
虚函数表 需要 不需要
性能开销 有(间接调用) 无(内联)
灵活性 高(运行时决定) 低(编译时决定)
代码膨胀 有(模板实例化)

在 Mini-Infer 中,我们结合使用两种多态:

  • IPlugin 接口:使用虚函数,支持运行时多态(注册表查找)。
  • CPUPlugin/CUDAPlugin 基类:使用 CRTP,消除样板代码。

2. CPUPlugin<Derived, ParamType> 设计

A. 类定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// mini_infer/operators/cpu_plugin.h

template <typename Derived, typename ParamType = PluginParam>
class CPUPlugin : public IPlugin {
public:
CPUPlugin() = default;
~CPUPlugin() override = default;

// 固定返回 CPU
core::DeviceType get_device_type() const noexcept override {
return core::DeviceType::CPU;
}

// CRTP 魔法:克隆
std::unique_ptr<IPlugin> clone() const override {
auto cloned = std::make_unique<Derived>(static_cast<const Derived&>(*this));
return cloned;
}

// 参数管理
const void* get_param_ptr() const noexcept override {
return param_ ? param_.get() : nullptr;
}

void set_param(std::shared_ptr<PluginParam> param) override {
param_ = std::dynamic_pointer_cast<ParamType>(param);
}

const ParamType* param() const {
return param_.get();
}

void set_typed_param(const ParamType& param) {
param_ = std::make_shared<ParamType>(param);
}

protected:
std::shared_ptr<ParamType> param_;
};

B. get_device_type() 的固定实现

1
2
3
core::DeviceType get_device_type() const noexcept override {
return core::DeviceType::CPU;
}

所有继承 CPUPlugin 的类都自动返回 CPU,无需重复实现。

C. clone() 的 CRTP 魔法

1
2
3
4
std::unique_ptr<IPlugin> clone() const override {
auto cloned = std::make_unique<Derived>(static_cast<const Derived&>(*this));
return cloned;
}

这是 CRTP 的核心价值

  1. Derived 是模板参数,在编译时已知。
  2. static_cast<const Derived&>(*this) 将基类引用转换为派生类引用。
  3. std::make_unique<Derived>(...) 调用派生类的拷贝构造函数。

没有 CRTP 的话,每个派生类都需要手动实现 clone()

1
2
3
4
5
6
7
8
9
10
11
12
// 没有 CRTP,需要手动实现
class ReLUCPUPlugin : public IPlugin {
std::unique_ptr<IPlugin> clone() const override {
return std::make_unique<ReLUCPUPlugin>(*this); // 样板代码!
}
};

class Conv2DCPUPlugin : public IPlugin {
std::unique_ptr<IPlugin> clone() const override {
return std::make_unique<Conv2DCPUPlugin>(*this); // 又是样板代码!
}
};

D. 参数管理

1
2
3
4
5
6
7
8
9
10
11
void set_param(std::shared_ptr<PluginParam> param) override {
param_ = std::dynamic_pointer_cast<ParamType>(param);
}

const ParamType* param() const {
return param_.get();
}

void set_typed_param(const ParamType& param) {
param_ = std::make_shared<ParamType>(param);
}
  • set_param:接受基类指针,动态转换为具体类型。
  • param():返回类型安全的参数指针。
  • set_typed_param:直接设置具体类型的参数。

3. SimpleCPUPlugin 简化版

对于无参数的算子(如 ReLU),我们提供更简单的基类:

1
2
3
4
5
6
template <typename Derived>
class SimpleCPUPlugin : public CPUPlugin<Derived, PluginParam> {
public:
SimpleCPUPlugin() = default;
~SimpleCPUPlugin() override = default;
};

使用示例

1
2
3
class ReLUCPUPlugin : public SimpleCPUPlugin<ReLUCPUPlugin> {
// 不需要处理参数
};

4. CUDAPlugin<Derived, ParamType> 设计

A. 类定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
// mini_infer/operators/cuda_plugin.h

template <typename Derived, typename ParamType = PluginParam>
class CUDAPlugin : public IPlugin {
public:
CUDAPlugin() = default;
~CUDAPlugin() override = default;

core::DeviceType get_device_type() const noexcept override {
return core::DeviceType::CUDA;
}

std::unique_ptr<IPlugin> clone() const override {
auto cloned = std::make_unique<Derived>(static_cast<const Derived&>(*this));
return cloned;
}

// 参数管理(与 CPUPlugin 相同)
const void* get_param_ptr() const noexcept ov ... }
void set_param(std::shared_ptr<PluginParam> param) override { ... }
const ParamType* param() const { ... }
void set_typed_param(const ParamType& param) { ... }

#ifdef MINI_INFER_USE_CUDA
// CUDA 特有辅助方法
static cudaStream_t get_cuda_stream(const PluginContext& context);
#endif

protected:
std::shared_ptr<ParamType> param_;
};

B. get_cuda_stream() 辅助方法

1
2
3
4
5
6
7
8
9
10
#ifdef MINI_INFER_USE_CUDA
static cudaStream_t get_cuda_stream(const PluginContext& context) {
if (!context.device_context) {
return nullptr;
}
auto* cuda_ctx = dynamic_cast<backends::cuda::CUDADeviceContext*>(
context.device_context);
return cuda_ctx ? cuda_ctx->stream() : nullptr;
}
#endif

这个辅助方法让派生类可以方便地获取 CUDA Stream。

C. cuDNN/cuBLAS Handle 访问

类似地,可以添加获取 cuDNN 和 cuBLAS Handle 的辅助方法:

1
2
3
4
5
6
7
8
9
10
11
static cudnnHandle_t get_cudnn_handle(const PluginContext& context) {
auto* cuda_ctx = dynamic_cast<backends::cuda::CUDADeviceContext*>(
context.device_context);
return cuda_ctx ? cuda_ctx->cudnn_handle() : nullptr;
}

static cublasHandle_t get_cublas_handle(const PluginContext& context) {
auto* cuda_ctx = dynamic_cast<backends::cuda::CUDADeviceContext*>(
context.device_context);
return cuda_ctx ? cuda_ctx->cublas_handle() : nullptr;
}

5. 实战:ReLUCPUPlugin 实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
// mini_infer/operators/cpu/relu_cpu.cpp

class ReLUCPUPlugin : public SimpleCPUPlugin<ReLUCPUPlugin> {
public:
// Identity
const char* get_plugin_type() const noexcept override { return "Relu"; }
core::OpType get_op_type() const noexcept override { return core::OpType::kRELU; }

// Shape Inference
core::Status infer_output_shapes(
const std::vector<core::Shape>& input_shapes,
std::vector<core::Shape>& output_shapes) const override {
if (input_shapes.empty()) {
return core::Status::ERROR_INVALID_ARGUMENT;
}
output_shapes = {input_shapes[0]}; // ReLU 不改变形状
return core::Status::SUCCESS;
}

// Execution
re::Status enqueue(
const std::vector<std::shared_ptr<core::Tensor>>& inputs,
std::vector<std::shared_ptr<core::Tensor>>& outputs,
const PluginContext& context) override {

const auto& input = inputs[0];
auto& output = outputs[0];

const float* in_data = static_cast<const float*>(input->data());
float* out_data = static_cast<float*>(output->data());
const size_t numel = input->shape().numel();

for (size_t i = 0; i < numel; ++i) {
out_data[i] = std::max(0.0f, in_data[i]);
}

return core::Status::SUCCESS;
}
};

// 注册
REGISTER_PLUGIN_SIMPLE(ReLUCPUPlugin, "Relu", kRELU, CPU)

注意

  1. 继承 SimpleCPUPlugin<ReLUCPUPlugin>(CRTP)。
  2. 只需实现 get_plugin_typeget_op_typeinfer_output_shapesenqueue
  3. clone()get_device_type() 等由基类自动提供。

6. 实战:Conv2DCPUPlugin 实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
// mini_infer/operators/cpu/conv2d_cpu.cpp

class Conv2DCPUPlugin : public CPUPlugin<Conv2DCPUPlugin, Conv2DParam> {
public:
// Identity
const char* get_plugin_type() const noexcept override { return "Conv"; }
core::OpType get_op_type() const noexcept override { return core::OpType::kCONV2D; }
int32_t get_nb_inputs() const noexcept override { return param_->use_bias ? 3 }

// Shape Inference
core::Statuutput_shapes(
const std::vector<core::Shape>& input_shapes,
std::vector<core::Shape>& output_shapes) const override {

const auto& input_shape = input_shapes[0];
const int64_t N = input_shape[0];
const int64_t C_in = input_shape[1];
const int64_t H_in = input_shape[2];
const int64_t W_in = input_shape[3];

const auto& weight_shape = input_shapes[1];
const int64_t C_out = weight_shape[0];

// 计算输出尺寸
const int64_t H_out = (H_in + 2 * param_->padding_h - param_->kernel_h) /
param_->stride_h + 1;
const int64_t W_out = (W_in + 2 * param_->padding_w - param_->kernel_w) /
param_->stride_w + 1;

output_shapes = {core::Shape({N, C_out, H_out, W_out})};
return core::Status::SUCCESS;
}

// Execution
core::Status enqueue(
const std::vector<std::shared_ptr<core::Tensor>>& inputs,
std::vector<std::shared_ptr<core::Tensor>>& outputs,
const PluginContext& context) override {

const auto& input = inputs[0];
const auto& weight = inputs[1];
const auto* bias = param_->use_bias ? inputs[2].get() : nullptr;
auto& output = outputs[0];

// im2col + GEMM 实现
// ... 调用底层 Kernel ...

// 激活融合
if (param_->activation == ActivationType::RELU) {
apply_relu(output);
}

return core::Status::SUCCESS;
}
};

// 注册
REGISTER_PLUGIN_SIMPLE(Conv2DCPUPlugin, "Conv", kCONV2D, CPU)

注意

  1. 继承 CPUPlugin<Conv2DCPUPlugin, Conv2DParam>(带参数类型)。
  2. 通过 param_-> 访问卷积参数。
  3. 支持激活融合(Conv + ReLU)。

7. CRTP 的优势总结

优势 说明
消除样板代码 clone()get_device_type() 等自动生成
类型安全
零运行时开销 没有虚函数调用开销(对于 CRTP 部分)
代码复用 公共逻辑集中在基类
易于扩展 添加新算子只需继承基类

8. 总结

本篇我们实现了 Mini-Infer 插件架构的 CRTP 基类:

  • CPUPlugin<Derived, ParamType>:CPU 插件的 CRTP 基类。
  • CUDAPlugin<Derived, ParamType>:CUDA 插件的 CRTP 基类。
  • SimpleCPUPlugin / SimpleCUDAPlugin:无参数算子的简化版。
  • clone() 的 CRTP 魔法:自动生成深拷贝代码。
  • 参数管理:类型安全的参数访问。

下一篇,我们将实现 PluginRegistry,看看如何通过静态注册机制自动发现和创建 Plugin。