Mini-Infer (8): Im2Col算法完全讲解

🎯 核心概念:为什么需要Im2Col?

问题:卷积计算很慢

朴素卷积需要7层嵌套循环:

1
2
3
4
5
6
7
8
9
// 超级慢!缓存不友好
for (batch)
for (out_channel)
for (in_channel)
for (kernel_h)
for (kernel_w)
for (out_h)
for (out_w)
output += input * weight

解决方案:转换为矩阵乘法

Im2Col的魔法

1
2
3
4
5
6
7
8
卷积运算 = 矩阵乘法

Output = Conv(Input, Weight)
↓ 转换
Output = Weight × col_buffer

然后用高度优化的GEMM库(如MKL)计算
→ 速度提升5-10倍!

📊 具体例子:一步步理解

输入参数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
输入图像(灰度图):
channels = 1
height = 4
width = 4

数据(线性存储):
[1, 2, 3, 4,
5, 6, 7, 8,
9, 10, 11, 12,
13, 14, 15, 16]

可视化为矩阵:
┌────────────────┐
│ 1 2 3 4 │
│ 5 6 7 8 │
│ 9 10 11 12 │
│ 13 14 15 16 │
└────────────────┘

卷积核参数:
kernel_h = 2, kernel_w = 2
stride_h = 1, stride_w = 1
padding_h = 0, padding_w = 0
dilation_h = 1, dilation_w = 1

输出尺寸:
out_height = (4 - 2) / 1 + 1 = 3
out_width = (4 - 2) / 1 + 1 = 3
→ 共9个输出位置(3×3网格)

🔍 滑动窗口可视化

卷积的本质是滑动窗口,Im2Col就是把这些窗口"展平":

位置0(左上角,oh=0, ow=0)

1
2
3
4
5
6
7
8
9
10
输入图像:
┌─────┐────────
│ 1 2│ 3 4
│ 5 6│ 7 8
└─────┘────────
9 10 11 12
13 14 15 16

提取的2×2窗口: [1, 2, 5, 6]
这4个数变成col_buffer的第1列

位置1(第1行第2列,oh=0, ow=1)

1
2
3
4
5
6
7
8
9
10
输入图像:
1 ┌─────┐────
│ 2 3│ 4
│ 6 7│ 8
──└─────┘────
9 10 11 12
13 14 15 16

提取的2×2窗口: [2, 3, 6, 7]
这4个数变成col_buffer的第2列

位置2(第1行第3列,oh=0, ow=2)

1
2
3
4
5
6
7
8
9
10
输入图像:
1 2 ┌─────┐
│ 3 4│
│ 7 8│
──────└─────┘
9 10 11 12
13 14 15 16

提取的2×2窗口: [3, 4, 7, 8]
这4个数变成col_buffer的第3列

位置3-8(类似地…)

1
2
3
4
5
位置3 (oh=1, ow=0):     位置4 (oh=1, ow=1):     位置5 (oh=1, ow=2):
[5, 6, 9, 10] [6, 7, 10, 11] [7, 8, 11, 12]

位置6 (oh=2, ow=0): 位置7 (oh=2, ow=1): 位置8 (oh=2, ow=2):
[9, 10, 13, 14] [10, 11, 14, 15] [11, 12, 15, 16]

📐 col_buffer的布局

关键理解:col_buffer是一个矩阵,形状为 [卷积核元素数, 输出位置数]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
形状: [kernel_h × kernel_w, out_h × out_w]
= [2 × 2, 3 × 3]
= [4, 9]

布局(每列是一个输出位置):

位置0 位置1 位置2 位置3 位置4 位置5 位置6 位置7 位置8
┌────────────────────────────────────────────────────────────┐
kh=0, │ 1 2 3 5 6 7 9 10 11 │ ← 卷积核[0,0]
kw=0 └────────────────────────────────────────────────────────────┘
┌────────────────────────────────────────────────────────────┐
kh=0, │ 2 3 4 6 7 8 10 11 12 │ ← 卷积核[0,1]
kw=1 └────────────────────────────────────────────────────────────┘
┌────────────────────────────────────────────────────────────┐
kh=1, │ 5 6 7 9 10 11 13 14 15 │ ← 卷积核[1,0]
kw=0 └────────────────────────────────────────────────────────────┘
┌────────────────────────────────────────────────────────────┐
kh=1, │ 6 7 8 10 11 12 14 15 16 │ ← 卷积核[1,1]
kw=1 └────────────────────────────────────────────────────────────┘

理解方式

  • 每一列 = 一个输出位置需要的所有输入数据
  • 每一行 = 卷积核的一个元素在所有位置的输入值

