Mini-Infer (27): 运行时架构重构 (下) — ExecutionContext 与零拷贝执行

1. ExecutionContext 的职责定位

在上一篇中,我们介绍了 InferencePlan 作为不可变的构建产物。本篇我们来看它的"运行时伙伴"——ExecutionContext

ExecutionContext每次推理请求的可变状态容器,它负责:

  1. 内存池管理:持有实际的内存缓冲区。
  2. 中间张量存储:存储每个节点的输出激活值。
  3. 设备上下文:管理 CPU/CUDA 执行环境。
  4. 动态形状推理:在输入形状变化时重新推导。

核心设计原则

  • InferencePlan共享的,多个 Context 可以引用同一个 Plan。
  • ExecutionContext独占的,每个推理请求一个 Context。
  • 这种分离使得并发推理成为可能。

2. 初始化流程 (initialize)

A. 类定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
// mini_infer/runtime/execution_context.h

class ExecutionContext {
public:
explicit ExecutionContext(std::shared_ptr<const InferencePlan> plan);

core::Status set_inputs(const std::unordered_map<std::string,
std::shared_ptr<core::Tensor>>& inputs);
core::Status set_inputs(const std::vector<std::shared_ptr<core::Tensor>>& inputs);

const std::vector<std::shared_ptr<core::Tensor>>& outputs() const;
const std::unordered_map<std::string, std::shared_ptr<core::Tensor>>& named_outputs() const;

private:
friend class InferencePlan;

std::shared_ptr<const InferencePlan> plan_; // 引用 Plan (不可变)
std::shared_ptr<void> shared_buffer_; // 共享内存池
size_t shared_buffer_size_{0};
std::vector<std::vector<std::shared_ptr<core::Tensor>>> node_outputs_; // 每个节点的输出
std::unique_ptr<ShapeInferenceEngine> shape_inference_engine_; // 动态形状引擎
std::unordered_map<core::DeviceType, std::shared_ptr<backends::DeviceContext>> contexts_;
bool initialized_{false};
};

B. 初始化流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// mini_infer/runtime/execution_context.cpp

core::Status ExecutionContext::initialize() {
if (initialized_) {
return core::Status::SUCCESS;
}

// 1. 创建设备上下文
core::DeviceType device_type = plan_->config().device_type;
#ifdef MINI_INFER_USE_CUDA
if (device_type == e::DeviceType::CUDA) {
contexts_.emplace(core::DeviceType::CUDA,
std::make_shared<backends::cuda::CUDADeviceContext>(plan_->config().device_id));
} else
#endif
{
contexts_.emplace(core::DeviceType::CPU,
std::make_shared<backends::CPUDeviceContext>());
}

// 2. 预分配节点输出容器
node_outputs_.clear();
node_outputs_.resize(plan_->graph()->node_capacity());

// 3. 分配张量内存
auto status = allocate_tensors();
if (status != core::Status::SUCCESS) {
return status;
}

// 4. 初始化动态形状引擎 (如果启用)
if (plan_->config().enable_dynamic_shapes) {
shape_inference_engine_ = std::make_unique<ShapeInferenceEngine>(plan_->graph());
shape_inference_engine_->set_verbose(plan_->config().enable_profiling);
}

initialized_ = true;
return core::Status::SUCCESS;
}

关键点

  1. DeviceContext 创建:根据配置创建 CPU 或 CUDA 上下文。
  2. node_outputs_ 预分配:按图的节点容量预分配,避免运行时扩容。
  3. 张量内存分配:这是最关键的步骤,下面详细讲解。

3. 内存绑定机制

