查看原文
其他

深入理解TVM:Python/C++互调(中)

月踏 知知爸爸是码农 2022-06-13

一、概述

前文已经讲过python/c++互调的c++端的底层核心数据结构:PackedFunc,详情请见:深入理解TVM:Python/C++互调(上)本文是python/c++互调这个系列的第二篇,主要来讲c++端的函数注册,python端对c++端的函数调用都来源于c++端的注册函数,最主要的一个函数注册宏是TVM_REGISTER_GLOBAL,code base里大概用了1300多次,除了这个注册宏,TVM里还有许多其它的注册宏,这里不一一细说,以后捡有代表性的放到注册机调研这个系列里来说。

二、先看几个注册实例

注册的函数可以是普通函数,也可以是labda表达式,注册接口有三个:set_body、set_body_typed、set_body_method,第一个使用的是PackedFunc,后面两个使用的是TypedPackedFunc,PackedFunc在这个系列的上篇讲过了,TypedPackedFunc是PackedFunc的一个wrapper,实现比较复杂,以后有时间再细说这部分,下面举三个简单示例来展示下这三个注册接口的使用。

使用set_body接口注册lambda表达式:

// src/topi/nn.ccTVM_REGISTER_GLOBAL("topi.nn.relu") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = relu<float>(args[0]);});

使用set_body_typed接口注册lambda表达式:

// src/te/schedule/graph.ccTVM_REGISTER_GLOBAL("schedule.PostDFSOrder")    .set_body_typed([](     const Array<Operation>& roots,      const ReadGraph& g) { return PostDFSOrder(roots, g); });

使用set_body_method接口注册类内函数:

// src/ir/module.ccTVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar") .set_body_method<IRModule>(&IRModuleNode::GetGlobalVar);

三、TVM_REGISTER_GLOBAL宏定义

这个宏定义的本质就是在注册文件定义了一个static的引用变量,引用到注册机内部new出来的一个新的Registry对象:

// include/tvm/runtime/registry.h#define TVM_REGISTER_GLOBAL(OpName)               \  static ::tvm::runtime::Registry& __mk_TVMxxx =  \ ::tvm::runtime::Registry::Register(OpName)

上面的xxx其实是__COUNTER__这个编译器拓展宏生成的一个唯一标识符,GCC文档里对这个宏有详细的描述(https://gcc.gnu.org/onlinedocs/cpp/Common-Predefined-Macros.html):

This macro expands to sequential integral values starting from 0. In conjunction with the ## operator, this provides a convenient means to generate unique identifiers. Care must be taken to ensure that __COUNTER__ is not expanded prior to inclusion of precompiled headers which use it. Otherwise, the precompiled headers will not be used.  

四、Registry::Manager

先来看最核心的Manager类,它是Registry的内部类,用来存储注册的对象,先看下代码:

// src/runtime/registry.ccstruct Registry::Manager { static Manager* Global() { static Manager* inst = new Manager(); return inst;  }  std::mutex mutex; unordered_map<std::string, Registry*> fmap;};

这个数据结构很简单,从上面代码能得到下面几点信息:

  • 数据结构里面带锁,可以保证线程安全

  • Manager是个单例,限制类的实例化对象个数是一种技术,可以限制实例化对象个数为0个、1个、N个,具体可参照《More Effective C++:35个改善编程与设计的有效方法》的条款26:限制某个 class 所能产生的对象数量这一章节

  • 使用unordered_map来存储注册信息,注册对象是Registry指针

五、Registry

这才是注册机的核心数据结构,简化过的代码如下(只保留了关键的数据结构和接口,原文使用了大量的模板、泛型等c++用法):

// include/tvm/runtime/registry.hclass Registry { public:  Registry& set_body(PackedFunc f);  Registry& set_body_typed(FLambda f);  Registry& set_body_method(R (T::*f)(Args...));
  static Registry& Register(const std::string& name);  static const PackedFunc* Get(const std::string& name); static std::vector ListNames();
protected: std::string name_; PackedFunc func_; friend struct Manager;};

Registry的功能可以为三部分,相关的实现代码也比较简单,总结如下:

  • 设置注册函数的set_body系列接口,使用Registry的一系列set_body方法,可以把PackedFunc类型的函数对象设置到Registry对象中

  • 创建Registry对象的Register静态接口,参照下面代码:

Registry& Registry::Register(const std::string& name) { Manager* m = Manager::Global();  std::lock_guard<std::mutex> lock(m->mutex);
Registry* r = new Registry(); r->name_ = name; m->fmap[name] = r; return *r;}
  • 获取注册函数的Get静态接口,代码如下:

const PackedFunc* Registry::Get(const std::string& name) { Manager* m = Manager::Global(); std::lock_guard<std::mutex> lock(m->mutex); auto it = m->fmap.find(name); if (it == m->fmap.end()) return nullptr; return &(it->second->func_);}

六、最后

这一篇的内容相对简单,但是对于python/c++的互调却至关重要,而且注册机也是一个被所有深度学习框架、编译器都会用到的技术,很有必要了解清楚,对于注册机我也准备单独写一个系列,前面已经写了一篇注册机调研:从Caffe开始,后面还会继续完善这个系列。

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存