# 前言

本系列文章将从代码和流程图入手,详细介绍 TVM AI 编译器的编译流程。本文章为第六篇,对应的 TVM 版本为当前最新版本 1.7。

网络上有不少 TVM 工程的教程资源,如果本博客也是其他教程的简单重复的话,则网络的角落里又多了一份纯粹的空间占用者。所以我在想,本文章有什么特点才值得一看呢?我觉得有两个优点: 1、本文从代码出发,不会泛泛而谈,能够从细节了解 TVM;2、自认为结构流程图画的不错,能够从整体上把握 TVM 的脉络。所以,也许值得一看呢。

本篇文章介绍 TVM CodeGen 函数。文章 《【TVM】通过代码学习编译流程【4】BuildRelay》 已经介绍了 BuildRelay 总体流程和其子函数 OptimizeImpl 。本篇文章将介绍后续的 CodeGen 流程的部分内容。 Codegen(func_module, func, mod_name) —— 将 Relay IRModule 降级为 TIR Module。

因为代码量巨大,模型编译会分成若干篇文章进行解析。接下来的若干篇都会介绍 BuildRelay 函数 及其调用的子函数。

作为初学者,错误在所难免,还望不吝赐教。

# Python 脚本

这里提供一个简单的 Python 脚本,调用 TVM Python 前端,实现 onnx 模型的编译过程。tvm 通过代码学习编程流程系列文章将基本采用这个脚本帮助追踪代码。

import onnx
from PIL import Image
import numpy as np
import tvm.relay as relay
import tvm
from tvm.contrib import graph_executor
######################################    路径信息    ##########################################
model_path = "/home/xianmu/module/resnet18.onnx"
save_path = "/home/xianmu/module/pythonSave/"
onnx_model = onnx.load(model_path)
##################################    图片信息    ##############################################
img_path = "/home/xianmu/.tvm_test_data/data/imagenet_cat.png"
# Resize it to 224x224
resized_image = Image.open(img_path).resize((224, 224))
img_data = np.asarray(resized_image).astype("float32")
# Our input image is in HWC layout while ONNX expects CHW input, so convert the array
img_data = np.transpose(img_data, (2, 0, 1))
# Normalize according to the ImageNet input specification
imagenet_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
imagenet_stddev = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
norm_img_data = (img_data / 255 - imagenet_mean) / imagenet_stddev
# Add the batch dimension, as we are expecting 4-dimensional input: NCHW.
img_data = np.expand_dims(norm_img_data, axis=0)
####################################     模型编译     ###########################################
input_name = "data"
target = tvm.target.Target(target="llvm", host="llvm")
shape_dict = {input_name: img_data.shape}
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict, export_node_renamed_model_path=save_path)  # 创建 IRModule  高级 Relay IR
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)   # 创建 GraphExecutorFactoryModule  
#######################################     模型保存      ########################################
# save
# 保存编译后的库文件(.so)
lib_fname = save_path + "mod.so"
lib.get_lib().export_library(lib_fname)
# 保存模型参数(.params)
params_fname = save_path + "mod.params"
with open(params_fname, "wb") as param_file:
    param_file.write(relay.save_param_dict(lib.get_params()))
# 保存 JSON 格式的计算图(.json)
json_fname = save_path + "mod.json"
with open(json_fname, "w") as json_file:
    json_file.write(lib.get_executor_config())
dev = tvm.device(str(target), 0)
module = graph_executor.GraphModule(lib["default"](dev))    # graph_executor.GraphModule
############################       运行    ##########################################
module.set_input(input_name, img_data)
module.run()
output_shape = (1, 1000)
tvm_output = module.get_output(0, tvm.nd.empty(output_shape)).numpy()
print(tvm_output)

再回顾一下 BuildRelay 函数,文章 《【TVM】通过代码学习编译流程【4】BuildRelay》 已经介绍了 BuildRelay 总体流程和其子函数 OptimizeImpl 。下面将介绍其中的: MakeExecutorCodegenInitCodegen