A. 内存池准备 (prepare_memory_pools)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
core::Status ExecutionContext::prepare_memory_pools(bool use_memory_pools) {
if (!use_memory_pools) {
return core::Status::SUCCESS;
}

const auto& plan = plan_->get_memory_plan();
shared_buffer_size_ = plan.shared_buffer_size;

// 根据设备类型分配内存
core::DeviceType device_type = plan_->config().device_type;
void* raw = nullptr;

#ifdef MINI_INFER_USE_CUDA
if (device_type == core::DeviceType::CUDA) {
auto cuda_allocator = std::make_shared<backends::cuda::CUDAAllocator>(
plan_->config().device_id);
raw = cuda_allocator->allocate(shared_buffer_size_,
plan_->config().memory_alignment);
cudaMemset(raw, 0, shared_buffer_size_);

// 使用自定义删除器确保正确释放
cuda_allocator_ = cuda_allocator;
shared_buffer_.reset(raw, [allocator = cuda_allocator](void* p) {
allocator->deallocate(p);
});
} else
#endif
{
raw = backends::cpu::CPUAllocator::instance()->allocate(
shared_buffer_size_, plan_->config().memory_alignment);
std::memset(raw, 0, shared_buffer_size_);
shared_buffer_.reset(raw, [](void* p) {
backends::cpu::CPUAllocator::instance()->deallocate(p);
});
}

return core::Status::SUCCESS;
}

设计要点

  1. 单块连续内存shared_buffer_ 是一块大的连续内存,所有中间张量都从这里分配。
  2. 自定义删除器:使用 shared_ptr 的自定义删除器确保内存正确释放。
  3. 零初始化:使用 memsetcudaMemset 清零,避免未初始化数据。

B. try_bind_tensor_to_pool 的三种结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
enum class PoolBindResult {
kNotTried, // 未尝试绑定 (内存规划未启用)
kBound, // 成功绑定到内存池
kFailed // 绑定失败
};

PoolBindResult ExecutionContext::try_bind_tensor_to_pool(
size_t node_id, size_t output_index,
std::shared_ptr<core::Tensor>& tensor,
bool use_memory_pools, int& allocated_count, int& failed_count) {

if (!use_memory_pools) {
return PoolBindResult::kNotTried;
}

const auto& plan = plan_->get_memory_plan();

// 检查是否有该节点的内存规划
if (node_id < plan.tensor_offsets.size() &&
plan.tensor_offsets[node_id] != MemoryPlan::kInvalidOffset) {

const size_t required = tensor->size_in_bytes();
const size_t offset = plan.tensor_offsets[node_id];

// 边界检查
if (offset + required > shared_buffer_size_) {
MI_LOG_ERROR("Tensor exceeds shared buffer size");
return PoolBindResult::kFailed;
}

// 绑定到共享缓冲区的指定偏移
core::DeviceType device_type = ->config().device_type;
if (!tensor->bind_external_data_with_offset(
shared_buffer_, shared_buffer_size_, offset, device_type)) {
return PoolBindResult::kFailed;
}

allocated_count++;
return PoolBindResult::kBound;
}

return PoolBindResult::kNotTried;
}

零拷贝的关键bind_external_data_with_offset 方法让 Tensor 直接指向共享缓冲区的某个偏移位置,而不是分配新内存。这就是"零拷贝"的含义——数据不需要在不同缓冲区之间复制。

C. 严格内存分配策略

当内存规划启用时,我们采用严格策略:如果某个张量没有对应的内存规划条目,直接报错而不是回退到动态分配。

