# 前言

文章 《【TVM】通过代码学习编译流程》系列 主要介绍 TVM 在模型编译过程的流程,有时候感觉缺少了对类及其属性和方法的介绍。所以决定在系列文章的中间插入一些 “类的结构及其属性方法” 的介绍。

本篇文章主要介绍 Pass 及其相关类。

作为初学者,错误在所难免,还望不吝赐教。

# Pass

可以再回顾一下在《【TVM】通过代码学习编译流程【4】》中讲到的本体、桥梁、指针的关系。

先看一看 Pass 的基类, 位于 include/tvm/ir/transform.h 。 Pass 本体 PassNode 。内容很少,主要就是 Pass 的执行函数: IRModule operator()(IRModule mod) 函数重载了 “()” 运算符。里面调用自身含有两个参数的 "()" 重载函数。

含有两个参数的 "()" 重载函数 virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx) const = 0; 是个虚函数,这意味着 PassNode 的派生类需要重写该函数,实现 Pass 的实际功能。

class PassNode : public Object {
 public:
  virtual ~PassNode() {}
  /*!
   * \brief Get the pass information/meta data. */
  virtual PassInfo Info() const = 0;
  IRModule operator()(IRModule mod) const {  // 重载了 “()” 运算符
    return this->operator()(std::move(mod), PassContext::Current());  // 调用含有两个参数的 "()" 重载函数
  }
  virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx) const = 0; // 虚函数 由派生类重写
  void VisitAttrs(AttrVisitor* v) {}
  static constexpr const char* _type_key = "transform.Pass";
  TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object);
};

Pass 指针 Pass ,指向 PassNode 本体。相当于给本体套了个壳子。
壳子中的 IRModule operator()(IRModule mod) const; 函数同样是调用自身含有两个参数的 "()" 重载函数。
含有两个参数的 "()" 重载函数 IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) 调用的是本体 PassNode 的功能。

class Pass : public ObjectRef {
 public:
  
  IRModule operator()(IRModule mod) const;
  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const;
  TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode);
 private:
  IRModule static AssertImmutableModule(const IRModule& mod, const PassNode* node,
                                        const PassContext& pass_ctx);
};
 
IRModule Pass::operator()(IRModule mod) const {  // 调用自身含有两个参数的 "()" 重载函数
  return this->operator()(std::move(mod), PassContext::Current());
}
IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const {  // 调用的是本体 `PassNode` 的功能
  const PassNode* node = operator->();
  ICHECK(node != nullptr);
  const PassInfo& pass_info = node->Info();
  if (!pass_ctx.InstrumentBeforePass(mod, pass_info)) {
    DLOG(INFO) << "Skipping pass : " << pass_info->name
               << " with opt level: " << pass_info->opt_level;
    return mod;
  }
  IRModule ret;
  if (pass_ctx->GetConfig<Bool>("testing.immutable_module", Bool(false)).value()) {
    ret = Pass::AssertImmutableModule(mod, node, pass_ctx);
  } else {
    ret = node->operator()(std::move(mod), pass_ctx);
  }
  pass_ctx.InstrumentAfterPass(ret, pass_info);
  return std::move(ret);
}

所以总结来说,Pass 修改模型的功能由 Pass 的派生类重载的 Pass::operator()(IRModule mod, const PassContext& pass_ctx) 函数实现。那么它有哪些派生类呢?后文提供了三个派生类,分别是 FunctionPassSequentialModulePass 。他们有不同的功能作用。

# FunctionPass

FunctionPassNode :: PassNode

Function-level Pass 的实现类,该类是 Pass 的派生类。接收 Module 中函数表达式列表中的一个 function 进行优化。
pass_func 具体实现 function 优化的函数:由外部提供,以 function 为输入,如 Pass DefuseOps()FoldConstant() 等函数提供他们各自的 pass_func ,以实现不同的功能。

Pass::operator()(IRModule mod, const PassContext& pass_ctx) 函数, FunctionPass 对该函数的实现也在下方。

  • 先遍历模型中的 function
  • AsOptimizableFunctionNode() 函数 :过滤掉不能被优化的 function,如 kCompiler (指定编译器的),kExtern (外部编译器的),kSkipOptimization (指明跳过的)
  • 调用 pass_func 优化 function
class FunctionPassNode : public PassNode {
 public:
  PassInfo pass_info;
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func;  // 具体实现 function 优化的函数:由外部提供,以 function 为输入
  FunctionPassNode() = default;
  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); }
  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;  // 不再被派生类重写  具体实现在下方
  PassInfo Info() const override { return pass_info; }
  static constexpr const char* _type_key = "relay.FunctionPass";
  TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPassNode, PassNode);
};
// Perform Module -> Module optimizations at the Function level.
IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const {  // 具体实现
  DiagnosticContext previous = DiagnosticContext::Default(mod);
  IRModule updated_mod = mod->ShallowCopy();
  std::vector<std::pair<GlobalVar, Function>> updates;
  for (const auto& kv : mod->functions) {  //  遍历模型中的 function
    // only process optimizable Relay Functions
    if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {  // 过滤掉不能被优化的 function,如 kCompiler (指定编译器的),kExtern (外部编译器的),kSkipOptimization (指明跳过的)
      Function updated_func = pass_func(GetRef<Function>(function_node), updated_mod, pass_ctx);  // 调用 pass_func 优化 function
      updates.push_back({kv.first, std::move(updated_func)});
    }
  }
  return transform::InferType()(updated_mod);
}

