查看原文
其他

深入理解TVM:Python/C++互调(上)

月踏 知知爸爸是码农 2022-06-13
一、概述

TVM已经是一个很庞大的系统,包含了很多的功能模块,其中python和c++的互相调用这个功能模块,没有使用第三方的开源库(boost.python、pybind11等),而是自己实现了一套复杂但精致高效强大的机制,值得好好研究学习。这部分内容很多,一篇文章很难说清楚,我准备把这部分分成上、中、下三篇来说,尽可能的把实现原理讲清楚:

  1. 上篇:最底层的c++数据结构支撑(围绕c++端PackedFunc)
  2. 中篇:基于PackedFunc的函数注册(围绕TVM_REGISTER_GLOBAL)
  3. 下篇:偏上层的python的调用细节(围绕ctypes内置库和python端PackedFunc)

本文讲第一部分,也就是围绕PackedFunc这个类来说,它是python和c++互调的桥梁,此类实现代码在include/tvm/runtime/packed_func.h文件中,这里面还有一个TypedPackedFunc类,它只是PackedFunc的一个wrapper,主要增加了类型检查的功能,开发TVM的c++代码要尽可能的使用这个类,但是我们为了把问题尽可能的简化,只关注PackedFunc这个最底层类,其中用到了下面这几个关键的数据结构:

  • TVMValue

  • TVMArgs

  • TVMPODValue_

  • TVMArgValue

  • TVMRetValue

  • TVMArgsSetter

下面结合代码,逐个来说(注:本文基于fe25b9e7c这个commit,下面所有列出的代码都做了相当大量的精简和修改,一来只为讲清楚原理,二来限于篇幅,大家如果有兴趣了解更多的细节,需要再去github上看实际的实现)

二、TVMValue

这是最基本的一个数据结构,是一个union,主要是为了储存c++和其它语言交互时所支持的几种类型的数据,代码很简单(其中DLDataType和DLDevice是两个复合数据类型,限于篇幅,这里不能全部列出来,大家需要自己到github追下细节):

// include/tvm/runtime/c_runtime_api.htypedef union { int64_t v_int64; double v_float64; void* v_handle; const char* v_str; DLDataType v_type; DLDevice v_device;} TVMValue;

三、TVMArgs

这个类主要是为了封装传给PackedFunc的所有参数,这个类也比较简单原始,主要基于TVMValue、参数类型编码、参数个数来实现,代码如下:

class TVMArgs { public: const TVMValue* values; const int* type_codes;  int num_args; TVMArgs(const TVMValue* values, const int* type_codes,           int num_args) { ... }            inline int size() const { return num_args; }  inline TVMArgValue operator[](int i) const {      return TVMArgValue(values[i], type_codes[i]);  }};

四、TVMPODValue_

这是一个内部使用的基类,主要主要服务于后面介绍到的TVMArgValue和TVMRetValue,从名字可以看出,这个类主要是处理POD类型的数据,POD是plain old data的缩写,要么是scalar type,要么是trival type,要么是standard layout type,具体可参考cppreference的PODType、is_pod、is_scalar、is_trivial、is_standard_layout等章节。其实关于POD类型,可以单独写一大篇文章,但它不是本文的重点,以后有时间再专门写文章细说。
这个类的实现核心是强制类型转换运算符重载(在c++中,类型的名字,包括类的名字本身也是一种运算符,即类型强制转换运算符),如下面代码所示:
class TVMPODValue_ { public:  operator double() const { return value_.v_float64; }  operator int64_t() const { return value_.v_int64; }  operator void*() const { return value_.v_handle; } template <typename T>  T* ptr() const { return static_cast<T*>(value_.v_handle); }
protected: TVMValue value_; int type_code_;};

五、TVMArgValue

这个类继承自前面的TVMPODValue_类,用作表示PackedFunc的一个参数,它和TVMPODValue_的区别是扩充了一些数据类型的支持,比如string、PackedFunc、TypedPackedFunc等,其中对后两个的支持是在c++代码中能够调用python函数的根本原因。这个类只使用所保存的underlying data,而不会去做释放,代码如下:

class TVMArgValue : public TVMPODValue_ { public: TVMArgValue() {} TVMArgValue(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {}
operator std::string() const {} operator PackedFunc() const { return *ptr<PackedFunc>(); } const TVMValue& value() const { return value_; }
template <typename T> inline operator T() const; inline operator DLDataType() const; inline operator DataType() const;};

六、TVMRetValue

这个类也是继承自TVMPODValue_类,主要作用是作为存放调用PackedFunc返回值的容器,它和TVMArgValue的区别是,它会管理所保存的underlying data,会对其做释放。这个类主要由四部分构成:

  • 构造和析构函数

  • 对强制类型转换运算符重载的扩展

  • 对赋值运算符的重载

  • 辅助函数,包括释放资源的Clear函数

代码如下:

class TVMRetValue : public TVMPODValue_ { public: // ctor and dtor, dtor will release related buffer TVMRetValue() {} ~TVMRetValue() { this->Clear(); }
// conversion operators operator std::string() const { return *ptr<std::string>(); }  operator DLDataType() const { return value_.v_type; } operator PackedFunc() const { return *ptr<PackedFunc>(); } // Assign operators  TVMRetValue& operator=(double value) {} TVMRetValue& operator=(void* value) {}  TVMRetValue& operator=(int64_t value) {} TVMRetValue& operator=(std::string value) {}  TVMRetValue& operator=(PackedFunc f) {}   private:  // judge type_code_, release underlying data  void Clear() { if (type_code_ == kTVMStr || type_code_ == kTVMBytes) {      delete ptr<std::string>(); } else if(type_code_ == kTVMPackedFuncHandle) {      delete ptr<PackedFunc>(); } else if(type_code_ == kTVMNDArrayHandle) {      NDArray::FFIDecRef(        static_cast<TVMArrayHandle>(value_.v_handle)); } else if(type_code_ == kTVMModuleHandle      || type_code_ == kTVMObjectHandle ) {      static_cast<Object*>(value_.v_handle)->DecRef(); } type_code_ = kTVMNullptr; }};

七、TVMArgsSetter

这是一个用于给TVMValue对象赋值的辅助类,主要通过重载函数调用运算符来实现,主要实现原理如下:

class TVMArgsSetter { public: TVMArgsSetter(TVMValue* values, int* type_codes) : values_(values), type_codes_(type_codes) {}   void operator()(size_t i, double value) const { values_[i].v_float64 = value; type_codes_[i] = kDLFloat; } void operator()(size_t i, const string& value) const { values_[i].v_str = value.c_str(); type_codes_[i] = kTVMStr; } void operator()(size_t i, const PackedFunc& value) const { values_[i].v_handle = const_cast<PackedFunc*>(&value); type_codes_[i] = kTVMPackedFuncHandle; }
private: TVMValue* values_; int* type_codes_;};

八、PackedFunc

有了前面所述的数据结构作为基础,再来看PackedFunc的实现,PackedFunc的实现很简单,内部只使用了一个储存函数指针的变量,再通过重载函数调用运算符来调用这个函数指针所指向的函数,代码如下:

class PackedFunc { public: using FType = function<void(TVMArgs args, TVMRetValue* rv)>; PackedFunc() {} explicit PackedFunc(FType body) : body_(body) {}
template <typename... Args> inline TVMRetValue operator()(Args&&... args) const { const int kNumArgs = sizeof...(Args); const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; TVMValue values[kArraySize]; int type_codes[kArraySize]; detail::for_each(TVMArgsSetter(values, type_codes), std::forward<Args>(args)...); TVMRetValue rv; body_(TVMArgs(values, type_codes, kNumArgs), &rv); return rv; } inline void CallPacked(TVMArgs args, TVMRetValue* rv) const { body_(args, rv); }
private: FType body_;};

九、最后

本文只详细说了开头列出的第一部分,即PackedFunc的实现原理,还有第二第三部分没说,后面的文章会陆续再分析细说这两部分。

最后的最后,TVM的官方文档对PackedFunc机制有一段简短精辟的介绍(https://tvm.apache.org/docs/dev/runtime.html),大家可以作为参考来理解上面代码:

PackedFunc is type-erased, which means that the function signature does not restrict which input type to pass in or type to return. Under the hood, when we call a PackedFunc, it packs the input arguments to TVMArgs on stack, and gets the result back via TVMRetValue. Thanks to template tricks in C++, we can call a PackedFunc just like a normal function. Because of its type-erased nature, we can call a PackedFunc from dynamic languages like python, without additional glue code for each new type function created. 



您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存