1
2
3
4
if (use_memory_pools) {
MI_LOG_ERROR("Missing memory plan entry for node " +
std::to_string(node-> return core::Status::ERROR_RUNTIME;
}

这种设计确保了:

  1. 内存使用可预测:所有张量都在规划范围内。
  2. 早期发现问题:如果内存规划有遗漏,构建时就会报错。
  3. 运行时零分配:推理过程中不会触发 malloc

4. 节点执行流程 (execute_node)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
core::Status ExecutionContext::execute_node(const std::shared_ptr<graph::Node>& node) {
// 1. 收集来自上游节点的输入
std::vector<std::shared_ptr<core::Tensor>> input_tensors;
const auto& input_edges = node->inputs();

// 计算图输入的数量
int max_dst_port = -1;
for (const auto& edge : input_edges) {
max_dst_port = std::max(max_dst_port, edge.dst_port);
}
size_t graph_input_count = static_cast<size_t>(max_dst_port + 1);
input_tensors.resize(graph_input_count);

// 从上游节点收集输出
for (const auto& edge : input_edges) {
const auto& src_outputs = node_outputs_[edge.node->id()];
input_tensors[edge.dst_port] = src_outputs[edge.src_port];
}

// 2. 合并图输入和权重输入
std::vector<std::shared_ptr<core::Tensor>> merged_inputs = node->input_tensors();
for (size_t i = 0; i < input_tensors.size(); ++i) {
if (input_tensors[i]) {
merged_inputs[i] = input_tensors[i];
}
}

// 3. 确保输入在正确的设备上 (CUDA 模式)
#ifdef MINI_INFER_USE_CUDA
if (device == core::DeviceType::CUDA) {
for (auto& tensor : merged_inputs) {
auto status = ensure_on_gpu(tensor);
if (status != core::Status::SUCCESS) {
return status;
}
}
}
#endif

// 4. 获取设备上下文
auto context = get_or_create_context(device_type);

// 5. 调用 Plugin 执行
auto* cached_plugin = node->get_operator()->cached_plugin();
operators::PluginContext plugin_ctx;
plugin_ctx.device_context = context.get();
return cached_plugin->enqueue(merged_inputs, output_tensors, plugin_ctx);
}

执行流程图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
┌────────────────────────────────────────────────────────┐
│ execute_node(node) │
├─────────────────────────────────────────────────────────────┤
│ 1. 收集上游输入 │
│ ┌──────────┐ ┌──────────┐ │
│ │ Node A │───►│ input[0] │ │
│ └──────────┘ └──────────┘ │
│ ┌──────────┐ ┌──────────┐ │
│ │ Node B │───►│ input[1] │ │
│ └──────────┘ └──────────┘ ─────────────────────────────────────────────────────────┤
│ 2. 合并权重输入 │
│ ┌──────────┐ ┌──────────┐ │
│ │ Weight │───►│ input[2] │ │
│ └──────────┘ └──────────┘ │
├─────────────────────────────────────────────────────────────┤
│ 3. 确保数据在正确设备 (GPU 模式) │
│ CPU Tensor ──cudaMemcpy──► GPU Tensor │
├─────────────────────────────────────────────────────────────┤
│ 4. 调用 Plugin::enqueue() │
│ merged_inputs ──► Plugin ──► output_tensors │
└─────────────────────────────────────────────────────────────┘

ensure_on_gpu 的三级查找

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
auto ensure_on_gpu = [&](std::shared_ptr<core::Tensor>& tensor) -> core::Status {
if (!tensor || tensor->device() == core::DeviceType::CUDA) {
return core::Status::SUCCESS; // 已在 GPU 或空张量
}

// 第一级:检查 Plan 的预加载缓存 (Build-Time 预加载的权重)
auto preloaded = plan_->get_gpu_tensor(tensor);
if (preloaded) {
tensor = preloaded;
return core::Status::SUCCESS;
}

// 第二级:检查 Context 的本地缓存 (运行时动态创建的张量)
auto cache_it = gpu_constant_cache_.find(tensor);
if (cache_it != gpu_constant_cache_.end()) {
tensor = cache_it->second;
return core::Status::SUCCESS;
}

// 第三级:运行时拷贝 (应该很少发生)
auto gpu_tensor = std::make_shared<core::Tensor>(
tensor->shape(), tensor->dtype(), core::DeviceType::CUDA);
cudaMemcpy(gpu_tensor->data(), tensor->data(),
tensor->size_in_bytes(), cudaMemcpyHostToDevice);
gpu_constant_cache_[tensor] = gpu_tensor;
tensor = gpu_tensor;
return core::Status::SUCCESS;
};

设计思想

  1. Build-Time 预加载优先:大部分权重在构建时已经加载到 GPU。
  2. Context 缓存次之:动态创建的常量张量缓存在 Context 中。
  3. 运行时拷贝兜底:只有极少数情况需要运行时拷贝。

5. 动态形状支持

A. shape_inference_engine_ 集成

enable_dynamic_shapes 为 true 时,Context 会持有一个 ShapeInferenceEngine 实例:

1
2
3
if (plan_->config().enable_dynamic_shapes) {
shape_inference_engine_ = std::make_unique<ShapeInferenceEngine>(plan_->graph());
}

B. handle_shape_change 的容量检查

当输入形状变化时,InferencePlan::execute() 会调用 handle_shape_change

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
// mini_infer/runtime/inference_plan.cpp

core::Status InferencePlan::handle_shape_change(
ExecutionContext* ctx,
const std::vector<ShapeInferenceEngine::RuntimeInputShape>& runtime_shapes) const {

// 1. 验证输入形状是否在 Profile 范围内
if (config_.optimization_profile) {
for (size_t idx = 0; idx < input_bindings_.size(); ++idx) {
const auto* range = config_.optimization_profile->get_shape_range(
input_bindings_[idx].name);
if (range && !range->contains(runtime_shapes[idx].shape)) {
MI_LOG_ERROR("Input shape is outside optimization profile range");
return core::Status::ERROR_INVALID_ARGUMENT;
}
}

// 2. 运行形状推理
auto status = ctx->shape_inference_engine_->infer_shapes(runtime_shapes);

// 3. 更新张量形状 (不重新分配内存)
for (const auto& node : sorted_nodes_) {
auto inferred = ctx->shape_inference_engine_->get_inferred_shape(node->name());
auto& outputs = ctx->node_outputs_[node->id()];

if (outputs[0]->shape() != *inferred) {
// 检查容量是否足够
const size_t required = inferred->numel() * outputs[0]->element_size();
size_t available = outputs[0]->capacity() - outputs[0]->storage_offset();

if (required > available) {
MI_LOG_ERROR("Shape exceeds planned capacity");
return core::Status::ERROR_RUNTIME;
}

// 只更新形状元数据,不重新分配
outputs[0]->set_shape_metadata(*inferred);
}
}

return core::Status::SUCCESS;
}

关键设计

  1. Profile 范围验证:确保输入形状在 min/max 范围内。
  2. 容量检查:新形状不能超过预分配的容量(基于 max shape)。
  3. 只更新元数据set_shape_metadata 只修改形状信息,不触发内存分配。

6. 多 Context 并发推理的可能性

由于 InferencePlan 是不可变的,多个 ExecutionContext 可以安全地共享同一个 Plan:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// 创建一个 Plan
auto plan = std::make_shared<InferencePlan>(config);->build(graph);

// 创建多个 Context
auto ctx1 = plan->create_execution_context();
auto ctx2 = plan->create_execution_context();

// 并发推理
std::thread t1([&]{
ctx1->set_inputs(inputs1);
plan->execute(ctx1.get());
auto outputs1 = ctx1->outputs();
});

std::thread t2([&]{
ctx2->set_inputs(inputs2);
plan->execute(ctx2.get());
auto outputs2 = ctx2->outputs();
});

t1.join();
t2.join();

并发安全的保证

资源 所有权 并发访问
图结构 Plan 只读,安全
权重数据 Plan 只读,安全
内存规划 Plan 只读,安全
GPU 权重缓存 Plan 只读,安全
中间张量 Context 独占,安全
内存池 Context 独占,安全
设备上下文 Context 独占,安全

7. 总结

本篇我们完成了运行时架构的第二部分:

  • ExecutionContext 作为可变的运行时状态:持有内存池、中间张量、设备上下文。
  • 零拷贝内存绑定:张量直接指向共享缓冲区的偏移位置。
  • 三级 GPU 张量查找:Build-Time 预加载 → Context 缓存 → 运行时拷贝。
  • 动态形状支持:容量检查 + 元数据更新,无需重新分配。
  • 并发推理支持:多个 Context 共享一个 Plan。

至此,Mini-Infer 的运行时架构重构完成。下一篇,我们将深入 Core 层,看看 StorageTensor 的分离如何支撑这一切。