Mini-Infer (21): 图优化实战 — TensorRT 风格的 FusionPass 与延迟删除

1. 设计哲学:为什么是“TensorRT 风格”?

在实现算子融合时,通常有两种流派:

  1. 替换流 (Replacement):发现 Conv + ReLU,删掉两个节点,创建一个新的 ConvReLU 算子节点。
    • 缺点:会导致算子数量爆炸(ConvReLU, ConvSigmoid, ConvTanh…)。
  2. 属性流 (Attribute/TensorRT-style):发现 Conv + ReLU保留 Conv 节点,只是给它设置一个 activation 属性,然后删掉 ReLU 节点
    • 优点:保持了算子库的简洁。Conv2D 算子内部根据 activation 属性决定是否在输出前执行激活逻辑。

Mini-Infer 坚定地选择了 TensorRT 风格。我们的 FusionPass 不会创建新节点,而是对现有节点进行“微创手术”。

2. 核心逻辑:图的“外科手术”与安全性

在对图进行修改(特别是删除节点)时,最容易犯的错误就是迭代器失效 (Iterator Invalidation)

❌ 错误的写法 (必崩)

1
2
3
4
5
for (const auto& [name, node] : graph->nodes()) { // 遍历 Map
if (match) {
graph->remove_node(target); // 删除元素 -> Map 结构改变 -> 迭代器失效 -> Segfault!
}
}

✅ 正确的写法 (Deferred Deletion / 延迟删除)

我们采用了 两阶段 (Two-Phase) 策略:

  1. Phase 1 (Mark): 遍历图,执行融合逻辑(修改算子属性、重连边),但不删除节点,只是将待删除的节点名字放入一个 std::unordered_set 中。
  2. Phase 2 (Sweep): 遍历结束后,统一删除所有标记的节点。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
core::Status FusionPass::apply(Graph* graph, int& num_modifications) {
// 延迟删除集合
std::unordered_set<std::string> nodes_to_delete;

// Phase 1: Mark & Fuse
for (const auto& [name, node] : graph->nodes()) {
// 如果节点已经被标记为删除,跳过
if (nodes_to_delete.count(name)) continue;

if (op_type == "Conv2D") {
// 在 try_fuse 中,我们将 activation 节点加入 nodes_to_delete
if (try_fuse_conv_activation(graph, node, nodes_to_delete)) {
num_modifications++;
}
}
}

// Phase 2: Sweep (安全删除)
for (const auto& node_name : nodes_to_delete) {
graph->remove_node(node_name);
}

return core::Status::SUCCESS;
}

3. try_fuse_conv_activation:融合实战

这是融合的核心逻辑。它需要精准地执行“图重写”。

步骤 A:校验

必须确认 Conv 只有一个输出,且输出节点是一个支持的激活函数。

步骤 B:设置属性

1
2
3
// 设置 Conv2D 内部的激活属性
auto conv_op = std::dynamic_pointer_cast<operators::Conv2D>(conv_node->get_operator());
conv_op->set_activation(act_type); // 需要 Conv2D 支持此接口

步骤 C:重连边 (Rewiring)

我们需要把原本连在 Activation 后面的所有节点,直接连到 Conv 后面。

1
2
3
4
5
6
7
// 1. 从 Conv 的 output 列表中移除 Activation
// 2. 遍历 Activation 的所有下游节点
for (const auto& output_node : activation_outputs) {
// 从下游节点的 input 列表中移除 Activation
// 将 Conv 连接到下游节点
graph->connect(conv_node->name(), output_node->name());
}

步骤 D:标记删除

1
2
// 关键:不要在这里调用 remove_node!
nodes_to_delete.insert(activation_node->name());

4. FusionPattern:为了未来的扩展

虽然我们目前硬编码了 Conv+Activation 的逻辑,但为了支持更复杂的模式(如 Conv+BN+ReLU),我们设计了 FusionPattern 结构体。

1
2
3
4
5
struct FusionPattern {
std::vector<std::string> operator_sequence; // {"Conv2D", "ReLU"}
std::string fused_operator_type; // "ConvActivation"
ValidatorFunc validator;
};

FusionPass::find_and_fuse 方法提供了一个通用的模式匹配框架。对于简单的序列模式,我们可以直接配置 FusionPattern 而无需编写专门的 C++ 代码。

5. 总结

我们成功实现了一个工业级、内存安全FusionPass

它不仅仅实现了 Conv+ReLU 的融合,更重要的是,它确立了图优化的黄金法则

  1. TensorRT 风格:属性融合优于节点替换。
  2. 延迟删除:两阶段处理优于边遍历边修改。

至此,我们的 Mini-Infer 引擎在功能和性能架构上都已经达到了一个新的高度:

  • 加载:支持 ONNX 模型导入。
  • 计算:拥有高性能的 im2col+GEMM 卷积。
  • 优化:具备自动、安全的算子融合能力。