Mini-Infer (8): Im2Col算法完全讲解
🎯 核心概念:为什么需要Im2Col?
问题:卷积计算很慢
朴素卷积需要7层嵌套循环:
1 2 3 4 5 6 7 8 9
| for (batch) for (out_channel) for (in_channel) for (kernel_h) for (kernel_w) for (out_h) for (out_w) output += input * weight
|
解决方案:转换为矩阵乘法
Im2Col的魔法:
1 2 3 4 5 6 7 8
| 卷积运算 = 矩阵乘法
Output = Conv(Input, Weight) ↓ 转换 Output = Weight × col_buffer
然后用高度优化的GEMM库(如MKL)计算 → 速度提升5-10倍!
|
📊 具体例子:一步步理解
输入参数
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
| 输入图像(灰度图): channels = 1 height = 4 width = 4 数据(线性存储): [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
可视化为矩阵: ┌────────────────┐ │ 1 2 3 4 │ │ 5 6 7 8 │ │ 9 10 11 12 │ │ 13 14 15 16 │ └────────────────┘
卷积核参数: kernel_h = 2, kernel_w = 2 stride_h = 1, stride_w = 1 padding_h = 0, padding_w = 0 dilation_h = 1, dilation_w = 1
输出尺寸: out_height = (4 - 2) / 1 + 1 = 3 out_width = (4 - 2) / 1 + 1 = 3 → 共9个输出位置(3×3网格)
|
🔍 滑动窗口可视化
卷积的本质是滑动窗口,Im2Col就是把这些窗口"展平":
位置0(左上角,oh=0, ow=0)
1 2 3 4 5 6 7 8 9 10
| 输入图像: ┌─────┐──────── │ 1 2│ 3 4 │ 5 6│ 7 8 └─────┘──────── 9 10 11 12 13 14 15 16
提取的2×2窗口: [1, 2, 5, 6] 这4个数变成col_buffer的第1列
|
位置1(第1行第2列,oh=0, ow=1)
1 2 3 4 5 6 7 8 9 10
| 输入图像: 1 ┌─────┐──── │ 2 3│ 4 │ 6 7│ 8 ──└─────┘──── 9 10 11 12 13 14 15 16
提取的2×2窗口: [2, 3, 6, 7] 这4个数变成col_buffer的第2列
|
位置2(第1行第3列,oh=0, ow=2)
1 2 3 4 5 6 7 8 9 10
| 输入图像: 1 2 ┌─────┐ │ 3 4│ │ 7 8│ ──────└─────┘ 9 10 11 12 13 14 15 16
提取的2×2窗口: [3, 4, 7, 8] 这4个数变成col_buffer的第3列
|
位置3-8(类似地…)
1 2 3 4 5
| 位置3 (oh=1, ow=0): 位置4 (oh=1, ow=1): 位置5 (oh=1, ow=2): [5, 6, 9, 10] [6, 7, 10, 11] [7, 8, 11, 12]
位置6 (oh=2, ow=0): 位置7 (oh=2, ow=1): 位置8 (oh=2, ow=2): [9, 10, 13, 14] [10, 11, 14, 15] [11, 12, 15, 16]
|
📐 col_buffer的布局
关键理解:col_buffer是一个矩阵,形状为 [卷积核元素数, 输出位置数]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| 形状: [kernel_h × kernel_w, out_h × out_w] = [2 × 2, 3 × 3] = [4, 9]
布局(每列是一个输出位置):
位置0 位置1 位置2 位置3 位置4 位置5 位置6 位置7 位置8 ┌────────────────────────────────────────────────────────────┐ kh=0, │ 1 2 3 5 6 7 9 10 11 │ ← 卷积核[0,0] kw=0 └────────────────────────────────────────────────────────────┘ ┌────────────────────────────────────────────────────────────┐ kh=0, │ 2 3 4 6 7 8 10 11 12 │ ← 卷积核[0,1] kw=1 └────────────────────────────────────────────────────────────┘ ┌────────────────────────────────────────────────────────────┐ kh=1, │ 5 6 7 9 10 11 13 14 15 │ ← 卷积核[1,0] kw=0 └────────────────────────────────────────────────────────────┘ ┌────────────────────────────────────────────────────────────┐ kh=1, │ 6 7 8 10 11 12 14 15 16 │ ← 卷积核[1,1] kw=1 └────────────────────────────────────────────────────────────┘
|
理解方式:
- 每一列 = 一个输出位置需要的所有输入数据
- 每一行 = 卷积核的一个元素在所有位置的输入值
🔢 代码逐行解析
循环结构总览
1 2 3 4 5 6 7
| for (c = 0; c < channels; ++c) for (kh = 0; kh < kernel_h; ++kh) for (kw = 0; kw < kernel_w; ++kw) 【计算这个卷积核元素对应的行号】 for (oh = 0; oh < out_height; ++oh) for (ow = 0; ow < out_width; ++ow) 【提取并填充数据到col_buffer】
|
关键变量详解
1. channel_size
1 2
| int channel_size = height * width;
|
作用:计算输入索引时,跳过多个通道的数据
1 2
| int input_row_start = -padding_h + kh * dilation_h; int input_col_start = -padding_w + kw * dilation_w;
|
含义:当前卷积核元素在输入图像中的起始位置(考虑padding和dilation)
示例1(无padding,无dilation):
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| kh=0, kw=0: input_row_start = 0 + 0×1 = 0 input_col_start = 0 + 0×1 = 0 → 卷积核左上角对应输入图像的(0,0)位置
kh=0, kw=1: input_row_start = 0 input_col_start = 0 + 1×1 = 1 → 卷积核右上角对应输入图像的(0,1)位置
kh=1, kw=0: input_row_start = 0 + 1×1 = 1 input_col_start = 0 → 卷积核左下角对应输入图像的(1,0)位置
|
示例2(padding=1,无dilation):
1 2 3 4
| kh=0, kw=0: input_row_start = -1 + 0×1 = -1 input_col_start = -1 + 0×1 = -1 → 起始位置在padding区域(负坐标)
|
示例3(无padding,dilation=2):
1 2 3 4 5 6 7 8
| kh=0, kw=0: input_row_start = 0 input_col_start = 0
kh=1, kw=1: input_row_start = 0 + 1×2 = 2 input_col_start = 0 + 1×2 = 2 → 膨胀卷积,卷积核元素间隔变大
|
3. col_idx(行索引)
1
| int col_idx = c * kernel_h * kernel_w + kh * kernel_w + kw;
|
含义:当前卷积核元素在col_buffer中的行号
计算逻辑:将三维坐标 (c, kh, kw) 转换为一维索引
示例(channels=2, kernel=2×2):
1 2 3 4 5 6 7 8
| c=0, kh=0, kw=0 → col_idx = 0×4 + 0×2 + 0 = 0 (第1行) c=0, kh=0, kw=1 → col_idx = 0×4 + 0×2 + 1 = 1 (第2行) c=0, kh=1, kw=0 → col_idx = 0×4 + 1×2 + 0 = 2 (第3行) c=0, kh=1, kw=1 → col_idx = 0×4 + 1×2 + 1 = 3 (第4行) c=1, kh=0, kw=0 → col_idx = 1×4 + 0×2 + 0 = 4 (第5行) c=1, kh=0, kw=1 → col_idx = 1×4 + 0×2 + 1 = 5 (第6行) c=1, kh=1, kw=0 → col_idx = 1×4 + 1×2 + 0 = 6 (第7行) c=1, kh=1, kw=1 → col_idx = 1×4 + 1×2 + 1 = 7 (第8行)
|
规律:每个通道占 kernel_h × kernel_w 行
1 2
| int input_row = input_row_start + oh * stride_h; int input_col = input_col_start + ow * stride_w;
|
含义:对于当前输出位置 (oh, ow),卷积核元素 (kh, kw) 对应的输入坐标
示例(stride=1):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| 输出位置(0,0): kh=0, kw=0: input_row_start=0, input_col_start=0 input_row = 0+0×1 = 0, input_col = 0+0×1 = 0 → 输入坐标(0,0)
输出位置(0,1): kh=0, kw=0: input_row = 0+0×1 = 0, input_col = 0+1×1 = 1 → 输入坐标(0,1)
输出位置(1,1): kh=0, kw=0: input_row = 0+1×1 = 1, input_col = 0+1×1 = 1 → 输入坐标(1,1) 输出位置(1,1): kh=1, kw=1: input_row_start=1, input_col_start=1 input_row = 1+1×1 = 2, input_col = 1+1×1 = 2 → 输入坐标(2,2)
|
可视化(输出位置(1,1),stride=1):
1 2 3 4 5 6 7 8 9 10 11 12
| 输入图像: 卷积核: 1 2 3 4 ┌─────┐ 5 6 7 8 │ x x│ kh=0,kw=0 → (1,1) 9 10 11 12 │ x x│ kh=0,kw=1 → (1,2) 13 14 15 16 └─────┘ kh=1,kw=0 → (2,1) kh=1,kw=1 → (2,2) 放置在(1,1): 1 2 3 4 5 ┌─────┐ │ 6 7│ │10 11│ └─────┘
|
5. col_buffer_idx(线性索引)
1 2
| int col_buffer_idx = col_idx * out_height * out_width + oh * out_width + ow;
|
含义:二维矩阵 col_buffer[col_idx][列号] 的一维存储位置
计算逻辑:
col_idx = 行号
oh * out_width + ow = 列号(二维坐标转一维)
col_idx * out_height * out_width = 跳过前面的行
示例(out_height=3, out_width=3):
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| col_idx=0, oh=0, ow=0: col_buffer_idx = 0×9 + 0×3 + 0 = 0 (第0行第0列)
col_idx=0, oh=0, ow=1: col_buffer_idx = 0×9 + 0×3 + 1 = 1 (第0行第1列)
col_idx=0, oh=2, ow=2: col_buffer_idx = 0×9 + 2×3 + 2 = 8 (第0行第8列)
col_idx=1, oh=0, ow=0: col_buffer_idx = 1×9 + 0×3 + 0 = 9 (第1行第0列)
col_idx=3, oh=2, ow=2: col_buffer_idx = 3×9 + 2×3 + 2 = 35 (第3行第8列)
|
可视化:
1 2 3
| col_buffer布局(线性存储): [0][1][2][3][4][5][6][7][8]│[9][10][11]...[17]│[18]...│[27]...[35] └─────── 第0行 ─────────────┘└──── 第1行 ──────┘ 第2行 └── 第3行 ──┘
|
1
| int input_idx = c * channel_size + input_row * width + input_col;
|
含义:三维输入 input[c][input_row][input_col] 的一维存储位置
计算逻辑:
c * channel_size = 跳过前面的通道
input_row * width = 当前行的起始位置
+ input_col = 列偏移
示例(channels=1, height=4, width=4):
1 2 3 4 5 6 7 8 9 10 11
| c=0, input_row=0, input_col=0: input_idx = 0×16 + 0×4 + 0 = 0 (值=1)
c=0, input_row=0, input_col=1: input_idx = 0×16 + 0×4 + 1 = 1 (值=2)
c=0, input_row=1, input_col=2: input_idx = 0×16 + 1×4 + 2 = 6 (值=7)
c=0, input_row=3, input_col=3: input_idx = 0×16 + 3×4 + 3 = 15 (值=16)
|
如果channels=2:
1 2
| c=1, input_row=0, input_col=0: input_idx = 1×16 + 0×4 + 0 = 16 (第2个通道的第1个元素)
|
7. 边界检查和数据复制
1 2 3 4 5 6 7 8 9
| if (input_row >= 0 && input_row < height && input_col >= 0 && input_col < width) { int input_idx = c * channel_size + input_row * width + input_col; col_buffer[col_buffer_idx] = input[input_idx]; } else { col_buffer[col_buffer_idx] = T(0); }
|
作用:
- 有效区域:从input复制实际数据
- Padding区域:填充0(模拟zero-padding)
示例(padding=1):
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| 输入(4×4): 加padding后(6×6): ┌────────────┐ ┌──────────────────┐ │ 1 2 3 4 │ │ 0 0 0 0 0 0 │ ← padding行 │ 5 6 7 8 │ │ 0 1 2 3 4 0 │ │ 9 10 11 12 │ → │ 0 5 6 7 8 0 │ │13 14 15 16 │ │ 0 9 10 11 12 0 │ └────────────┘ │ 0 13 14 15 16 0 │ │ 0 0 0 0 0 0 │ ← padding行 └──────────────────┘ ↑ ↑ padding列 padding列
当 input_row=-1 或 input_col=-1 时 → 填充0 当 input_row=6 或 input_col=6 时 → 填充0
|
🎬 完整执行流程示例
让我们跟踪第1个输出位置的数据填充:
输出位置(0,0),oh=0, ow=0
1 2 3 4 5 6 7
| 输入图像: ┌─────┐──────── │ 1 2│ 3 4 ← 需要提取这个2×2窗口 │ 5 6│ 7 8 └─────┘──────── 9 10 11 12 13 14 15 16
|
轮次1:c=0, kh=0, kw=0
1 2 3 4 5 6 7 8 9 10
| col_idx = 0×4 + 0×2 + 0 = 0 input_row_start = 0, input_col_start = 0
oh=0, ow=0: input_row = 0 + 0×1 = 0 input_col = 0 + 0×1 = 0 col_buffer_idx = 0×9 + 0×3 + 0 = 0 input_idx = 0×16 + 0×4 + 0 = 0 col_buffer[0] = input[0] = 1 ✓
|
轮次2:c=0, kh=0, kw=1
1 2 3 4 5 6 7 8 9 10
| col_idx = 0×4 + 0×2 + 1 = 1 input_row_start = 0, input_col_start = 1
oh=0, ow=0: input_row = 0 + 0×1 = 0 input_col = 1 + 0×1 = 1 col_buffer_idx = 1×9 + 0×3 + 0 = 9 input_idx = 0×16 + 0×4 + 1 = 1 col_buffer[9] = input[1] = 2 ✓
|
轮次3:c=0, kh=1, kw=0
1 2 3 4 5 6 7 8 9 10
| col_idx = 0×4 + 1×2 + 0 = 2 input_row_start = 1, input_col_start = 0
oh=0, ow=0: input_row = 1 + 0×1 = 1 input_col = 0 + 0×1 = 0 col_buffer_idx = 2×9 + 0×3 + 0 = 18 input_idx = 0×16 + 1×4 + 0 = 4 col_buffer[18] = input[4] = 5 ✓
|
轮次4:c=0, kh=1, kw=1
1 2 3 4 5 6 7 8 9 10
| col_idx = 0×4 + 1×2 + 1 = 3 input_row_start = 1, input_col_start = 1
oh=0, ow=0: input_row = 1 + 0×1 = 1 input_col = 1 + 0×1 = 1 col_buffer_idx = 3×9 + 0×3 + 0 = 27 input_idx = 0×16 + 1×4 + 1 = 5 col_buffer[27] = input[5] = 6 ✓
|
结果:输出位置(0,0)的col_buffer第0列 = [1, 2, 5, 6]
🚀 为什么这样做有效?
卷积 → 矩阵乘法的转换
假设我们有:
- Weight矩阵 (卷积核):
[out_channels, in_channels × kh × kw]
- col_buffer矩阵:
[in_channels × kh × kw, out_h × out_w]
矩阵乘法:
1 2 3
| Output = Weight × col_buffer
Output形状: [out_channels, out_h × out_w]
|
每个输出元素的计算:
1 2 3 4
| Output[i][j] = Weight[i] · col_buffer[:][j] = Weight的第i行 点乘 col_buffer的第j列 = Σ(卷积核元素 × 对应输入值) = 这正是卷积的定义!
|
示例(1个输出通道,1个输入通道):
1 2 3 4 5 6 7 8 9
| Weight (卷积核展平): [w00, w01, w10, w11] (1×4)
col_buffer的第0列 (位置0的输入): [1, 2, 5, 6]^T (4×1)
点积: w00×1 + w01×2 + w10×5 + w11×6 = 卷积核在位置(0,0)的输出值
|
💡 总结
Im2Col的6层循环做了什么?
1 2 3 4 5 6 7 8
| 1. for (c): 遍历所有输入通道 2. for (kh): 遍历卷积核的高度维度 3. for (kw): 遍历卷积核的宽度维度 → 确定col_buffer的"行"(哪个卷积核元素) 4. for (oh): 遍历输出的高度维度 5. for (ow): 遍历输出的宽度维度 → 确定col_buffer的"列"(哪个输出位置) 6. 【提取输入值,填充col_buffer】
|
关键公式记忆
| 变量 |
公式 |
含义 |
col_idx |
c×Kh×Kw + kh×Kw + kw |
col_buffer的行号 |
output_pos |
oh×out_w + ow |
输出位置的一维索引 |
col_buffer_idx |
col_idx×(out_h×out_w) + output_pos |
col_buffer线性索引 |
input_row |
input_row_start + oh×stride_h |
实际输入行坐标 |
input_col |
input_col_start + ow×stride_w |
实际输入列坐标 |
input_idx |
c×h×w + input_row×w + input_col |
输入线性索引 |
为什么Im2Col快?
- 转换为GEMM - 利用高度优化的矩阵乘法库(MKL, cuBLAS)
- 缓存友好 - 连续内存访问模式
- 并行友好 - GEMM天然支持多线程和SIMD
代价:需要额外内存存储col_buffer
收益:计算速度提升5-10倍!
🎉 这就是所有深度学习框架(PyTorch, TensorFlow, TensorRT)加速卷积的核心技巧!