深入理解TVM:Python/C++互调(上)
TVM已经是一个很庞大的系统,包含了很多的功能模块,其中python和c++的互相调用这个功能模块,没有使用第三方的开源库(boost.python、pybind11等),而是自己实现了一套复杂但精致高效强大的机制,值得好好研究学习。这部分内容很多,一篇文章很难说清楚,我准备把这部分分成上、中、下三篇来说,尽可能的把实现原理讲清楚:
上篇:最底层的c++数据结构支撑(围绕c++端PackedFunc) 中篇:基于PackedFunc的函数注册(围绕TVM_REGISTER_GLOBAL) 下篇:偏上层的python的调用细节(围绕ctypes内置库和python端PackedFunc)
本文讲第一部分,也就是围绕PackedFunc这个类来说,它是python和c++互调的桥梁,此类实现代码在include/tvm/runtime/packed_func.h文件中,这里面还有一个TypedPackedFunc类,它只是PackedFunc的一个wrapper,主要增加了类型检查的功能,开发TVM的c++代码要尽可能的使用这个类,但是我们为了把问题尽可能的简化,只关注PackedFunc这个最底层类,其中用到了下面这几个关键的数据结构:
TVMValue
TVMArgs
TVMPODValue_
TVMArgValue
TVMRetValue
TVMArgsSetter
下面结合代码,逐个来说(注:本文基于fe25b9e7c这个commit,下面所有列出的代码都做了相当大量的精简和修改,一来只为讲清楚原理,二来限于篇幅,大家如果有兴趣了解更多的细节,需要再去github上看实际的实现)
二、TVMValue
这是最基本的一个数据结构,是一个union,主要是为了储存c++和其它语言交互时所支持的几种类型的数据,代码很简单(其中DLDataType和DLDevice是两个复合数据类型,限于篇幅,这里不能全部列出来,大家需要自己到github追下细节):
// include/tvm/runtime/c_runtime_api.h
typedef union {
int64_t v_int64;
double v_float64;
void* v_handle;
const char* v_str;
DLDataType v_type;
DLDevice v_device;
} TVMValue;
三、TVMArgs
这个类主要是为了封装传给PackedFunc的所有参数,这个类也比较简单原始,主要基于TVMValue、参数类型编码、参数个数来实现,代码如下:
class TVMArgs {
public:
const TVMValue* values;
const int* type_codes;
int num_args;
TVMArgs(const TVMValue* values,
const int* type_codes,
int num_args) { ... }
inline int size() const { return num_args; }
inline TVMArgValue operator[](int i) const {
return TVMArgValue(values[i], type_codes[i]);
}
};
四、TVMPODValue_
class TVMPODValue_ {
public:
operator double() const { return value_.v_float64; }
operator int64_t() const { return value_.v_int64; }
operator void*() const { return value_.v_handle; }
template <typename T>
T* ptr() const { return static_cast<T*>(value_.v_handle); }
protected:
TVMValue value_;
int type_code_;
};
五、TVMArgValue
这个类继承自前面的TVMPODValue_类,用作表示PackedFunc的一个参数,它和TVMPODValue_的区别是扩充了一些数据类型的支持,比如string、PackedFunc、TypedPackedFunc等,其中对后两个的支持是在c++代码中能够调用python函数的根本原因。这个类只使用所保存的underlying data,而不会去做释放,代码如下:
class TVMArgValue : public TVMPODValue_ {
public:
TVMArgValue() {}
TVMArgValue(TVMValue value, int type_code)
: TVMPODValue_(value, type_code) {}
operator std::string() const {}
operator PackedFunc() const { return *ptr<PackedFunc>(); }
const TVMValue& value() const { return value_; }
template <typename T>
inline operator T() const;
inline operator DLDataType() const;
inline operator DataType() const;
};
六、TVMRetValue
这个类也是继承自TVMPODValue_类,主要作用是作为存放调用PackedFunc返回值的容器,它和TVMArgValue的区别是,它会管理所保存的underlying data,会对其做释放。这个类主要由四部分构成:
构造和析构函数
对强制类型转换运算符重载的扩展
对赋值运算符的重载
辅助函数,包括释放资源的Clear函数
代码如下:
class TVMRetValue : public TVMPODValue_ {
public:
// ctor and dtor, dtor will release related buffer
TVMRetValue() {}
~TVMRetValue() { this->Clear(); }
// conversion operators
operator std::string() const { return *ptr<std::string>(); }
operator DLDataType() const { return value_.v_type; }
operator PackedFunc() const { return *ptr<PackedFunc>(); }
// Assign operators
TVMRetValue& operator=(double value) {}
TVMRetValue& operator=(void* value) {}
TVMRetValue& operator=(int64_t value) {}
TVMRetValue& operator=(std::string value) {}
TVMRetValue& operator=(PackedFunc f) {}
private:
// judge type_code_, release underlying data
void Clear() {
if (type_code_ == kTVMStr || type_code_ == kTVMBytes) {
delete ptr<std::string>();
} else if(type_code_ == kTVMPackedFuncHandle) {
delete ptr<PackedFunc>();
} else if(type_code_ == kTVMNDArrayHandle) {
NDArray::FFIDecRef(
static_cast<TVMArrayHandle>(value_.v_handle));
} else if(type_code_ == kTVMModuleHandle
|| type_code_ == kTVMObjectHandle ) {
static_cast<Object*>(value_.v_handle)->DecRef();
}
type_code_ = kTVMNullptr;
}
};
七、TVMArgsSetter
这是一个用于给TVMValue对象赋值的辅助类,主要通过重载函数调用运算符来实现,主要实现原理如下:
class TVMArgsSetter {
public:
TVMArgsSetter(TVMValue* values, int* type_codes)
: values_(values), type_codes_(type_codes) {}
void operator()(size_t i, double value) const {
values_[i].v_float64 = value;
type_codes_[i] = kDLFloat;
}
void operator()(size_t i, const string& value) const {
values_[i].v_str = value.c_str();
type_codes_[i] = kTVMStr;
}
void operator()(size_t i, const PackedFunc& value) const {
values_[i].v_handle = const_cast<PackedFunc*>(&value);
type_codes_[i] = kTVMPackedFuncHandle;
}
private:
TVMValue* values_;
int* type_codes_;
};
八、PackedFunc
有了前面所述的数据结构作为基础,再来看PackedFunc的实现,PackedFunc的实现很简单,内部只使用了一个储存函数指针的变量,再通过重载函数调用运算符来调用这个函数指针所指向的函数,代码如下:
class PackedFunc {
public:
using FType = function<void(TVMArgs args, TVMRetValue* rv)>;
PackedFunc() {}
explicit PackedFunc(FType body) : body_(body) {}
template <typename... Args>
inline TVMRetValue operator()(Args&&... args) const {
const int kNumArgs = sizeof...(Args);
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
TVMValue values[kArraySize];
int type_codes[kArraySize];
detail::for_each(TVMArgsSetter(values, type_codes),
std::forward<Args>(args)...);
TVMRetValue rv;
body_(TVMArgs(values, type_codes, kNumArgs), &rv);
return rv;
}
inline void CallPacked(TVMArgs args, TVMRetValue* rv) const {
body_(args, rv);
}
private:
FType body_;
};
九、最后
本文只详细说了开头列出的第一部分,即PackedFunc的实现原理,还有第二第三部分没说,后面的文章会陆续再分析细说这两部分。
最后的最后,TVM的官方文档对PackedFunc机制有一段简短精辟的介绍(https://tvm.apache.org/docs/dev/runtime.html),大家可以作为参考来理解上面代码:
PackedFunc is type-erased, which means that the function signature does not restrict which input type to pass in or type to return. Under the hood, when we call a PackedFunc, it packs the input arguments to TVMArgs on stack, and gets the result back via TVMRetValue. Thanks to template tricks in C++, we can call a PackedFunc just like a normal function. Because of its type-erased nature, we can call a PackedFunc from dynamic languages like python, without additional glue code for each new type function created.