读 ncnn 源码(XXXIX):消除冗余塑形(下)——eliminate_reshape_before_binaryop

ncnnoptimize 的图优化 Pass 中,对 ReshapeFlatten 等塑形层(Shape-manipulation layers)的消除是一个重要主题。上一篇我们看到了 GlobalAveragePoolingInnerProduct 因其输出在语义上已是“扁平”的,从而使其后的 Reshape/Flatten 变得多余。

本篇分析的 eliminate_reshape_before_binaryop 则利用了 BinaryOp 层自身实现的健壮性(Robustness),来消除其输入端的多余 Reshape 操作。

TL;DR

  1. 目标: 识别并消除 ... -> Reshape(to 1x1xC) -> BinaryOp 这样的模式。
  2. 核心原理 (BinaryOp 的“形状无关性”): ncnn::BinaryOp(如 Add, Mul 等)在执行两个张量的逐元素操作时,其 forward 实现在根本上是“形状无关”的。它主要关心两个张量的总元素数量 total() 是否相等
  3. 冗余分析: 只要两个输入张量的总元素数相同,BinaryOp 就能正确地逐个处理它们(通过各自的 w, h, c, cstep 遍历)。因此,一个 Reshape 层如果仅仅是将一个张量从 [W, H, C] 拍平为 [1, 1, W*H*C](总元素数不变),而改变其内存中的数据顺序,那么这个 Reshape 操作对于 BinaryOp 来说就是完全冗余的。
  4. 模式匹配: 遍历 layers,查找 Reshape 层,并检查其是否满足“拍平”条件(reshape->w == 1reshape->h == 1)。然后查找紧随其后的、消费该 blobBinaryOp
  5. 图结构修改 (短路): 执行标准的“短路”操作。将 Reshape 层的输入 blob (bottom_blob_index_final) 直接连接为 BinaryOp 层的输入,并标记 Reshape 层为 "ncnnfused"
  6. 效果: 移除了一个不产生任何计算、仅改变元数据的冗余层,简化了计算图,减少了层调度开销。

1. 动机:BinaryOp 的“形状无关性”

ncnn 中的 BinaryOp 层(当 with_scalar=0 时)被设计用来处理两个张量的逐元素运算。虽然它支持复杂的广播(Broadcasting)规则,但在最常见的情况下——两个输入张量具有相同的总元素数——它的 forward 实现本质上是一个遍历 total() 次的循环。

例如,一个 BinaryOp(Add) 操作,它需要计算 C = A + B

  • 如果 A 的形状是 [224, 224, 3] (total=150528)
  • B 的形状是 [150528, 1, 1] (total=150528)

BinaryOpforward 内核完全有能力处理这种情况。它会各自使用 ABw, h, cstep 来迭代各自的数据指针,执行 total() 次加法。

因此,如果一个 Reshape 层的唯一作用是将 A[224, 224, 3] 变为 [150528, 1, 1](或 [1, 1, 150528]),那么这个 Reshape 层对于后续的 BinaryOp 来说是完全多余的。ncnnoptimize 正是抓住了这一点来进行优化。


2. 源码解析:eliminate_reshape_before_binaryop

该函数的实现逻辑清晰地展现了这一思路:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
int NetOptimize::eliminate_reshape_before_binaryop()
{
const size_t layer_count = layers.size();
for (size_t i = 0; i < layer_count; i++)
{
// 1. 模式匹配:找到 Reshape 层
if (layers[i]->type != "Reshape")
continue;

ncnn::Reshape* reshape = (ncnn::Reshape*)layers[i];

// 2. 核心条件:检查是否为“拍平”操作 (输出为 1x1xC_total)
if (reshape->w != 1 || reshape->h != 1)
continue; // 不是拍平为 1x1x... 的操作,跳过
if (!reshape->shape_expr.empty())
continue; // 暂不支持动态 shape

// 3. 模式匹配:查找紧随其后的 BinaryOp
int top_blob_index = layers[i]->tops[0];
size_t j = i + 1;
for (; j < layer_count; j++)
{
if (layers[j]->type != "BinaryOp") continue;
if (layers[j]->bottoms.size() != 2) continue;
// 确认 BinaryOp 消费了 Reshape 的输出
if (layers[j]->bottoms[0] == top_blob_index || layers[j]->bottoms[1] == top_blob_index)
break;
}
if (j == layer_count) continue; // 未找到

ncnn::BinaryOp* binaryop = (ncnn::BinaryOp*)layers[j];
fprintf(stderr, "eliminate_reshape_before_binaryop %s %s\n", reshape->name.c_str(), binaryop->name.c_str());

// 4. 图结构修改 (短路)

// a) 获取 Reshape 层的输入 blob
int bottom_blob_index_final = reshape->bottoms[0];

// b) 将 BinaryOp 的输入直接指向 Reshape 层的输入 blob
if (binaryop->bottoms[0] == top_blob_index)
binaryop->bottoms[0] = bottom_blob_index_final;
if (binaryop->bottoms[1] == top_blob_index)
binaryop->bottoms[1] = bottom_blob_index_final;

// c) 更新 blob 的消费者信息
blobs[bottom_blob_index_final].consumer = j; // 生产者是 j (BinaryOp)

// d) 标记 Reshape 层为无效
reshape->type = "ncnnfused";
}

return 0;
}

3. 结语

eliminate_reshape_before_binaryopncnnoptimize 中又一个基于语义等价的算子消除 Pass。它与 eliminate..._after_global_pooling 系列 Pass 互为补充,共同清理了计算图中因形状操作而引入的冗余节点。

此优化利用了 BinaryOpforward 实现的健壮性——即在处理逐元素操作时,对输入的 (W, H, C) 维度不敏感,只关心 total() 元素总数。通过消除这些不必要的“视图转换”层,ncnnoptimize 进一步简化了计算图,减少了层调度开销,使网络结构更加精简。