void BuildRelay(IRModule relay_module, const String& mod_name) {
    // Relay IRModule -> IRModule optimizations.
    IRModule module = WithAttrs(  // 为 Relay IRModule 添加 Executor 和 Runtime 属性
        relay_module, <!--swig0-->);
    relay_module = OptimizeImpl(std::move(module));  // 执行多个针对 Relay IRModule 的优化 Pass
    // Get the updated function and new IRModule to build.
    Function func = Downcast<Function>(relay_module->Lookup("main"));  // 获取 Relay IRModule 中的 main 函数表达式
    IRModule func_module = WithAttrs(IRModule::FromExpr(func),  // 为 main 函数表达式添加属性信息
                                     <!--swig1-->);
    // Generate code for the updated function.
    executor_codegen_ = MakeExecutorCodegen(executor_->name);   // 构建代码生成 GraphCodegen
    executor_codegen_->Init(nullptr, config_->primitive_targets);  // Codegen 初始化
    executor_codegen_->Codegen(func_module, func, mod_name);   // 将 Relay IRModule 降级为 TIR Module    
    executor_codegen_->UpdateOutput(&ret_);  // 更新降级后的 json 图结构到 BuildOutput 结构体
    ret_.params = executor_codegen_->GetParams();   // 更新降级后的 params 到 BuildOutput 结构体
    auto lowered_funcs = executor_codegen_->GetIRModule(); // 获取降级后的 TIR Module
    // No need to build for external functions.
    Target ext_dev("ext_dev");
    if (lowered_funcs.find(ext_dev) != lowered_funcs.end()) {
      lowered_funcs.Set(ext_dev, IRModule());
    }
    const Target& host_target = config_->host_virtual_device->target;
    const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate");
    // When there is no lowered_funcs due to reasons such as optimization.
    if (lowered_funcs.size() == 0) {
      if (host_target->kind->name == "llvm") {
        CHECK(pf != nullptr) << "Unable to create empty module for llvm without llvm codegen.";
        // If we can decide the target is LLVM, we then create an empty LLVM module.
        ret_.mod = (*pf)(host_target->str(), "empty_module");
      } else {
        // If we cannot decide the target is LLVM, we create an empty CSourceModule.
        // The code content is initialized with ";" to prevent complaining
        // from CSourceModuleNode::SaveToFile.
        ret_.mod = tvm::codegen::CSourceModuleCreate(";", "", Array<String>{});
      }
    } else {
      ret_.mod = tvm::TIRToRuntime(lowered_funcs, host_target);  // TIR Module 转换为 runtime::Module
    }
    auto ext_mods = executor_codegen_->GetExternalModules();
    ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, host_target,
                                                  runtime_, executor_,
                                                  executor_codegen_->GetExecutorCodegenMetadata());
    // Remove external params which were stored in metadata module.
    for (tvm::runtime::Module mod : ext_mods) {
      auto pf_var = mod.GetFunction("get_const_vars");
      if (pf_var != nullptr) {
        Array<String> variables = pf_var();
        for (size_t i = 0; i < variables.size(); i++) {
          auto it = ret_.params.find(variables[i].operator std::string());
          if (it != ret_.params.end()) {
            VLOG(1) << "constant '" << variables[i] << "' has been captured in external module";
            ret_.params.erase(it);
          }
        }
      }
    }
  }

# MakeExecutorCodegen

BuildRelay 函数的总体过程如下图:

buildRelay流程图

函数 MakeExecutorCodegen() 用于创建 GraphCodegen 对象 executor_codegen_GraphCodegen 继承于 ExecutorCodegen .

ExecutorCodegen 有个成员 tvm::runtime::Module mod; ,因此 GraphCodegen 也包含这个成员 mod

如下代码所示, GraphCodegen 在初始化的时候,将成员 mod 赋值为 GetPackedFunc("relay.build_module._GraphExecutorCodegen") 获得的构建函数。

/*!
 * \brief GraphCodegen module wrapper
 
 */
struct GraphCodegen : ExecutorCodegen {  // 继承于 `ExecutorCodegen`
  GraphCodegen() {
    auto pf = GetPackedFunc("relay.build_module._GraphExecutorCodegen");  // 通过名字获取 mod 的构建函数
    mod = (*pf)();  // 将 mod 赋值为 GraphExecutorCodegenModule 对象
  }
  void UpdateOutput(BuildOutput* ret) override { ret->graph_json = GetGraphJSON(); }
  std::string GetGraphJSON() { return CallFunc<std::string>("get_graph_json", nullptr); }
  ~GraphCodegen() {}
};

GetPackedFunc("relay.build_module._GraphExecutorCodegen") 函数通过名字获取 TVM_REGISTER_GLOBAL 注册的全局函数 CreateGraphCodegenMod()

TVM_REGISTER_GLOBAL("relay.build_module._GraphExecutorCodegen")
    .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateGraphCodegenMod(); });
     }

CreateGraphCodegenMod() 函数创建对象 GraphExecutorCodegenModule 并返回。因此 GraphCodegen 对象的成员 mod 赋值为对象 GraphExecutorCodegenModule

而后续 GraphCodegen 的很多功能都会调用 mod 的功能,也就是 GraphExecutorCodegenModule 对象的功能。

