# 前言

本系列文章将从代码和流程图入手,详细介绍 TVM AI 编译器的编译流程。本文章为第四篇,对应的 TVM 版本为当前最新版本 1.7。

网络上有不少 TVM 工程的教程资源,如果本博客也是其他教程的简单重复的话,则网络的角落里又多了一份纯粹的空间占用者。所以我在想,本文章有什么特点才值得一看呢?我觉得有两个优点: 1、本文从代码出发,不会泛泛而谈,能够从细节了解 TVM;2、自认为结构流程图画的不错,能够从整体上把握 TVM 的脉络。所以,也许值得一看呢。

本篇文章介绍 TVM BuildRelay 函数。文章 《【TVM】通过代码学习编译流程【3】模型编译》 已经介绍了 Relay IRModule 转换为 GraphExecutorFactory 的过程。其中中间有个 bld_mod.build() 函数调用 C++ 类 RelayBuildModule 的函数 buildbuild 流程包含了将 Relay IRModule 降级为低级中间表示 TIR,然后再转换为 Runtime::Module 的过程。本篇文章将介绍 BuildRelay 和其子函数 OptimizeImpl

因为代码量巨大,模型编译会分成若干篇文章进行解析。接下来的若干篇都会介绍 BuildRelay 函数 及其调用的子函数。

作为初学者,错误在所难免,还望不吝赐教。

# Python 脚本

这里提供一个简单的 Python 脚本,调用 TVM Python 前端,实现 onnx 模型的编译过程。tvm 通过代码学习编程流程系列文章将基本采用这个脚本帮助追踪代码。

import onnx
from PIL import Image
import numpy as np
import tvm.relay as relay
import tvm
from tvm.contrib import graph_executor
######################################    路径信息    ##########################################
model_path = "/home/xianmu/module/resnet18.onnx"
save_path = "/home/xianmu/module/pythonSave/"
onnx_model = onnx.load(model_path)
##################################    图片信息    ##############################################
img_path = "/home/xianmu/.tvm_test_data/data/imagenet_cat.png"
# Resize it to 224x224
resized_image = Image.open(img_path).resize((224, 224))
img_data = np.asarray(resized_image).astype("float32")
# Our input image is in HWC layout while ONNX expects CHW input, so convert the array
img_data = np.transpose(img_data, (2, 0, 1))
# Normalize according to the ImageNet input specification
imagenet_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
imagenet_stddev = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
norm_img_data = (img_data / 255 - imagenet_mean) / imagenet_stddev
# Add the batch dimension, as we are expecting 4-dimensional input: NCHW.
img_data = np.expand_dims(norm_img_data, axis=0)
####################################     模型编译     ###########################################
input_name = "data"
target = tvm.target.Target(target="llvm", host="llvm")
shape_dict = {input_name: img_data.shape}
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict, export_node_renamed_model_path=save_path)  # 创建 IRModule  高级 Relay IR
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)   # 创建 GraphExecutorFactoryModule  
#######################################     模型保存      ########################################
# save
# 保存编译后的库文件(.so)
lib_fname = save_path + "mod.so"
lib.get_lib().export_library(lib_fname)
# 保存模型参数(.params)
params_fname = save_path + "mod.params"
with open(params_fname, "wb") as param_file:
    param_file.write(relay.save_param_dict(lib.get_params()))
# 保存 JSON 格式的计算图(.json)
json_fname = save_path + "mod.json"
with open(json_fname, "w") as json_file:
    json_file.write(lib.get_executor_config())
dev = tvm.device(str(target), 0)
module = graph_executor.GraphModule(lib["default"](dev))    # graph_executor.GraphModule
############################       运行    ##########################################
module.set_input(input_name, img_data)
module.run()
output_shape = (1, 1000)
tvm_output = module.get_output(0, tvm.nd.empty(output_shape)).numpy()
print(tvm_output)

# BuildRelay

文章 《【TVM】通过代码学习编译流程【3】模型编译》 已经介绍了 Relay IRModule 转换为 GraphExecutorFactory 的过程。其中中间有个 bld_mod.build() 函数调用 C++ 类 RelayBuildModule 的函数 buildbuild 流程包含了将 Relay IRModule 降级为低级中间表示 TIR,然后再转换为 Runtime::Module 的过程。本篇讲解 bld_mod.build() 函数调用到的 BuildRelay 函数,即下图中的红色节点。

总体流程图

前述流程中讲到 bld_mod.build() 函数调用 C++ 类 RelayBuildModule 的函数 build 。 下面是 RelayBuildModule 类根据名字 “build” 调用的函数,它调用了自身的 Build() 函数。

