深入理解TVM:Python/C++互调(中)
一、概述
二、先看几个注册实例
注册的函数可以是普通函数,也可以是labda表达式,注册接口有三个:set_body、set_body_typed、set_body_method,第一个使用的是PackedFunc,后面两个使用的是TypedPackedFunc,PackedFunc在这个系列的上篇讲过了,TypedPackedFunc是PackedFunc的一个wrapper,实现比较复杂,以后有时间再细说这部分,下面举三个简单示例来展示下这三个注册接口的使用。
使用set_body接口注册lambda表达式:
// src/topi/nn.cc
TVM_REGISTER_GLOBAL("topi.nn.relu")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = relu<float>(args[0]);
});
使用set_body_typed接口注册lambda表达式:
// src/te/schedule/graph.cc
TVM_REGISTER_GLOBAL("schedule.PostDFSOrder")
.set_body_typed([](
const Array<Operation>& roots,
const ReadGraph& g) {
return PostDFSOrder(roots, g);
});
使用set_body_method接口注册类内函数:
// src/ir/module.cc
TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar")
.set_body_method<IRModule>(&IRModuleNode::GetGlobalVar);
三、TVM_REGISTER_GLOBAL宏定义
这个宏定义的本质就是在注册文件定义了一个static的引用变量,引用到注册机内部new出来的一个新的Registry对象:
// include/tvm/runtime/registry.h
#define TVM_REGISTER_GLOBAL(OpName) \
static ::tvm::runtime::Registry& __mk_TVMxxx = \
::tvm::runtime::Registry::Register(OpName)
上面的xxx其实是__COUNTER__这个编译器拓展宏生成的一个唯一标识符,GCC文档里对这个宏有详细的描述(https://gcc.gnu.org/onlinedocs/cpp/Common-Predefined-Macros.html):
This macro expands to sequential integral values starting from 0. In conjunction with the ## operator, this provides a convenient means to generate unique identifiers. Care must be taken to ensure that __COUNTER__ is not expanded prior to inclusion of precompiled headers which use it. Otherwise, the precompiled headers will not be used.
四、Registry::Manager
先来看最核心的Manager类,它是Registry的内部类,用来存储注册的对象,先看下代码:
// src/runtime/registry.cc
struct Registry::Manager {
static Manager* Global() {
static Manager* inst = new Manager();
return inst;
}
std::mutex mutex;
unordered_map<std::string, Registry*> fmap;
};
这个数据结构很简单,从上面代码能得到下面几点信息:
数据结构里面带锁,可以保证线程安全
Manager是个单例,限制类的实例化对象个数是一种技术,可以限制实例化对象个数为0个、1个、N个,具体可参照《More Effective C++:35个改善编程与设计的有效方法》的条款26:限制某个 class 所能产生的对象数量这一章节
使用unordered_map来存储注册信息,注册对象是Registry指针
五、Registry
这才是注册机的核心数据结构,简化过的代码如下(只保留了关键的数据结构和接口,原文使用了大量的模板、泛型等c++用法):
// include/tvm/runtime/registry.h
class Registry {
public:
Registry& set_body(PackedFunc f);
Registry& set_body_typed(FLambda f);
Registry& set_body_method(R (T::*f)(Args...));
static Registry& Register(const std::string& name);
static const PackedFunc* Get(const std::string& name);
static std::vector ListNames();
protected:
std::string name_;
PackedFunc func_;
friend struct Manager;
};
Registry的功能可以为三部分,相关的实现代码也比较简单,总结如下:
设置注册函数的set_body系列接口,使用Registry的一系列set_body方法,可以把PackedFunc类型的函数对象设置到Registry对象中
创建Registry对象的Register静态接口,参照下面代码:
Registry& Registry::Register(const std::string& name) {
Manager* m = Manager::Global();
std::lock_guard<std::mutex> lock(m->mutex);
Registry* r = new Registry();
r->name_ = name;
m->fmap[name] = r;
return *r;
}
获取注册函数的Get静态接口,代码如下:
const PackedFunc* Registry::Get(const std::string& name) {
Manager* m = Manager::Global();
std::lock_guard<std::mutex> lock(m->mutex);
auto it = m->fmap.find(name);
if (it == m->fmap.end()) return nullptr;
return &(it->second->func_);
}
六、最后
这一篇的内容相对简单,但是对于python/c++的互调却至关重要,而且注册机也是一个被所有深度学习框架、编译器都会用到的技术,很有必要了解清楚,对于注册机我也准备单独写一个系列,前面已经写了一篇注册机调研:从Caffe开始,后面还会继续完善这个系列。