Mini-Infer (7.6): 架构重构 - 用“模板元编程”消除内核注册的“样板戏”

1. 问题的本质:一个“模板”的“模板”

我们的问题是:KernelRegistry (注册表) 的类型,依赖于函数指针的类型,而函数指针的类型又依赖于数据类型 (float, int)

  • GEMM_NT for float -> void(*)(const float*, ...)
  • GEMM_NT for int32 -> void(*)(const int32_t*, ...)

这是一个清晰的模板模式。我们可以把 GEMM_NT 的函数签名定义为一个“函数类型模板”:

1
2
template<typename T>
using GEMMFunc_NT = void(*)(const T* A, const T* B, T* C, int M, int N, int K);

现在,我们的问题演变为:如何创建一个通用KernelRegistry,它接受 GEMMFunc_NT 这样的**“模板”**作为参数,然后再由用户指定 T(如 float)?

2. 解决方案:template<template<typename> class FuncType>

这就是 C++ 模板元编程中的“模板模板参数”。我们来构建一个通用的 KernelRegistry,它不再是一个具体的类,而是一个“注册表的工厂”。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
// mini_infer/kernels/kernel_registry_v2.h (一个假想的新文件名)
#pragma once
#include "mini_infer/kernels/kernel_base.h" // 包含 KernelRegistryBase
#include <type_traits>

namespace mini_infer {
namespace kernels {

/**
* @brief 通用内核注册表 (现代 C++ 模板方法)
*
* FuncType 不是一个类型,而是一个"类型模板",例如 GEMMFunc_NT
*/
template<template<typename> class FuncType>
class KernelRegistry {
public:
/**
* @brief 内部类,用于特定数据类型 (T)
*
* 这是真正的“注册表单例”
*/
template<typename T>
class ForType : public KernelRegistryBase<FuncType<T>> {
public:
// 关键:为每种类型 T (float, int32) 提供一个独立的单例
static ForType& instance() {
static ForType reg;
return reg;
}

private:
ForType() = default;
ForType(const ForType&) = delete;
ForType& operator=(const ForType&) = delete;
};
};

我们来拆解这个“魔法”:

  1. template<template<typename> class FuncType>
    • FuncType 不是一个类型(像 int),也不是一个类型模板(像 std::vector)。
    • 它是一个接受一个类型参数模板(像 GEMMFunc_NT)。
  2. class ForType : public KernelRegistryBase<FuncType<T>>
    • FuncType<T>:这里是“魔法”发生的地方。C++ 会把 GEMMFunc_NTfloat “组装”起来,得到 GEMMFunc_NT<float>,这解析为我们需要的函数指针类型 void(*)(const float*, ...)
    • KernelRegistryBase<GEMMFunc_NT<float>>:我们的 ForType 类继承自正确的 KernelRegistryBase 实例化版本。
    • ForType<T>::instance():为每一种 Tfloat, int32…)都提供了一个独立的、静态的单例。

现在,我们可以像这样获取 GEMM_NTfloat 注册表: KernelRegistry<GEMMFunc_NT>::template ForType<float>::instance()


3. 告别丑陋:用“宏”和“别名”实现“语法糖”

KernelRegistry<...>::template ForType<...>::instance() 这种语法简直是“反人类”的。我们必须用 C++ 的“语法糖”把它藏起来。

第1步:使用 using 别名

1
2
3
// 别名,用于隐藏内部的 ForType
template<typename T, template<typename> class FuncType>
using KernelRegistryFor = typename KernelRegistry<FuncType>::template ForType<T>;

第2步:使用“宏”来定义别名(DRY 原则) 我们为每种内核类型(GEMM_NN, GEMM_NT…)定义一个易读的别名:

1
2
3
4
// 定义注册表别名的"宏"
#define DEFINE_REGISTRY_ALIAS(Name, FuncType) \
template<typename T> \
using Name = KernelRegistryFor<T, FuncType>

现在,我们可以在 gemm.h 中这样定义我们的注册表:

1
2
3
4
5
6
// in gemm.h
// ... (定义 GEMMFunc_NT 和 GEMMFunc_NN) ...

// 一行代码,定义了 GEMMRegistry_NT<T> 和 GEMMRegistry_NN<T>
DEFINE_REGISTRY_ALIAS(GEMMRegistry_NT, GEMMFunc_NT);
DEFINE_REGISTRY_ALIAS(GEMMRegistry_NN, GEMMFunc_NN);

现在,“噩梦”变成了“梦想”:

  • 旧代码(丑陋): KernelRegistry<GEMMFunc_NT>::template ForType<float>::instance()
  • 新代码(优雅): GEMMRegistry_NT<float>::instance()

4. 用“宏”生成辅助函数

我们还需要 is_backend_availableget_best_backend 这样的辅助函数。我们同样不希望为 GEMM_NT, GEMM_NN… 手动各写一套。

我们可以用宏来生成这些“样板戏”代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
// 宏:定义一个“后端检查器”函数
#define DEFINE_BACKEND_CHECKER(FuncName, RegistryType) \
template<typename T> \
static bool FuncName(KernelBackend backend) { \
return RegistryType<T>::instance().is_backend_available(backend); \
}

// 宏:定义一个“最佳后端获取器”函数
#define DEFINE_BEST_BACKEND_GETTER(FuncName, RegistryType) \
template<typename T> \
static KernelBackend FuncName() { \
return RegistryType<T>::instance().get_best_backend(); \
}

使用:

1
2
3
4
5
6
// in gemm.h
// ... (定义别名) ...

// 自动生成 is_backend_available_nn<T> 和 is_backend_available_nt<T>
DEFINE_BACKEND_CHECKER(is_backend_available_nn, GEMMRegistry_NN)
DEFINE_BACKEND_CHECKER(is_backend_available_nt, GEMMRegistry_NT)

总结与展望

通过 C++ 模板元编程和宏,我们支付了一次性、高昂的“复杂度税”,将所有的复杂性都封装在了 kernel_registry.h 这一个文件中。

我们换来的是:

  1. 零代码复用GEMM, im2col 等所有内核现在都可以免费类型安全地重用这个注册表。
  2. 优雅的 APIGEMMRegistry_NT<float>::instance()
  3. 易于维护:我们不再需要手写几十个单例类。

我们的 KernelRegistry 架构现在已经彻底稳定。我们为 Blog 7.5 中那个“丑陋”的 KernelRegistryInitializer 提供了完美的“配套设施”。