Mini-Infer 架构深潜 (4): Graph 与 Node - 编织计算“神经网”
引言:从“算子”到“网络”
在上一篇文章中,我们构建了 Operator 抽象和一个“自注册”工厂。我们现在有能力创建独立的计算单元(如 “ReLU”, “Convolution”)。
但一个神经网络不是一堆孤立的算子,它是一个有向无环图 (DAG)。数据必须从 Input 流向 Convolution,再流向 ReLU,最终到达 Output。
我们如何描述这种连接关系和执行顺序?
本篇,我们将构建 Mini-Infer 的“骨架”:Graph(图)和 Node(节点)。
1. Node:图的基本单元 (node.h)
Node 是我们图结构中最基本的“原子”。它是一个轻量级的数据容器,其职责是“连接”。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
| #pragma once
#include "mini_infer/operators/operator.h" #include "mini_infer/core/tensor.h" #include <vector> #include <memory>
namespace mini_infer { namespace graph {
class Node { public: explicit Node(const std::string& name); ~Node() = default;
void set_operator(std::shared_ptr<operators::Operator> op) { op_ = op; } std::shared_ptr<operators::Operator> get_operator() const { return op_; } void add_input(std::shared_ptr<Node> input_node); void add_output(std::shared_ptr<Node> output_node); const std::vector<std::shared_ptr<Node>>& inputs() const; const std::vector<std::shared_ptr<Node>>& outputs() const; void set_input_tensors(const std::vector<std::shared_ptr<core::Tensor>>& tensors); void set_output_tensors(const std::vector<std::shared_ptr<core::Tensor>>& tensors); const std::vector<std::shared_ptr<core::Tensor>>& input_tensors() const; std::vector<std::shared_ptr<core::Tensor>>& output_tensors();
private: std::string name_; std::shared_ptr<operators::Operator> op_; std::vector<std::shared_ptr<Node>> input_nodes_; std::vector<std::shared_ptr<Node>> output_nodes_; std::vector<std::shared_ptr<core::Tensor>> input_tensors_; std::vector<std::shared_ptr<core::Tensor>> output_tensors_; }; } }
|
Node 的设计非常清晰:
- 它持有
Operator:这是 Node 与计算逻辑的“灵魂”链接。Node 负责“结构”,Operator 负责“计算”。
- 它持有
Node 指针:input_nodes_ 和 output_nodes_ 定义了图的拓扑结构(边)。
- 它持有
Tensor 指针:input_tensors_ 和 output_tensors_ 是运行时的数据缓冲区。forward 期间,一个 Node 会从 input_tensors_ 中读取数据,调用 op_ 进行计算,然后将结果填入 output_tensors_。
- 无需 Pimpl:和
Shape、Tensor 一样,Node 是一个高性能、轻量级的数据结构。它的公共接口已经暴露了 vector 和 shared_ptr,使用 Pimpl 毫无收益,反而会带来不必要的性能开销。
2. Graph:网络“管理器” (graph.h)
Graph 类的职责是持有并管理所有的 Node,并提供构建、验证和排序图的核心算法。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
| #pragma once
namespace mini_infer { namespace graph {
class Graph { public: Graph() = default; Graph(const Graph&) = delete; Graph& operator=(const Graph&) = delete; Graph(Graph&&) noexcept = default; Graph& operator=(Graph&&) noexcept = default;
std::shared_ptr<Node> create_node(const std::string& name); std::shared_ptr<Node> get_node(const std::string& name) const; [[nodiscard]] core::Status connect(const std::string& src_name, const std::string& dst_name); void set_inputs(const std::vector<std::string>& input_names); void set_outputs(const std::vector<std::string>& output_names);
[[nodiscard]] core::Status topological_sort(std::vector<std::shared_ptr<Node>>& sorted_nodes) const; [[nodiscard]] core::Status optimize(); [[nodiscard]] core::Status validate() const; const std::unordered_map<std::string, std::shared_ptr<Node>>& nodes() const noexcept { return nodes_; }
private: std::unordered_map<std::string, std::shared_ptr<Node>> nodes_; std::vector<std::string> input_names_; std::vector<std::string> output_names_; }; } }
|
这个 Graph 类的设计是健壮的:
- 所有权与生命周期:
Graph 通过 nodes_ (一个 unordered_map) 持有所有 Node 的 std::shared_ptr,从而管理它们的生命周期。
- 禁止拷贝:
Graph 是一个重量级的资源管理器,拷贝一个 Graph 几乎总是一个错误。将其显式 delete 避免了歧义。
[[nodiscard]]:在 connect, topological_sort 等返回 Status 的函数上使用 C++17 的 [[nodiscard]] 是一个极好的实践。它会强制调用者必须检查返回值,防止忽略 ERROR_RUNTIME 等严重错误。
3. 核心算法:图的构建与排序 (graph.cpp)
Graph 类的灵魂在于它的 .cpp 实现,尤其是 connect 和 topological_sort。
3.1. connect (鲁棒的边构建)
connect 函数是图构建的基石,它必须非常鲁棒:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| core::Status Graph::connect(const std::string& src_name, const std::string& dst_name) { if (src_name == dst_name) { return core::Status::ERROR_INVALID_ARGUMENT; } auto src = get_node(src_name); auto dst = get_node(dst_name); if (!src || !dst) { return core::Status::ERROR_INVALID_ARGUMENT; }
for (const auto& out : src->outputs()) { if (out && out->name() == dst_name) { return core::Status::SUCCESS; } }
src->add_output(dst); dst->add_input(src);
return core::Status::SUCCESS; }
|
这种对自循环和重复边的检查,保证了图的构建过程是幂等的、健壮的。
3.2. topological_sort (Kahn 算法)
这是整个推理框架的“执行引擎”。forward 必须(也只能)按照拓扑排序的顺序执行。我们不能在 Convolution 之前执行 ReLU。
Mini-Infer 采用了经典的 Kahn 算法(基于入度 “in-degree” 的算法)来实现。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
| core::Status Graph::topological_sort(std::vector<std::shared_ptr<Node>>& sorted_nodes) const { sorted_nodes.clear(); const size_t num_nodes = nodes_.size(); if (num_nodes == 0) return core::Status::SUCCESS;
std::unordered_map<std::string, int> in_degree; in_degree.reserve(num_nodes); for (const auto& kv : nodes_) { in_degree.emplace(kv.first, 0); }
for (const auto& kv : nodes_) { for (const auto& out : kv.second->outputs()) { if (out) { ++(in_degree[out->name()]); } } }
std::queue<std::shared_ptr<Node>> q; for (const auto& kv : nodes_) { if (in_degree[kv.first] == 0) { q.push(kv.second); } }
while (!q.empty()) { auto node = q.front(); q.pop();
sorted_nodes.push_back(node);
for (const auto& out : node->outputs()) { --(in_degree[out->name()]); if (in_degree[out->name()] == 0) { q.push(out); } } }
if (sorted_nodes.size() != num_nodes) { return core::Status::ERROR_RUNTIME; }
return core::Status::SUCCESS; }
|
这个算法不仅为我们提供了唯一的、正确的执行顺序,还免费附赠了一个强大的功能:validate()(图验证)。validate 函数可以通过调用 topological_sort 来立即检测图是否存在“环路”(Cycle),这是推理图的“绝症”。
4. 架构决策:算法 vs. 数据
在 Mini-Infer 的这个版本中,topological_sort 和 validate 被实现为 Graph 类的成员函数。
- 这是一种经典、有效的面向对象设计。
Graph 封装了它的数据(nodes_)和操作这些数据的方法(topological_sort)。
- 另一种设计是将
Graph 保持为一个纯粹的“数据容器”,并将 topological_sort 等算法移到 graph_algorithms.h 这样的“算法”文件中,作为自由函数 topological_sort(Graph& graph) 来调用。这被称为“数据与算法分离”,它更强调“算法”的独立性。
两种设计都是专业且正确的。Mini-Infer 目前的选择(成员函数)具有很强的封装性,非常适合当前项目的演进。
总结与展望
我们成功地构建了 Mini-Infer 的“骨架”。我们现在有了一个功能齐全的图管理器,它能够:
- 用
Node 链接 Operator。
- 用
connect 构建图的拓扑结构。
- 用
topological_sort 获得一个安全、正确的执行序列。
我们已经拥有了 Tensor(数据), Backend(场地), Operator(计算单元), 和 Graph(蓝图)。