# 前言
文章 《【TVM】通过代码学习编译流程》系列 主要介绍 TVM 在模型编译过程的流程,有时候感觉缺少了对类及其属性和方法的介绍。所以决定在系列文章的中间插入一些 “类的结构及其属性方法” 的介绍。
本篇文章主要介绍 Pass 及其相关类。
作为初学者,错误在所难免,还望不吝赐教。
# Pass
可以再回顾一下在《【TVM】通过代码学习编译流程【4】》中讲到的本体、桥梁、指针的关系。
先看一看 Pass 的基类, 位于 include/tvm/ir/transform.h
。 Pass 本体 PassNode
。内容很少,主要就是 Pass 的执行函数: IRModule operator()(IRModule mod)
函数重载了 “()” 运算符。里面调用自身含有两个参数的 "()" 重载函数。
含有两个参数的 "()" 重载函数 virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx) const = 0;
是个虚函数,这意味着 PassNode
的派生类需要重写该函数,实现 Pass 的实际功能。
class PassNode : public Object { | |
public: | |
virtual ~PassNode() {} | |
/*! | |
* \brief Get the pass information/meta data. */ | |
virtual PassInfo Info() const = 0; | |
IRModule operator()(IRModule mod) const { // 重载了 “()” 运算符 | |
return this->operator()(std::move(mod), PassContext::Current()); // 调用含有两个参数的 "()" 重载函数 | |
} | |
virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx) const = 0; // 虚函数 由派生类重写 | |
void VisitAttrs(AttrVisitor* v) {} | |
static constexpr const char* _type_key = "transform.Pass"; | |
TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object); | |
}; |
Pass 指针 Pass
,指向 PassNode
本体。相当于给本体套了个壳子。
壳子中的 IRModule operator()(IRModule mod) const;
函数同样是调用自身含有两个参数的 "()" 重载函数。
含有两个参数的 "()" 重载函数 IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx)
调用的是本体 PassNode
的功能。
class Pass : public ObjectRef { | |
public: | |
IRModule operator()(IRModule mod) const; | |
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const; | |
TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode); | |
private: | |
IRModule static AssertImmutableModule(const IRModule& mod, const PassNode* node, | |
const PassContext& pass_ctx); | |
}; | |
IRModule Pass::operator()(IRModule mod) const { // 调用自身含有两个参数的 "()" 重载函数 | |
return this->operator()(std::move(mod), PassContext::Current()); | |
} | |
IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const { // 调用的是本体 `PassNode` 的功能 | |
const PassNode* node = operator->(); | |
ICHECK(node != nullptr); | |
const PassInfo& pass_info = node->Info(); | |
if (!pass_ctx.InstrumentBeforePass(mod, pass_info)) { | |
DLOG(INFO) << "Skipping pass : " << pass_info->name | |
<< " with opt level: " << pass_info->opt_level; | |
return mod; | |
} | |
IRModule ret; | |
if (pass_ctx->GetConfig<Bool>("testing.immutable_module", Bool(false)).value()) { | |
ret = Pass::AssertImmutableModule(mod, node, pass_ctx); | |
} else { | |
ret = node->operator()(std::move(mod), pass_ctx); | |
} | |
pass_ctx.InstrumentAfterPass(ret, pass_info); | |
return std::move(ret); | |
} |
所以总结来说,Pass 修改模型的功能由 Pass 的派生类重载的 Pass::operator()(IRModule mod, const PassContext& pass_ctx)
函数实现。那么它有哪些派生类呢?后文提供了三个派生类,分别是 FunctionPass
, Sequential
, ModulePass
。他们有不同的功能作用。
# FunctionPass
FunctionPassNode :: PassNode
Function-level Pass 的实现类,该类是 Pass 的派生类。接收 Module 中函数表达式列表中的一个 function 进行优化。
pass_func
具体实现 function 优化的函数:由外部提供,以 function 为输入,如 Pass DefuseOps()
, FoldConstant()
等函数提供他们各自的 pass_func
,以实现不同的功能。
Pass::operator()(IRModule mod, const PassContext& pass_ctx)
函数, FunctionPass
对该函数的实现也在下方。
- 先遍历模型中的 function
AsOptimizableFunctionNode()
函数 :过滤掉不能被优化的 function,如 kCompiler (指定编译器的),kExtern (外部编译器的),kSkipOptimization (指明跳过的)- 调用 pass_func 优化 function
class FunctionPassNode : public PassNode { | |
public: | |
PassInfo pass_info; | |
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func; // 具体实现 function 优化的函数:由外部提供,以 function 为输入 | |
FunctionPassNode() = default; | |
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } | |
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; // 不再被派生类重写 具体实现在下方 | |
PassInfo Info() const override { return pass_info; } | |
static constexpr const char* _type_key = "relay.FunctionPass"; | |
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPassNode, PassNode); | |
}; |
// Perform Module -> Module optimizations at the Function level. | |
IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { // 具体实现 | |
DiagnosticContext previous = DiagnosticContext::Default(mod); | |
IRModule updated_mod = mod->ShallowCopy(); | |
std::vector<std::pair<GlobalVar, Function>> updates; | |
for (const auto& kv : mod->functions) { // 遍历模型中的 function | |
// only process optimizable Relay Functions | |
if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { // 过滤掉不能被优化的 function,如 kCompiler (指定编译器的),kExtern (外部编译器的),kSkipOptimization (指明跳过的) | |
Function updated_func = pass_func(GetRef<Function>(function_node), updated_mod, pass_ctx); // 调用 pass_func 优化 function | |
updates.push_back({kv.first, std::move(updated_func)}); | |
} | |
} | |
return transform::InferType()(updated_mod); | |
} |
# ModulePass
ModulePassNode :: PassNode
Module-level Pass 的实现类, FunctionPass
优化的是 Relay Module 包含的多个 function
,作用于 function
内部,不能实现 function
增删; ModulePassNode
优化的是整个 Module,能够实现 function
增删等 Module 范围的优化。
pass_func
具体实现 Module 优化的函数:由外部提供,以 Module 为输入
Pass::operator()(IRModule mod, const PassContext& pass_ctx)
函数,实现在下方。
- 调用 pass_func 优化 Module
/*! | |
* \brief Module-level passes are designed to implement global | |
* analysis/optimizations, i.e. interprocedural optimizations (IPO), etc. Passes | |
* at this level have the full control of a given Relay program including | |
* addition and deletion of functions. | |
*/ | |
class ModulePassNode : public PassNode { | |
public: | |
/* \brief The pass meta data.*/ | |
PassInfo pass_info; | |
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func; // 具体实现 Module 优化的函数:由外部提供,以 Module 为输入 | |
ModulePassNode() = default; | |
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } | |
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; // 不再允许派生类修改 | |
/*! | |
* \brief Get the pass information/meta data. | |
*/ | |
PassInfo Info() const override { return pass_info; } | |
static constexpr const char* _type_key = "transform.ModulePass"; | |
TVM_DECLARE_FINAL_OBJECT_INFO(ModulePassNode, PassNode); | |
}; |
// Module -> Module optimizations. | |
IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { | |
DiagnosticContext previous = DiagnosticContext::Default(mod); | |
const PassInfo& pass_info = Info(); | |
mod = pass_func(std::move(mod), pass_ctx); // 调用 pass_func 优化 Module | |
pass_ctx->diag_ctx.value().Render(); | |
pass_ctx->diag_ctx = previous; | |
return mod; | |
} |
# Sequential
Sequential :Sequential 类包含多个按照顺序执行的 Pass,类似于 pytorch 里面的 nn.Sequential
tvm::Array<Pass> passes
:数组,包含多个 Pass,如前面提到的FunctionPass
,ModulePass
Pass::operator()(IRModule mod, const PassContext& pass_ctx)
函数,实现在下方。- 遍历所有包含的 pass
- 调用 Pass 执行对模型的优化
/*! | |
* \brief The SequentialNode contains a set of passes that transform Relay/Relax | |
* programs from one AST to another semantically equivalent one. | |
* | |
* One example of this level of pass is that the pass manager needs to correctly | |
* perform a host of optimizations with a given optimization level and disabled | |
* passes. | |
*/ | |
class SequentialNode : public PassNode { | |
public: | |
/* \brief The pass meta data.*/ | |
PassInfo pass_info; | |
/*! \brief A list of passes that used to compose a sequential pass. */ | |
tvm::Array<Pass> passes; // 数组,包含多个 Pass,如前面提到的 `FunctionPass`, `ModulePass` | |
PassInfo Info() const override { return pass_info; } | |
void ResolveDependency(const IRModule& mod); | |
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; | |
static constexpr const char* _type_key = "transform.Sequential"; | |
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode); | |
}; |
IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const { | |
for (const Pass& pass : passes) { // 遍历所有包含的 pass | |
// resolve dependencies | |
for (const auto& it : pass_info->required) { | |
mod = GetPass(it)(std::move(mod), pass_ctx); | |
} | |
if (pass_ctx->trace_stack.size() && !pass_info->traceable && | |
(!pass_ctx->make_traceable.defined() || | |
pass_ctx->make_traceable.value().count(pass_info->name))) { | |
// In the future, we should pass the ffi key for a pass by deducing from its name. | |
String transform_func_key = "relax.tuning_api.Choice.default_transform_func"; | |
String constr_func_key = "relax.tuning_api.Choice.default_constr_func"; | |
relax::Knob knob = relax::Knob( | |
pass_info->name, <!--swig0-->); | |
// Add new decision to the trace at the top of the stack. | |
auto trace = Downcast<relax::Trace>(pass_ctx->trace_stack.back()); | |
trace->Add(knob, "Applied"); | |
mod = pass(std::move(mod), pass_ctx); // 调用 Pass 执行对模型的优化 | |
trace->SetOutMod(mod); | |
} else { | |
mod = pass(std::move(mod), pass_ctx); // 调用 Pass 执行对模型的优化 | |
} | |
} | |
return mod; | |
} |
# 后记
本博客目前以及可预期的将来都不会支持评论功能。各位大侠如若有指教和问题,可以在我的 github 项目 或随便一个项目下提出 issue,或者知乎 私信,并指明哪一篇博客,我看到一定及时回复。