else if (name == "build") {
      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
        ICHECK_EQ(args.num_args, 8);
        this->Build(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]);
      });

RelayBuildModule 类的 Build() 函数以 Relay IRModule 模型为输入,调用关键函数 BuildRelay(std::move(mod), mod_name);

void Build(IRModule mod, const Array<Target>& raw_targets, const tvm::Target& target_host,
             const Executor& executor, const Runtime& runtime,
             const WorkspaceMemoryPools& workspace_memory_pools,
             const ConstantMemoryPools& constant_memory_pools, const String mod_name) {
    VLOG_CONTEXT << "Build";
    executor_ = executor;
    runtime_ = runtime;
    workspace_memory_pools_ = workspace_memory_pools;
    constant_memory_pools_ = constant_memory_pools;
    config_ = CompilationConfig(PassContext::Current(), raw_targets);
    VLOG(1) << "Using compilation config:" << std::endl << config_;
    BuildRelay(std::move(mod), mod_name);
  }

RelayBuildModule 类的 BuildRelay() 函数是 TVM 编译的关键函数。基本的编译流程都在该函数当中。代码如下所示,在代码中添加了部分注释。简略介绍一下流程;
WithAttrs() —— 为 Relay IRModule 添加 Executor 和 Runtime 属性
OptimizeImpl(std::move(module)) —— 收集并执行大量针对高级中间表示 Relay IRModule 的优化 Pass,包含算子融合,常量折叠等。
MakeExecutorCodegen(executor_->name) —— 构建代码生成 GraphCodegen,用于将高层次的计算描述转换为特定硬件平台上的低层次、可执行代码。
Init(nullptr, config_->primitive_targets) ——Codegen 初始化。
Codegen(func_module, func, mod_name) —— 将 Relay IRModule 降级为 TIR Module。
UpdateOutput(&ret_) —— 更新降级后的 json 图结构到 BuildOutput 结构体
executor_codegen_->GetParams() —— 更新降级后的 params 到 BuildOutput 结构体
TIRToRuntime(lowered_funcs, host_target) ——TIR Module 转换为 runtime::Module

void BuildRelay(IRModule relay_module, const String& mod_name) {
    // Relay IRModule -> IRModule optimizations.
    IRModule module = WithAttrs(  // 为 Relay IRModule 添加 Executor 和 Runtime 属性
        relay_module, <!--swig0-->);
    relay_module = OptimizeImpl(std::move(module));  // 执行多个针对 Relay IRModule 的优化 Pass
    // Get the updated function and new IRModule to build.
    Function func = Downcast<Function>(relay_module->Lookup("main"));  // 获取 Relay IRModule 中的 main 函数表达式
    IRModule func_module = WithAttrs(IRModule::FromExpr(func),  // 为 main 函数表达式添加属性信息
                                     <!--swig1-->);
    // Generate code for the updated function.
    executor_codegen_ = MakeExecutorCodegen(executor_->name);   // 构建代码生成 GraphCodegen
    executor_codegen_->Init(nullptr, config_->primitive_targets);  // Codegen 初始化
    executor_codegen_->Codegen(func_module, func, mod_name);   // 将 Relay IRModule 降级为 TIR Module    
    executor_codegen_->UpdateOutput(&ret_);  // 更新降级后的 json 图结构到 BuildOutput 结构体
    ret_.params = executor_codegen_->GetParams();   // 更新降级后的 params 到 BuildOutput 结构体
    auto lowered_funcs = executor_codegen_->GetIRModule(); // 获取降级后的 TIR Module
    // No need to build for external functions.
    Target ext_dev("ext_dev");
    if (lowered_funcs.find(ext_dev) != lowered_funcs.end()) {
      lowered_funcs.Set(ext_dev, IRModule());
    }
    const Target& host_target = config_->host_virtual_device->target;
    const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate");
    // When there is no lowered_funcs due to reasons such as optimization.
    if (lowered_funcs.size() == 0) {
      if (host_target->kind->name == "llvm") {
        CHECK(pf != nullptr) << "Unable to create empty module for llvm without llvm codegen.";
        // If we can decide the target is LLVM, we then create an empty LLVM module.
        ret_.mod = (*pf)(host_target->str(), "empty_module");
      } else {
        // If we cannot decide the target is LLVM, we create an empty CSourceModule.
        // The code content is initialized with ";" to prevent complaining
        // from CSourceModuleNode::SaveToFile.
        ret_.mod = tvm::codegen::CSourceModuleCreate(";", "", Array<String>{});
      }
    } else {
      ret_.mod = tvm::TIRToRuntime(lowered_funcs, host_target);  // TIR Module 转换为 runtime::Module
    }
    auto ext_mods = executor_codegen_->GetExternalModules();
    ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, host_target,
                                                  runtime_, executor_,
                                                  executor_codegen_->GetExecutorCodegenMetadata());
    // Remove external params which were stored in metadata module.
    for (tvm::runtime::Module mod : ext_mods) {
      auto pf_var = mod.GetFunction("get_const_vars");
      if (pf_var != nullptr) {
        Array<String> variables = pf_var();
        for (size_t i = 0; i < variables.size(); i++) {
          auto it = ret_.params.find(variables[i].operator std::string());
          if (it != ret_.params.end()) {
            VLOG(1) << "constant '" << variables[i] << "' has been captured in external module";
            ret_.params.erase(it);
          }
        }
      }
    }
  }

