Mini-Infer (32): 插件架构 (上) — IPlugin 接口设计

1. 问题背景:为什么需要插件架构?

在之前的 Mini-Infer 实现中,我们使用了 Operator + Kernel 双层抽象:

  • Operator:定义算子的元数据(输入/输出数量、形状推理)。
  • Kernel:实现具体的计算逻辑(CPU/CUDA)。

这种设计在早期工作良好,但随着功能增加,问题逐渐暴露:

A. 双层抽象的问题

1
2
3
4
5
Operator (元数据)

KernelRegistry (查找)

Kernel (计算)
  1. 维护成本高:添加新算子需要修改两个地方。
  2. 一致性问题:Operator 和 Kernel 的参数可能不同步。
  3. 查找开销:每次执行都需要从 Registry 查找 Kernel。

B. 多后端支持的复杂性

1
2
3
// 旧架构:需要为每个后端注册 Kernel
REGISTER_KERNEL(Conv2D, CPU, Conv2DCPUKernel);
REGISTER_KERNEL(Conv2D, CUDA, Conv2DCUDAKernel);

当后端增多时,注册代码会爆炸式增长。

C. TensorRT 的启示

TensorRT 使用 IPluginV2 接口,将元数据和计算逻辑统一在一个类中:

1
2
3
4
5
6
7
8
9
10
11
class IPluginV2 {
// 元数据
virtual const char* getPluginType() const = 0;
virtual int getNbOutputs() const = 0;

// 形状推理
virtual Dims getOutputDimensions(...) = 0;

// 执行
virtual int enqueue(...) = 0;
};

Mini-Infer 借鉴了这一设计,实现了自己的 IPlugin 接口。


2. IPlugin 接口设计

A. 完整接口定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
// mini_infer/operators/plugin_base.h

class IPlugin {
public:
virtual ~IPlugin() = default;

// ========== Identity ==========
virtual const char* get_plugin_type() const noexcept = 0;
virtual core::OpType get_op_type() const noexcept = 0;
virtual core::DeviceType get_device_type() const noexcept = 0;

// ========== I/O Configuration ==========
virtual int32_t get_nb_outputs() const noexcept { return 1; }
virtual int32_t get_nb_inputs() const noexcept { return 1; }

// ========== Shape Inference ==========
virtual core::Status infer_output_shapes(
const std::vector<core::Shape>& input_shapes,
std::vector<core::Shape>& output_shapes) const = 0;

virtual core::Status infer_output_metadata(
const std::vector<core::Shape>& input_shapes,
const std::vector<core::DataType>& input_dtypes,
std::vector<core::Shape>& output_shapes,
std::vector<core::DataType>& output_dtypes) const;

// ========== Workspace ==========
virtual size_t get_workspace_size(
const std::vector<core::Shape>& input_shapes) const noexcept { return 0; }

// ========== Lifecycle ==========
virtual core::Status initialize() { return core::Status::SUCCESS; }
virtual void terminate() noexcept {}

// ========== Execution ==========
virtual core::Status enqueue(
const std::vector<std::shared_ptr<core::Tensor>>& inputs,
std::vector<std::shared_ptr<core::Tensor>>& outputs,
const PluginContext& context) = 0;

// ========== Clone ==========
virtual std::unique_ptr<IPlugin> clone() const = 0;

// ========== Parameters ==========
virtual const void* get_param_ptr() const noexcept { return nullptr; }
virtual void set_param(std::shared_ptr<PluginParam> param) { (void)param; }
};

B. Identity 方法组

1
2
3
virtual const char* get_plugin_type() const noexcept = 0;
virtual core::OpType get_op_type() const noexcept = 0;
virtual core::DeviceType get_device_type() const noexcept = 0;

这三个方法唯一标识一个 Plugin:

  • get_plugin_type():返回字符串名称,如 "Conv2D""ReLU"
  • get_op_type():返回枚举值,用于快速比较。
  • get_device_type():返回 CPUCUDA

注册表的 Key(OpType, DeviceType) 二元组。

C. I/O Configuration 方法组

1
2
virtual int32_t get_nb_outputs() const noexcept { return 1; }
virtual int32_t get_nb_inputs() const noexcept { return 1; }

大多数算子只有一个输出,所以提供默认实现。特殊算子(如 Split)可以覆盖。

D. Shape Inference 方法组

1
2
3
virtual core::Status infer_output_shapes(
const std::vector<core::Shape>& input_shapes,
std::vector<core::Shape>& output_shapes) const = 0;

这是纯虚函数,每个 Plugin 必须实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
virtual core::Status infer_output_metadata(
const std::vector<core::Shape>& input_shapes,
const std::vector<core::DataType>& input_dtypes,
std::vector<core::Shape>& output_shapes,
std::vector<core::DataType>& output_dtypes) const {
// 默认实现:推理形状 + 传播第一个输入的 dtype
auto status = infer_output_shapes(input_shapes, output_shapes);
if (status != core::Status::SUCCESS) return status;

const core::D inferred =
input_dtypes.empty() ? core::DataType::FLOAT32 : input_dtypes[0];
output_dtypes.assign(output_shapes.size(), inferred);
return core::Status::SUCCESS;
}

infer_output_metadata 提供默认实现,大多数算子不需要覆盖。

E. Workspace 方法

1
2
virtual size_t get_workspace_size(
const std::vector<core::Shape>& input_shapes) const noexcept { return 0; }

返回执行时需要的临时内存大小。默认为 0(不需要 Workspace)。

F. Lifecycle 方法组