runtime::Module CreateGraphCodegenMod() {
  auto ptr = make_object<GraphExecutorCodegenModule>();  // 创建对象 `GraphExecutorCodegenModule`
  return runtime::Module(ptr);  // 封装成 ptr
}

# Init

buildRelay流程图

Init() 函数完成 CodeGen 的初始化。 GraphCodegen 对象 executor_codegen_Init() 函数,首先调用父类 ExecutorCodegenInit() 函数,该函数又调用成员 mod 的初始化函数,即 GraphExecutorCodegenModule 的初始化函数,如下所示:

该函数将 GraphExecutorCodegenModule 的成员 std::shared_ptr<GraphExecutorCodegen> codegen_; 赋值为 GraphExecutorCodegen 对象。

virtual PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) {
    if (name == "init") {
      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
        ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: "
                                    << "runtime::Module mod and Array<Target> targets";
        void* mod = args[0];
        Array<Target> targets = args[1];
        codegen_ = std::make_shared<GraphExecutorCodegen>(reinterpret_cast<runtime::Module*>(mod),
                                                          std::move(targets));
      });
    }

# CodeGen

buildRelay流程图

Codegen(func_module, func, mod_name) —— 将 Relay IRModule 降级为 TIR Module。

这是一个很复杂的过程,本篇只讲解部分。

Codegen() 函数调用的也是成员 mod ,也就是 GraphExecutorCodegenModuleCodegen() 函数,这里不在赘述。下面是 GraphExecutorCodegenModuleCodegen() 函数。

显然, codegen 又调用了 GraphExecutorCodegenModule 成员 codegen_Codegen 函数,即 GraphExecutorCodegen 类的 Codegen 函数。

else if (name == "codegen") {
      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
        IRModule mod = args[0];
        Function func = args[1];
        String mod_name = args[2];
        this->output_ = this->codegen_->Codegen(mod, func, mod_name); // 调用了成员 codegen_的 Codegen 函数
      });

GraphExecutorCodegen 类的 Codegen 函数如下所示:

该函数实现 Relay IRModule 降级为 TIR Module,并返回包含所有信息的结构体 LoweredOutput

其中关键函数 tec::LowerTE() 完成降级过程。下文代码中做了简单的注释。

LoweredOutput Codegen(IRModule mod, relay::Function func, String mod_name) {
    mod_name_ = mod_name;
    VLOG_CONTEXT << "GraphExecutorCodegen";
    VLOG(1) << "compiling:" << std::endl << PrettyPrint(func);
    // TODO(mbs): Why plan memory and update workspace sizes before lowering?
    memory_plan_ = GraphPlanMemory(func);  // 为模型生成内存分配策略,复用内存,但似乎不需要在模型降级之前进行这一步
    backend::FunctionInfo func_info;
    if (memory_plan_.defined()) {
      // TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
      func_info =
          relay::tec::UpdateMainWorkspaceSize(mod, config_, memory_plan_->expr_to_storage_info);
      mod = WithAttr(mod, "main_func_info", func_info);
    }
    IRModule lowered_mod = tec::LowerTE(mod_name_, config_, [this](BaseFunc func) {  
      // We need to maintain the constant map for external
      // functions so we pass this processing function which
      // allows us to process each function as we lower it.
      if (func->GetAttr<String>(attr::kCompiler).defined()) {
        UpdateConstants(func, &params_);
      }
      tec::UpdateFunctionMetadata(func, this->function_metadata_);
    })(mod);  // 模型降级 relay IRmodule 转为 TIR Module
    Optional<backend::FunctionInfo> main_func_info =
        lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");  
    function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value());
    Function lowered_main_func = Downcast<Function>(lowered_mod->Lookup("main"));
    // Now that we have lowered all operators to TIR code, we can proceed with compilation.
    //
    // We need to unfortunately re-plan as the previous results have been invalidated by lowering
    // we will fix this in future refactors.
    memory_plan_ = GraphPlanMemory(lowered_main_func);  // 再次生成内存分配策略  未来可能不会重复进行内存分配
    // The graph planner also can not handle planning calls to global variables to we must remap
    // First we convert all the parameters into input nodes.
    for (auto param : lowered_main_func->params) {
      auto node_ptr = GraphInputNode::make_node_ptr(param->name_hint(), GraphAttrs());
      var_map_[param.get()] = AddNode(node_ptr, param);
    }
    heads_ = VisitExpr(lowered_main_func->body);
    std::ostringstream os;
    dmlc::JSONWriter writer(&os);
    GetJSON(&writer);    // 将图结构写为 json
    LoweredOutput ret;  // LoweredOutput 用于收集降级后的所有信息
    ret.graph_json = os.str();  //json 图结构赋值给 ret
    // Collect any runtime modules generated by external codegen.
    ret.external_mods =  // 外部编译器模型赋值给 ret
        lowered_mod->GetAttr<Array<runtime::Module>>(tvm::attr::kExternalMods).value_or({});
    // Collect any constants extracted by external codegen.
    ret.params = std::unordered_map<std::string, tvm::runtime::NDArray>();
    Map<String, runtime::NDArray> const_name_to_constant =
        lowered_mod->GetAttr<Map<String, runtime::NDArray>>(tvm::attr::kConstNameToConstant)
            .value_or({});
    for (const auto& kv : const_name_to_constant) {
      VLOG(1) << "constant '" << kv.first << "' contributed by external codegen";
      ICHECK(ret.params.emplace(kv.first, kv.second).second);
    }
    // Collect any constants extracted during lowering.
    for (const auto& kv : params_) {
      VLOG(1) << "constant '" << kv.first << "' contributed by TECompiler";
      ICHECK(ret.params.emplace(kv.first, kv.second).second);
    }
    ret.function_metadata = std::move(function_metadata_);  // 函数元数据赋值给 ret
    // This is the point where we separate the functions in the module by target
    ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod);
    ret.metadata =  // 模型元数据赋值给 ret
        ExecutorCodegenMetadata({} /* inputs */, {} /* input_tensor_types */, {} /* outputs */,
                                {} /* output_tensor_types */, {} /* pools */, {} /* devices */,
                                runtime::kTvmExecutorGraph /* executor */, mod_name_ /* mod_name */,
                                "packed" /* interface_api */, Bool(false) /* unpacked_api */);
    return ret;  // 返回包含所有信息的结构体 `LoweredOutput`
  }