🔢 代码逐行解析

循环结构总览

1
2
3
4
5
6
7
for (c = 0; c < channels; ++c)           // 外层:遍历输入通道
for (kh = 0; kh < kernel_h; ++kh) // 中层:遍历卷积核高度
for (kw = 0; kw < kernel_w; ++kw) // 中层:遍历卷积核宽度
【计算这个卷积核元素对应的行号】
for (oh = 0; oh < out_height; ++oh) // 内层:遍历输出高度
for (ow = 0; ow < out_width; ++ow) // 内层:遍历输出宽度
【提取并填充数据到col_buffer】

关键变量详解

1. channel_size

1
2
int channel_size = height * width;
// 例如: 4×4图像 → channel_size = 16

作用:计算输入索引时,跳过多个通道的数据


2. input_row_startinput_col_start

1
2
int input_row_start = -padding_h + kh * dilation_h;
int input_col_start = -padding_w + kw * dilation_w;

含义:当前卷积核元素在输入图像中的起始位置(考虑padding和dilation)

示例1(无padding,无dilation):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
kh=0, kw=0:
input_row_start = 0 + 0×1 = 0
input_col_start = 0 + 0×1 = 0
→ 卷积核左上角对应输入图像的(0,0)位置

kh=0, kw=1:
input_row_start = 0
input_col_start = 0 + 1×1 = 1
→ 卷积核右上角对应输入图像的(0,1)位置

kh=1, kw=0:
input_row_start = 0 + 1×1 = 1
input_col_start = 0
→ 卷积核左下角对应输入图像的(1,0)位置

示例2(padding=1,无dilation):

1
2
3
4
kh=0, kw=0:
input_row_start = -1 + 0×1 = -1
input_col_start = -1 + 0×1 = -1
→ 起始位置在padding区域(负坐标)

示例3(无padding,dilation=2):

1
2
3
4
5
6
7
8
kh=0, kw=0:
input_row_start = 0
input_col_start = 0

kh=1, kw=1:
input_row_start = 0 + 1×2 = 2
input_col_start = 0 + 1×2 = 2
→ 膨胀卷积,卷积核元素间隔变大

3. col_idx(行索引)

1
int col_idx = c * kernel_h * kernel_w + kh * kernel_w + kw;

含义:当前卷积核元素在col_buffer中的行号

计算逻辑:将三维坐标 (c, kh, kw) 转换为一维索引

示例(channels=2, kernel=2×2):

1
2
3
4
5
6
7
8
c=0, kh=0, kw=0 → col_idx = 0×4 + 0×2 + 0 = 0  (第1行)
c=0, kh=0, kw=1 → col_idx = 0×4 + 0×2 + 1 = 1 (第2行)
c=0, kh=1, kw=0 → col_idx = 0×4 + 1×2 + 0 = 2 (第3行)
c=0, kh=1, kw=1 → col_idx = 0×4 + 1×2 + 1 = 3 (第4行)
c=1, kh=0, kw=0 → col_idx = 1×4 + 0×2 + 0 = 4 (第5行)
c=1, kh=0, kw=1 → col_idx = 1×4 + 0×2 + 1 = 5 (第6行)
c=1, kh=1, kw=0 → col_idx = 1×4 + 1×2 + 0 = 6 (第7行)
c=1, kh=1, kw=1 → col_idx = 1×4 + 1×2 + 1 = 7 (第8行)

规律:每个通道占 kernel_h × kernel_w


4. input_rowinput_col(实际输入位置)

1
2
int input_row = input_row_start + oh * stride_h;
int input_col = input_col_start + ow * stride_w;

含义:对于当前输出位置 (oh, ow),卷积核元素 (kh, kw) 对应的输入坐标

示例(stride=1):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
输出位置(0,0):
kh=0, kw=0: input_row_start=0, input_col_start=0
input_row = 0+0×1 = 0, input_col = 0+0×1 = 0
→ 输入坐标(0,0)

输出位置(0,1):
kh=0, kw=0: input_row = 0+0×1 = 0, input_col = 0+1×1 = 1
→ 输入坐标(0,1)

输出位置(1,1):
kh=0, kw=0: input_row = 0+1×1 = 1, input_col = 0+1×1 = 1
→ 输入坐标(1,1)

输出位置(1,1):
kh=1, kw=1: input_row_start=1, input_col_start=1
input_row = 1+1×1 = 2, input_col = 1+1×1 = 2
→ 输入坐标(2,2)

可视化(输出位置(1,1),stride=1):