BuildRelay 过程很长,这是总体的流程结构的一部分。

总体

# OptimizeImpl

OptimizeImpl(std::move(module)) —— 收集并执行大量针对高级中间表示 Relay IRModule 的优化 Pass,包含算子融合,常量折叠等。 OptimizeImpl() 函数如下所示。代码中包含了主要函数的简单注释。

IRModule OptimizeImpl(IRModule relay_module) {
    ICHECK(relay_module.defined()) << "The IRModule must be defined for the Relay compiler.";
    backend::BindParamsInModule(relay_module, params_);  // 根据参数名字绑定参数
    Array<Pass> pass_seqs =  // 获取一系列 pass,组成包含 Pass 的数组
        GetPassPrefix(/*is_homogenous=*/config_->primitive_targets.size() == 1, /*is_vm=*/false);
    transform::PassContext pass_ctx = PassContext::Current();
    if (config_->optional_homogeneous_target.defined()) {
      // This pass currently only supports the homogeneous case.
      pass_seqs.push_back(transform::SplitArgs(
          config_->optional_homogeneous_target->GetAttr<Integer>("max_function_args", 0)
              .value()
              .IntValue()));
    }
    // Always plan devices so the remaining passes don't need to distinguish homogeneous vs hetrogenous execution.
    pass_seqs.push_back(transform::PlanDevices(config_));
    // Fuse the operations if it is needed.
    pass_seqs.push_back(transform::FuseOps());  // 添加算子融合 Pass
    // Create a sequential pass and perform optimizations.
    transform::Pass seq = transform::Sequential(pass_seqs);  // 将多个 Pass 封装成 Sequential 类。Sequential 类包含多个按照顺序执行的 Pass,类似于 pytorch 里面的 nn.Sequential
    if (config_->optional_homogeneous_target.defined()) {
      With<Target> tctx(config_->optional_homogeneous_target);
      relay_module = seq(relay_module);
    } else {
      relay_module = seq(relay_module);  // 执行 Sequential seq 中的所有 Pass,修改 relay_module 结构
    }
    // Do layout rewrite for auto-scheduler.  
    if (backend::IsAutoSchedulerEnabled() && config_->optional_homogeneous_target.defined()) {  // 使用 auto-schedule 优化调度的情况下
      Pass major_pass = transform::AutoSchedulerLayoutRewrite();  // 内存排布重写 Pass
      bool enable_layout_rewrite_targets =
          config_->optional_homogeneous_target->GetTargetDeviceType() == kDLCPU ||
          config_->optional_homogeneous_target->GetAttr<String>("device", "") == "mali";
      if (enable_layout_rewrite_targets && pass_ctx.PassEnabled(major_pass->Info())) {
        With<Target> tctx(config_->optional_homogeneous_target);
        relay_module = major_pass(relay_module);
        // Defuse ops to fold constants, then fuse them again
        relay_module = transform::DefuseOps()(relay_module);  // 执行 单个 Pass
        relay_module = transform::FoldConstant()(relay_module);
        relay_module = transform::FuseOps()(relay_module);
      }
    }
    if (backend::IsMetaScheduleEnabled() && config_->optional_homogeneous_target.defined()) { // 使用 meta-schedule 优化调度的情况下
      Pass major_pass = transform::MetaScheduleLayoutRewrite();
      bool enable_layout_rewrite_targets =
          config_->optional_homogeneous_target->GetTargetDeviceType() == kDLCPU ||
          config_->optional_homogeneous_target->GetAttr<String>("device", "") == "mali";
      if (enable_layout_rewrite_targets && pass_ctx.PassEnabled(major_pass->Info())) {
        With<Target> tctx(config_->optional_homogeneous_target);
        relay_module = major_pass(relay_module);
        // Defuse ops to fold constants, then fuse them again
        relay_module = transform::DefuseOps()(relay_module);
     // 执行反算子融合 Pass
        relay_module = transform::FoldConstant()(relay_module);  // 执行常量折叠 Pass
        relay_module = transform::FuseOps()(relay_module);
       // 执行算子融合 Pass
      }
    }
    relay_module = transform::InferType()(relay_module);
    relay_module = transform::Inline()(relay_module);
    relay_module = transform::InferType()(relay_module);
    relay_module = transform::LabelOps()(relay_module);
    relay_module = transform::AnnotateMemoryScope()(relay_module);
    return relay_module;
  }

