Mini-Infer (21): 图优化实战 — TensorRT 风格的 `FusionPass` 与延迟删除
Mini-Infer (21): 图优化实战 — TensorRT 风格的 FusionPass 与延迟删除
1. 设计哲学:为什么是“TensorRT 风格”?
在实现算子融合时,通常有两种流派:
- 替换流 (Replacement):发现
Conv + ReLU,删掉两个节点,创建一个新的ConvReLU算子节点。- 缺点:会导致算子数量爆炸(
ConvReLU,ConvSigmoid,ConvTanh…)。
- 缺点:会导致算子数量爆炸(
- 属性流 (Attribute/TensorRT-style):发现
Conv + ReLU,保留Conv节点,只是给它设置一个activation属性,然后删掉ReLU节点。- 优点:保持了算子库的简洁。
Conv2D算子内部根据activation属性决定是否在输出前执行激活逻辑。
- 优点:保持了算子库的简洁。
Mini-Infer 坚定地选择了 TensorRT 风格。我们的 FusionPass 不会创建新节点,而是对现有节点进行“微创手术”。
2. 核心逻辑:图的“外科手术”与安全性
在对图进行修改(特别是删除节点)时,最容易犯的错误就是迭代器失效 (Iterator Invalidation)。
❌ 错误的写法 (必崩)
1 | for (const auto& [name, node] : graph->nodes()) { // 遍历 Map |
✅ 正确的写法 (Deferred Deletion / 延迟删除)
我们采用了 两阶段 (Two-Phase) 策略:
- Phase 1 (Mark): 遍历图,执行融合逻辑(修改算子属性、重连边),但不删除节点,只是将待删除的节点名字放入一个
std::unordered_set中。 - Phase 2 (Sweep): 遍历结束后,统一删除所有标记的节点。
1 | core::Status FusionPass::apply(Graph* graph, int& num_modifications) { |
3. try_fuse_conv_activation:融合实战
这是融合的核心逻辑。它需要精准地执行“图重写”。
步骤 A:校验
必须确认 Conv 只有一个输出,且输出节点是一个支持的激活函数。
步骤 B:设置属性
1 | // 设置 Conv2D 内部的激活属性 |
步骤 C:重连边 (Rewiring)
我们需要把原本连在 Activation 后面的所有节点,直接连到 Conv 后面。
1 | // 1. 从 Conv 的 output 列表中移除 Activation |
步骤 D:标记删除
1 | // 关键:不要在这里调用 remove_node! |
4. FusionPattern:为了未来的扩展
虽然我们目前硬编码了 Conv+Activation 的逻辑,但为了支持更复杂的模式(如 Conv+BN+ReLU),我们设计了 FusionPattern 结构体。
1 | struct FusionPattern { |
FusionPass::find_and_fuse 方法提供了一个通用的模式匹配框架。对于简单的序列模式,我们可以直接配置 FusionPattern 而无需编写专门的 C++ 代码。
5. 总结
我们成功实现了一个工业级、内存安全的 FusionPass。
它不仅仅实现了 Conv+ReLU 的融合,更重要的是,它确立了图优化的黄金法则:
- TensorRT 风格:属性融合优于节点替换。
- 延迟删除:两阶段处理优于边遍历边修改。
至此,我们的 Mini-Infer 引擎在功能和性能架构上都已经达到了一个新的高度:
- 加载:支持 ONNX 模型导入。
- 计算:拥有高性能的 im2col+GEMM 卷积。
- 优化:具备自动、安全的算子融合能力。
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 James的成长之路!
评论