1
2
3
4
5
6
7
8
9
10
11
12
输入图像:           卷积核:
1 2 3 4 ┌─────┐
5 6 7 8 │ x x│ kh=0,kw=0 → (1,1)
9 10 11 12 │ x x│ kh=0,kw=1 → (1,2)
13 14 15 16 └─────┘ kh=1,kw=0 → (2,1)
kh=1,kw=1 → (2,2)
放置在(1,1):
1 2 3 4
5 ┌─────┐
│ 6 7│
│10 11│
└─────┘

5. col_buffer_idx(线性索引)

1
2
int col_buffer_idx = col_idx * out_height * out_width 
+ oh * out_width + ow;

含义:二维矩阵 col_buffer[col_idx][列号] 的一维存储位置

计算逻辑

  • col_idx = 行号
  • oh * out_width + ow = 列号(二维坐标转一维)
  • col_idx * out_height * out_width = 跳过前面的行

示例(out_height=3, out_width=3):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
col_idx=0, oh=0, ow=0:
col_buffer_idx = 0×9 + 0×3 + 0 = 0 (第0行第0列)

col_idx=0, oh=0, ow=1:
col_buffer_idx = 0×9 + 0×3 + 1 = 1 (第0行第1列)

col_idx=0, oh=2, ow=2:
col_buffer_idx = 0×9 + 2×3 + 2 = 8 (第0行第8列)

col_idx=1, oh=0, ow=0:
col_buffer_idx = 1×9 + 0×3 + 0 = 9 (第1行第0列)

col_idx=3, oh=2, ow=2:
col_buffer_idx = 3×9 + 2×3 + 2 = 35 (第3行第8列)

可视化

1
2
3
col_buffer布局(线性存储):
[0][1][2][3][4][5][6][7][8]│[9][10][11]...[17]│[18]...│[27]...[35]
└─────── 第0行 ─────────────┘└──── 第1行 ──────┘ 第2行 └── 第3行 ──┘

6. input_idx(输入线性索引)

1
int input_idx = c * channel_size + input_row * width + input_col;

含义:三维输入 input[c][input_row][input_col] 的一维存储位置

计算逻辑

  • c * channel_size = 跳过前面的通道
  • input_row * width = 当前行的起始位置
  • + input_col = 列偏移

示例(channels=1, height=4, width=4):

1
2
3
4
5
6
7
8
9
10
11
c=0, input_row=0, input_col=0:
input_idx = 0×16 + 0×4 + 0 = 0 (值=1)

c=0, input_row=0, input_col=1:
input_idx = 0×16 + 0×4 + 1 = 1 (值=2)

c=0, input_row=1, input_col=2:
input_idx = 0×16 + 1×4 + 2 = 6 (值=7)

c=0, input_row=3, input_col=3:
input_idx = 0×16 + 3×4 + 3 = 15 (值=16)

如果channels=2

1
2
c=1, input_row=0, input_col=0:
input_idx = 1×16 + 0×4 + 0 = 16 (第2个通道的第1个元素)

7. 边界检查和数据复制

1
2
3
4
5
6
7
8
9
if (input_row >= 0 && input_row < height &&
input_col >= 0 && input_col < width) {
// 在有效范围内 → 复制数据
int input_idx = c * channel_size + input_row * width + input_col;
col_buffer[col_buffer_idx] = input[input_idx];
} else {
// 超出边界(padding区域)→ 填充0
col_buffer[col_buffer_idx] = T(0);
}

作用

  1. 有效区域:从input复制实际数据
  2. Padding区域:填充0(模拟zero-padding)

示例(padding=1):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
输入(4×4):          加padding后(6×6):
┌────────────┐ ┌──────────────────┐
│ 1 2 3 4 │ │ 0 0 0 0 0 0 │ ← padding行
│ 5 6 7 8 │ │ 0 1 2 3 4 0 │
│ 9 10 11 12 │ → │ 0 5 6 7 8 0 │
│13 14 15 16 │ │ 0 9 10 11 12 0 │
└────────────┘ │ 0 13 14 15 16 0 │
│ 0 0 0 0 0 0 │ ← padding行
└──────────────────┘
↑ ↑
padding列 padding列

当 input_row=-1 或 input_col=-1 时 → 填充0
当 input_row=6 或 input_col=6 时 → 填充0

🎬 完整执行流程示例

让我们跟踪第1个输出位置的数据填充:

输出位置(0,0),oh=0, ow=0

1
2
3
4
5
6
7
输入图像:
┌─────┐────────
│ 1 2│ 3 4 ← 需要提取这个2×2窗口
│ 5 6│ 7 8
└─────┘────────
9 10 11 12
13 14 15 16

