Mini-Infer (32): 插件架构 (上) — IPlugin 接口设计
1. 问题背景:为什么需要插件架构?
在之前的 Mini-Infer 实现中,我们使用了 Operator + Kernel 双层抽象:
Operator :定义算子的元数据(输入/输出数量、形状推理)。
Kernel :实现具体的计算逻辑(CPU/CUDA)。
这种设计在早期工作良好,但随着功能增加,问题逐渐暴露:
A. 双层抽象的问题
1 2 3 4 5 Operator (元数据) ↓ KernelRegistry (查找) ↓ Kernel (计算)
维护成本高 :添加新算子需要修改两个地方。
一致性问题 :Operator 和 Kernel 的参数可能不同步。
查找开销 :每次执行都需要从 Registry 查找 Kernel。
B. 多后端支持的复杂性
1 2 3 REGISTER_KERNEL (Conv2D, CPU, Conv2DCPUKernel);REGISTER_KERNEL (Conv2D, CUDA, Conv2DCUDAKernel);
当后端增多时,注册代码会爆炸式增长。
C. TensorRT 的启示
TensorRT 使用 IPluginV2 接口,将元数据和计算逻辑统一在一个类中:
1 2 3 4 5 6 7 8 9 10 11 class IPluginV2 { virtual const char * getPluginType () const = 0 ; virtual int getNbOutputs () const = 0 ; virtual Dims getOutputDimensions (...) = 0 ; virtual int enqueue (...) = 0 ; };
Mini-Infer 借鉴了这一设计,实现了自己的 IPlugin 接口。
2. IPlugin 接口设计
A. 完整接口定义
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 class IPlugin {public : virtual ~IPlugin () = default ; virtual const char * get_plugin_type () const noexcept = 0 ; virtual core::OpType get_op_type () const noexcept = 0 ; virtual core::DeviceType get_device_type () const noexcept = 0 ; virtual int32_t get_nb_outputs () const noexcept { return 1 ; } virtual int32_t get_nb_inputs () const noexcept { return 1 ; } virtual core::Status infer_output_shapes ( const std::vector<core::Shape>& input_shapes, std::vector<core::Shape>& output_shapes) const = 0 ; virtual core::Status infer_output_metadata ( const std::vector<core::Shape>& input_shapes, const std::vector<core::DataType>& input_dtypes, std::vector<core::Shape>& output_shapes, std::vector<core::DataType>& output_dtypes) const ; virtual size_t get_workspace_size ( const std::vector<core::Shape>& input_shapes) const noexcept { return 0 ; } virtual core::Status initialize () { return core::Status::SUCCESS; } virtual void terminate () noexcept {} virtual core::Status enqueue ( const std::vector<std::shared_ptr<core::Tensor>>& inputs, std::vector<std::shared_ptr<core::Tensor>>& outputs, const PluginContext& context) = 0 ; virtual std::unique_ptr<IPlugin> clone () const = 0 ; virtual const void * get_param_ptr () const noexcept { return nullptr ; } virtual void set_param (std::shared_ptr<PluginParam> param) { (void )param; } };
B. Identity 方法组
1 2 3 virtual const char * get_plugin_type () const noexcept = 0 ;virtual core::OpType get_op_type () const noexcept = 0 ;virtual core::DeviceType get_device_type () const noexcept = 0 ;
这三个方法唯一标识一个 Plugin:
get_plugin_type() :返回字符串名称,如 "Conv2D"、"ReLU"。
get_op_type() :返回枚举值,用于快速比较。
get_device_type() :返回 CPU 或 CUDA。
注册表的 Key :(OpType, DeviceType) 二元组。
C. I/O Configuration 方法组
1 2 virtual int32_t get_nb_outputs () const noexcept { return 1 ; }virtual int32_t get_nb_inputs () const noexcept { return 1 ; }
大多数算子只有一个输出,所以提供默认实现。特殊算子(如 Split)可以覆盖。
D. Shape Inference 方法组
1 2 3 virtual core::Status infer_output_shapes ( const std::vector<core::Shape>& input_shapes, std::vector<core::Shape>& output_shapes) const = 0 ;
这是纯虚函数 ,每个 Plugin 必须实现。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 virtual core::Status infer_output_metadata ( const std::vector<core::Shape>& input_shapes, const std::vector<core::DataType>& input_dtypes, std::vector<core::Shape>& output_shapes, std::vector<core::DataType>& output_dtypes) const { auto status = infer_output_shapes (input_shapes, output_shapes); if (status != core::Status::SUCCESS) return status; const core::D inferred = input_dtypes.empty () ? core::DataType::FLOAT32 : input_dtypes[0 ]; output_dtypes.assign (output_shapes.size (), inferred); return core::Status::SUCCESS; }
infer_output_metadata 提供默认实现,大多数算子不需要覆盖。
E. Workspace 方法
1 2 virtual size_t get_workspace_size ( const std::vector<core::Shape>& input_shapes) const noexcept { return 0 ; }
返回执行时需要的临时内存大小。默认为 0(不需要 Workspace)。
F. Lifecycle 方法组
1 2 virtual core::Status initialize () { return core::Status::SUCCESS; }virtual void terminate () noexcept {}
initialize() :在首次执行前调用,用于分配资源、创建 cuDNN 描述符等。
terminate() :在 Plugin 销毁前调用,用于释放资源。
G. Execution 方法
1 2 3 4 virtual core::Status enqueue ( const std::vector<std::shared_ptr<core::Tensor>>& inputs, std::vector<std::shared_ptr<core::Tensor>>& outputs, const PluginContext& context) = 0 ;
这是核心计算方法 ,执行实际的算子逻辑。
H. Clone 方法
1 virtual std::unique_ptr<IPlugin> clone () const = 0 ;
创建 Plugin 的深拷贝。用于:
多个节点使用同一类型算子时,每个节点需要独立的 Plugin 实例。
支持参数不同的同类型算子。
I. Parameters 方法组
1 2 virtual const void * get_param_ptr () const noexcept { return nullptr ; }virtual void set_param (std::shared_ptr<PluginParam> param) { (void )param; }
用于获取和设置算子参数。无参数算子(如 ReLU)使用默认实现。
3. PluginContext 结构
1 2 3 4 5 struct PluginContext { backends::DeviceContext* device_context{nullptr }; std::shared_ptr<void > workspace; size_t workspace_size{0 }; };
PluginContext 包含执行时需要的运行时信息:
device_context :设备上下文(CPU 或 CUDA)。
workspace :临时内存
workspace_size :Workspace 大小。
4. PluginParam 参数体系
A. 基类定义
1 2 3 struct PluginParam { virtual ~PluginParam () = default ; };
所有参数结构体都继承自 PluginParam。
B. Conv2DParam
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 struct Conv2DParam : public PluginParam { int kernel_h{1 }; int kernel_w{1 }; int stride_h{1 }; int stride_w{1 }; int padding_h{0 }; int padding_w{0 }; int dilation_h{1 }; int dilation_w{1 }; int groups{1 }; bool use_bias{true }; ActivationType activation{ActivationType::NONE}; Conv2DParam () = default ; Conv2DParam (int kh, int kw, int sh, int sw, int ph, int pw, int g, bool bias) : kernel_h (kh), kernel_w (kw), stride_h (sh), stride_w (sw), padding_h (ph), padding_w (pw), groups (g), use_bias (bias) {} };
C. LinearParam
1 2 3 4 5 6 7 8 9 struct LinearParam : public PluginParam { int in_features{0 }; int out_features{0 }; bool use_bias{true }; LinearParam () = default ; LinearParam (int in_f, int out_f, bool bias) : in_features (in_f), out_features (out_f), use_bias (bias) {} };
D. PoolingParam
1 2 3 4 5 6 7 8 9 10 11 12 13 struct PoolingParam : public PluginParaolingType type{PoolingType::MAX}; int kernel_h{2 }; int kernel_w{2 }; int stride_h{2 }; int stride_w{2 }; int padding_h{0 }; int padding_w{0 }; PoolingParam () = default ; PoolingParam (PoolingType t, int kh, int kw, int sh, int sw, int ph, int pw) : type (t), kernel_h (kh), kernel_w (kw), stride_h (sh), stride_w (sw), padding_h (ph), padding_w (pw) {} };
E. 继承 vs 组合的设计选择
我们选择继承 而不是组合,原因是:
类型安全 :可以用 dynamic_cast 检查参数类型。
多态存储 :可以用 shared_ptr<PluginParam> 统一存储。
简单直观 :参数结构体本身就是数据容器,继承开销很小。
5. IPluginCreator 工厂接口
1 2 3 4 5 6 7 8 9 10 class IPluginCreator {public : virtual ~IPluginCreator () = default ; virtual const char * get_plugin_type () const noexcept = 0 ; virtual core::OpType get_op_type () const noexcept = 0 ; virtual core::DeviceType get_device_type () const noexcept = 0 ; virtual std::unique_ptr<IPlugin> create_plugin () const = 0 ; };
每个 Plugin 类型都有一个对应的 Creator,用于:
延迟创建 :只在需要时创建 Plugin 实例。
注册机制 :Creator 注册到 Registry,按需创建 Plugin。
参数传递 :创建后通过 set_param 设置参数。
6. 与 TensorRT IPluginV2 的对比
特性
TensorRT IPluginV2
Mini-Infer IPlugin
形状推理
getOutputDimensions
infer_output_shapes
执行
enqueue
enqueue
克隆
clone
clone
序列化
serialize / deserialize
不支持(暂时)
格式支持
supportsFormat
不支持(暂时)
动态形状
supportsFormatCombination
通过 Profile 支持
Mini-Infer 的 IPlugin 是 TensorRT IPluginV2 的简化版,保留了核心功能,去掉了序列化和格式协商等复杂特性。
7. 总结
本篇我们设计了 Mini-Infer 的插件架构核心——IPlugin 接口:
统一抽象 :将元数据和计算逻辑合并到一个类中。
Identity 方法 :get_plugin_type、get_op_type、get_device_type。
Shape Inference :infer_output_shapes、infer_output_metadata。
Execution :enqueue 是核心计算方法。
Clone :支持深拷贝,每个节点独立实例。
PluginParam 体系 :类型安全的参数传递。
IPluginCreator :工厂接口,延迟创建。
下一篇,我们将看看如何使用 CRTP(Curiously Recurring Template Pattern)实现 CPUPlugin 和 CUDAPlugin 基类,消除样板代码。