再看一下 tec::LowerTE() 函数。其第三个参数 [this](BaseFunc func) {...} 是一个 lambda 表达式,功能和 “外部函数” 相关,外部函数指的是 “标明使用外部编译器编译的 function”,如 “dnnl,ccompilmer” 等,现在暂不关注。

IRModule lowered_mod = tec::LowerTE(mod_name_, config_, [this](BaseFunc func) {  
      // We need to maintain the constant map for external
      // functions so we pass this processing function which
      // allows us to process each function as we lower it.
      if (func->GetAttr<String>(attr::kCompiler).defined()) {
        UpdateConstants(func, &params_);
      }
      tec::UpdateFunctionMetadata(func, this->function_metadata_);
    })(mod);  // 模型降级 relay IRmodule 转为 TIR Module

tec::LowerTE() 函数返回的是一个 Sequential 类,其包含多个按照顺序执行的 Pass。如果对 Pass 还不了解或者遗忘了,可以再回顾一下《【TVM】通过代码学习类【3.5】Pass》

返回的 Sequential 类不仅包含 RelayToTIRTargetHookExtractPrimFuncConstantsInferType() 三个函数获得的 Pass,还包含 CreateModulePass(pass_func, 0, "LowerTE", {"InferType"}) 封装成的 Pass。

CreateModulePass(pass_func, 0, "LowerTE", {"InferType"})pass_func 封装成 Pass。 pass_func 是 lambda 表达式,其调用了含有四个参数的函数 LowerTE(module, module_name, process_fn, complilation_config); ,完成了降级的主要内容。注意该 LowerTE() 含有四个参数,非前述提到的含有三个参数的 tec::LowerTE()

tec::LowerTE() 函数返回 Sequential 类之后立即执行,完成对 Relay Module 的降级。

Pass LowerTE(String module_name, CompilationConfig complilation_config, ProcessFn process_fn) {
  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule module,  //lambda 表达式
                                                                            PassContext ctx) {
    return LowerTE(module, module_name, process_fn, complilation_config);
  };
  return tvm::transform::Sequential(
      {tvm::relay::transform::RelayToTIRTargetHook(complilation_config),
       tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {"InferType"}), InferType(),
       tvm::tir::transform::ExtractPrimFuncConstants()});
}

后续文章将按照顺序介绍 RelayToTIRTargetHookLowerTE()InferType()ExtractPrimFuncConstants

# 后记

本博客目前以及可预期的将来都不会支持评论功能。各位大侠如若有指教和问题,可以在我的 github 项目 或随便一个项目下提出 issue,或者知乎 私信,并指明哪一篇博客,我看到一定及时回复。

Edited on

Give me a cup of [coffee]~( ̄▽ ̄)~*

XianMu WeChat Pay

WeChat Pay

XianMu Alipay

Alipay