backend::BindParamsInModule(relay_module, params_) —— 根据参数名字绑定参数

Meta Schedule 和 AutoSchedule 都是用于自动优化调度的工具。AutoSchedule 是 TVM 较早期引入的一种自动调度机制,它的主要特点是基于成本模型的搜索算法来探索可能的调度选项。AutoSchedule 通常依赖于一个预定义的成本模型来估计不同调度策略下的性能,并使用搜索算法(如进化算法或随机搜索)来找到最佳的调度。MetaSchedule 是一个更为现代化且灵活的自动调度框架,它旨在解决日益增长的硬件多样性和复杂的深度学习工作负载带来的挑战。与 AutoSchedule 相比,MetaSchedule 引入了更多先进的技术和设计理念,如自适应性,机器学习搜索测量,有反馈的成本模型。使用方法:例如可以在 PassContext 中配置 “relay.backend.use_meta_schedule” 参数设置为 TRUE,则 TVM 使用 Meta Schedule 帮助完成自动调度优化。

GetPassPrefix() —— 获取一系列 Pass,组成包含 Pass 的数组。代码如下。

函数中还获取了很多其他 Pass,下一篇文章将选择其中的 DefuseOps Pass 进行讲解。

Array<Pass> GetPassPrefix(bool is_homogeneous, bool is_vm) {
  Array<Pass> pass_seqs;
  // TODO(mbs): Would be nice to get spans on all diagnostics, but since they arg forgotton
  // by most passes there's little utility in including this now. Plus we'd need to only do
  // this if there's no existing spans to work from.
  // pass_seqs.push_back(parser::AnnotateSpans());
  Array<runtime::String> entry_functions{"main"};
  pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
  pass_seqs.push_back(transform::ToBasicBlockNormalForm());
  // Run all dialect legalization passes.
  pass_seqs.push_back(relay::qnn::transform::Legalize());
  // Legalize pass is restricted to homogeneous execution for now.
  if (is_homogeneous) {
    pass_seqs.push_back(transform::Legalize());
  }
  pass_seqs.push_back(transform::SimplifyInference());
  if (is_vm) {
    // eta expand to support constructors in argument position
    pass_seqs.push_back(transform::EtaExpand(
        /* expand_constructor */ true, /* expand_global_var */ false));
  }
  PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
    Expr expr = args[0];
    if (auto* call_node = expr.as<CallNode>()) {
      auto op_node = call_node->op.as<OpNode>();
      if (op_node->name == "cast") {
        auto attrs = call_node->attrs.as<CastAttrs>();
        if (attrs->dtype == DataType::Int(32)) {
          *rv = true;
        }
      }
    }
    *rv = false;
  });
  pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
  pass_seqs.push_back(transform::CombineParallelConv2D(3));
  pass_seqs.push_back(transform::CombineParallelDense(3));
  pass_seqs.push_back(transform::CombineParallelBatchMatmul(3));
  pass_seqs.push_back(transform::FoldConstant());
  pass_seqs.push_back(transform::FoldScaleAxis());
  pass_seqs.push_back(transform::SimplifyExpr());
  pass_seqs.push_back(transform::CanonicalizeCast());
  pass_seqs.push_back(transform::CanonicalizeOps());
  pass_seqs.push_back(transform::FlattenAtrousConv());
  // Alter layout transformation is currently only applied to homogeneous execution.
  if (is_homogeneous) {
    if (!is_vm) {
      pass_seqs.push_back(transform::InferType());
    }
    pass_seqs.push_back(transform::AlterOpLayout());
    pass_seqs.push_back(transform::SimplifyExprPostAlterOp());
  }
  // Fast math optimizations.
  pass_seqs.push_back(transform::FastMath());
  pass_seqs.push_back(transform::FoldConstant());
  return pass_seqs;
}

# 后记

本博客目前以及可预期的将来都不会支持评论功能。各位大侠如若有指教和问题,可以在我的 github 项目 或随便一个项目下提出 issue,或者知乎 私信,并指明哪一篇博客,我看到一定及时回复。

Edited on

Give me a cup of [coffee]~( ̄▽ ̄)~*

XianMu WeChat Pay

WeChat Pay

XianMu Alipay

Alipay