# ModulePass

ModulePassNode :: PassNode

Module-level Pass 的实现类, FunctionPass 优化的是 Relay Module 包含的多个 function ,作用于 function 内部,不能实现 function 增删; ModulePassNode 优化的是整个 Module,能够实现 function 增删等 Module 范围的优化。
pass_func 具体实现 Module 优化的函数:由外部提供,以 Module 为输入
Pass::operator()(IRModule mod, const PassContext& pass_ctx) 函数,实现在下方。

  • 调用 pass_func 优化 Module
/*!
 * \brief Module-level passes are designed to implement global
 * analysis/optimizations, i.e. interprocedural optimizations (IPO), etc. Passes
 * at this level have the full control of a given Relay program including
 * addition and deletion of functions.
 */
class ModulePassNode : public PassNode {
 public:
  /* \brief The pass meta data.*/
  PassInfo pass_info;
  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func;  // 具体实现 Module 优化的函数:由外部提供,以 Module 为输入
  ModulePassNode() = default;
  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); }
  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;  // 不再允许派生类修改
  /*!
   * \brief Get the pass information/meta data.
   */
  PassInfo Info() const override { return pass_info; }
  static constexpr const char* _type_key = "transform.ModulePass";
  TVM_DECLARE_FINAL_OBJECT_INFO(ModulePassNode, PassNode);
};
// Module -> Module optimizations.
IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
  DiagnosticContext previous = DiagnosticContext::Default(mod);
  const PassInfo& pass_info = Info();
  mod = pass_func(std::move(mod), pass_ctx);  // 调用 pass_func 优化 Module
  pass_ctx->diag_ctx.value().Render();
  pass_ctx->diag_ctx = previous;
  return mod;
}

# Sequential

Sequential :Sequential 类包含多个按照顺序执行的 Pass,类似于 pytorch 里面的 nn.Sequential

  • tvm::Array<Pass> passes :数组,包含多个 Pass,如前面提到的 FunctionPassModulePass
  • Pass::operator()(IRModule mod, const PassContext& pass_ctx) 函数,实现在下方。
    • 遍历所有包含的 pass
    • 调用 Pass 执行对模型的优化
/*!
 * \brief The SequentialNode contains a set of passes that transform Relay/Relax
 * programs from one AST to another semantically equivalent one.
 *
 * One example of this level of pass is that the pass manager needs to correctly
 * perform a host of optimizations with a given optimization level and disabled
 * passes.
 */
class SequentialNode : public PassNode {
 public:
  /* \brief The pass meta data.*/
  PassInfo pass_info;
  /*! \brief A list of passes that used to compose a sequential pass. */
  tvm::Array<Pass> passes;  // 数组,包含多个 Pass,如前面提到的 `FunctionPass`, `ModulePass`
  PassInfo Info() const override { return pass_info; }
  void ResolveDependency(const IRModule& mod);
  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
  static constexpr const char* _type_key = "transform.Sequential";
  TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
};
IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
  for (const Pass& pass : passes) {  // 遍历所有包含的 pass
    
    // resolve dependencies
    for (const auto& it : pass_info->required) {
      mod = GetPass(it)(std::move(mod), pass_ctx);
    }
  
    if (pass_ctx->trace_stack.size() && !pass_info->traceable &&
        (!pass_ctx->make_traceable.defined() ||
         pass_ctx->make_traceable.value().count(pass_info->name))) {
      // In the future, we should pass the ffi key for a pass by deducing from its name.
      String transform_func_key = "relax.tuning_api.Choice.default_transform_func";
      String constr_func_key = "relax.tuning_api.Choice.default_constr_func";
      relax::Knob knob = relax::Knob(
          pass_info->name, <!--swig0-->);
      // Add new decision to the trace at the top of the stack.
      auto trace = Downcast<relax::Trace>(pass_ctx->trace_stack.back());
      trace->Add(knob, "Applied");
      mod = pass(std::move(mod), pass_ctx);  // 调用 Pass 执行对模型的优化
      trace->SetOutMod(mod);
    } else {
      mod = pass(std::move(mod), pass_ctx);  // 调用 Pass 执行对模型的优化
    }
  }
  return mod;
}

# 后记

本博客目前以及可预期的将来都不会支持评论功能。各位大侠如若有指教和问题,可以在我的 github 项目 或随便一个项目下提出 issue,或者知乎 私信,并指明哪一篇博客,我看到一定及时回复。

Edited on

Give me a cup of [coffee]~( ̄▽ ̄)~*

XianMu WeChat Pay

WeChat Pay

XianMu Alipay

Alipay