1
2
virtual core::Status initialize() { return core::Status::SUCCESS; }
virtual void terminate() noexcept {}
  • initialize():在首次执行前调用,用于分配资源、创建 cuDNN 描述符等。
  • terminate():在 Plugin 销毁前调用,用于释放资源。

G. Execution 方法

1
2
3
4
virtual core::Status enqueue(
const std::vector<std::shared_ptr<core::Tensor>>& inputs,
std::vector<std::shared_ptr<core::Tensor>>& outputs,
const PluginContext& context) = 0;

这是核心计算方法,执行实际的算子逻辑。

H. Clone 方法

1
virtual std::unique_ptr<IPlugin> clone() const = 0;

创建 Plugin 的深拷贝。用于:

  • 多个节点使用同一类型算子时,每个节点需要独立的 Plugin 实例。
  • 支持参数不同的同类型算子。

I. Parameters 方法组

1
2
virtual const void* get_param_ptr() const noexcept { return nullptr; }
virtual void set_param(std::shared_ptr<PluginParam> param) { (void)param; }

用于获取和设置算子参数。无参数算子(如 ReLU)使用默认实现。


3. PluginContext 结构

1
2
3
4
5
struct PluginContext {
backends::DeviceContext* device_context{nullptr};
std::shared_ptr<void> workspace;
size_t workspace_size{0};
};

PluginContext 包含执行时需要的运行时信息:

  • device_context:设备上下文(CPU 或 CUDA)。
  • workspace:临时内存
  • workspace_size:Workspace 大小。

4. PluginParam 参数体系

A. 基类定义

1
2
3
struct PluginParam {
virtual ~PluginParam() = default;
};

所有参数结构体都继承自 PluginParam

B. Conv2DParam

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
struct Conv2DParam : public PluginParam {
int kernel_h{1};
int kernel_w{1};
int stride_h{1};
int stride_w{1};
int padding_h{0};
int padding_w{0};
int dilation_h{1};
int dilation_w{1};
int groups{1};
bool use_bias{true};
ActivationType activation{ActivationType::NONE};

Conv2DParam() = default;
Conv2DParam(int kh, int kw, int sh, int sw, int ph, int pw, int g, bool bias)
: kernel_h(kh), kernel_w(kw), stride_h(sh), stride_w(sw),
padding_h(ph), padding_w(pw), groups(g), use_bias(bias) {}
};

C. LinearParam

1
2
3
4
5
6
7
8
9
struct LinearParam : public PluginParam {
int in_features{0};
int out_features{0};
bool use_bias{true};

LinearParam() = default;
LinearParam(int in_f, int out_f, bool bias)
: in_features(in_f), out_features(out_f), use_bias(bias) {}
};

D. PoolingParam

1
2
3
4
5
6
7
8
9
10
11
12
13
struct PoolingParam : public PluginParaolingType type{PoolingType::MAX};
int kernel_h{2};
int kernel_w{2};
int stride_h{2};
int stride_w{2};
int padding_h{0};
int padding_w{0};

PoolingParam() = default;
PoolingParam(PoolingType t, int kh, int kw, int sh, int sw, int ph, int pw)
: type(t), kernel_h(kh), kernel_w(kw), stride_h(sh), stride_w(sw),
padding_h(ph), padding_w(pw) {}
};

E. 继承 vs 组合的设计选择

我们选择继承而不是组合,原因是:

  1. 类型安全:可以用 dynamic_cast 检查参数类型。
  2. 多态存储:可以用 shared_ptr<PluginParam> 统一存储。
  3. 简单直观:参数结构体本身就是数据容器,继承开销很小。

5. IPluginCreator 工厂接口

1
2
3
4
5
6
7
8
9
10
class IPluginCreator {
public:
virtual ~IPluginCreator() = default;

virtual const char* get_plugin_type() const noexcept = 0;
virtual core::OpType get_op_type() const noexcept = 0;
virtual core::DeviceType get_device_type() const noexcept = 0;

virtual std::unique_ptr<IPlugin> create_plugin() const = 0;
};

每个 Plugin 类型都有一个对应的 Creator,用于:

  1. 延迟创建:只在需要时创建 Plugin 实例。
  2. 注册机制:Creator 注册到 Registry,按需创建 Plugin。
  3. 参数传递:创建后通过 set_param 设置参数。

6. 与 TensorRT IPluginV2 的对比

特性 TensorRT IPluginV2 Mini-Infer IPlugin
形状推理 getOutputDimensions infer_output_shapes
执行 enqueue enqueue
克隆 clone clone
序列化 serialize / deserialize 不支持(暂时)
格式支持 supportsFormat 不支持(暂时)
动态形状 supportsFormatCombination 通过 Profile 支持

Mini-Infer 的 IPlugin 是 TensorRT IPluginV2 的简化版,保留了核心功能,去掉了序列化和格式协商等复杂特性。


7. 总结

本篇我们设计了 Mini-Infer 的插件架构核心——IPlugin 接口:

  • 统一抽象:将元数据和计算逻辑合并到一个类中。
  • Identity 方法get_plugin_typeget_op_typeget_device_type
  • Shape Inferenceinfer_output_shapesinfer_output_metadata
  • Executionenqueue 是核心计算方法。
  • Clone:支持深拷贝,每个节点独立实例。
  • PluginParam 体系:类型安全的参数传递。
  • IPluginCreator:工厂接口,延迟创建。

下一篇,我们将看看如何使用 CRTP(Curiously Recurring Template Pattern)实现 CPUPluginCUDAPlugin 基类,消除样板代码。