轮次1:c=0, kh=0, kw=0

1
2
3
4
5
6
7
8
9
10
col_idx = 0×4 + 0×2 + 0 = 0
input_row_start = 0, input_col_start = 0

oh=0, ow=0:
input_row = 0 + 0×1 = 0
input_col = 0 + 0×1 = 0
col_buffer_idx = 0×9 + 0×3 + 0 = 0
input_idx = 0×16 + 0×4 + 0 = 0

col_buffer[0] = input[0] = 1

轮次2:c=0, kh=0, kw=1

1
2
3
4
5
6
7
8
9
10
col_idx = 0×4 + 0×2 + 1 = 1
input_row_start = 0, input_col_start = 1

oh=0, ow=0:
input_row = 0 + 0×1 = 0
input_col = 1 + 0×1 = 1
col_buffer_idx = 1×9 + 0×3 + 0 = 9
input_idx = 0×16 + 0×4 + 1 = 1

col_buffer[9] = input[1] = 2

轮次3:c=0, kh=1, kw=0

1
2
3
4
5
6
7
8
9
10
col_idx = 0×4 + 1×2 + 0 = 2
input_row_start = 1, input_col_start = 0

oh=0, ow=0:
input_row = 1 + 0×1 = 1
input_col = 0 + 0×1 = 0
col_buffer_idx = 2×9 + 0×3 + 0 = 18
input_idx = 0×16 + 1×4 + 0 = 4

col_buffer[18] = input[4] = 5

轮次4:c=0, kh=1, kw=1

1
2
3
4
5
6
7
8
9
10
col_idx = 0×4 + 1×2 + 1 = 3
input_row_start = 1, input_col_start = 1

oh=0, ow=0:
input_row = 1 + 0×1 = 1
input_col = 1 + 0×1 = 1
col_buffer_idx = 3×9 + 0×3 + 0 = 27
input_idx = 0×16 + 1×4 + 1 = 5

col_buffer[27] = input[5] = 6

结果:输出位置(0,0)的col_buffer第0列 = [1, 2, 5, 6]


🚀 为什么这样做有效?

卷积 → 矩阵乘法的转换

假设我们有:

  • Weight矩阵 (卷积核): [out_channels, in_channels × kh × kw]
  • col_buffer矩阵: [in_channels × kh × kw, out_h × out_w]

矩阵乘法

1
2
3
Output = Weight × col_buffer

Output形状: [out_channels, out_h × out_w]

每个输出元素的计算

1
2
3
4
Output[i][j] = Weight[i] · col_buffer[:][j]
= Weight的第i行 点乘 col_buffer的第j列
= Σ(卷积核元素 × 对应输入值)
= 这正是卷积的定义!

示例(1个输出通道,1个输入通道):

1
2
3
4
5
6
7
8
9
Weight (卷积核展平):
[w00, w01, w10, w11] (1×4)

col_buffer的第0列 (位置0的输入):
[1, 2, 5, 6]^T (4×1)

点积:
w00×1 + w01×2 + w10×5 + w11×6
= 卷积核在位置(0,0)的输出值

💡 总结

Im2Col的6层循环做了什么?

1
2
3
4
5
6
7
8
1. for (c):        遍历所有输入通道
2. for (kh): 遍历卷积核的高度维度
3. for (kw): 遍历卷积核的宽度维度
→ 确定col_buffer的"行"(哪个卷积核元素)
4. for (oh): 遍历输出的高度维度
5. for (ow): 遍历输出的宽度维度
→ 确定col_buffer的"列"(哪个输出位置)
6. 【提取输入值,填充col_buffer】

关键公式记忆

变量 公式 含义
col_idx c×Kh×Kw + kh×Kw + kw col_buffer的行号
output_pos oh×out_w + ow 输出位置的一维索引
col_buffer_idx col_idx×(out_h×out_w) + output_pos col_buffer线性索引
input_row input_row_start + oh×stride_h 实际输入行坐标
input_col input_col_start + ow×stride_w 实际输入列坐标
input_idx c×h×w + input_row×w + input_col 输入线性索引

为什么Im2Col快?

  1. 转换为GEMM - 利用高度优化的矩阵乘法库(MKL, cuBLAS)
  2. 缓存友好 - 连续内存访问模式
  3. 并行友好 - GEMM天然支持多线程和SIMD

代价:需要额外内存存储col_buffer

收益:计算速度提升5-10倍!

🎉 这就是所有深度学习框架(PyTorch, TensorFlow, TensorRT)加速卷积的核心技巧!