深入理解TVM:RELAY_REGISTER_OP
一、简介
RELAY_REGISTER_OP是TVM中非常重要的一个注册宏,用来注册relay中的op,目前code base中大概有200多个(基于commit fe25b9e7c,因为官方的代码一直在更新,所以这个数字也会继续变化,但现在op基本趋于稳定,应该不会再有大的变化),先贴几个relay中注册的op,给大家一个直观的印象:
下面来详细介绍这个注册机制的实现和原理,为了便于理解,所有代码都做了不同程度的精简,有些语法不通的地方是混用了一些伪代码。
二、定义
这个宏定义在include/tvm/relay/op.h,具体如下:
#define RELAY_REGISTER_OP(OpName) TVM_REGISTER_OP(OpName)
可以看到这个宏其实是TVM_REGISTER_OP(定义在include/tvm/ir/op.h)这个宏的alias,TVM_REGISTER_OP的定义如下(做了化简):
#define TVM_REGISTER_OP(OpName) \
static OpRegEntry& __make_##Op__COUNTER__ = \
OpRegEntry::RegisterOrGet(OpName).set_name()
这里定义了一个OpRegEntry的static引用变量,__COUNTER__宏保证这个变量名全局唯一(参照前文注册机调研:TVM之TVM_REGISTER_GLOBAL),static变量保证在main函数执行之前完成初始化,上面RegisterOrGet的实现如下:
OpRegEntry& RegisterOrGet(const String& name) {
if (entry_map_.find(name) != entry_map_.end())
return entry_map_[name];
auto entry = make_unique<OpRegEntry>(entries_.size());
entry->name = name;
entry_map_[name] = entry.get();
entries_.emplace_back(move(entry));
return entry_map_[name];
}
三、OpRegEntry
每注册一个op,都会创建这个数据结构的一个实例,前面的一段代码可以看到实际的创建过程,最后把创建好的实例保存在了entries_这个变量里。OpRegEntry定义在include/tvm/ir/op.h,主要接口和数据成员如下:
class OpRegEntry {
public:
static OpRegEntry& RegisterOrGet(const String& name);
private:
string name;
Op op_;
public:
OpRegEntry& describe(const string& desp);
OpRegEntry& add_argument(const string& name, ...);
OpRegEntry& add_type_rel(const string& rel_name);
OpRegEntry& set_attrs_type();
OpRegEntry& set_num_inputs(int32_t n);
OpRegEntry& set_support_level(int32_t level);
OpRegEntry& set_attr(const std::string& attr_name, ...);
OpRegEntry& set_name();
};
这里使用的时候也是可以像下面这样使用链式调用(src/relay/op/nn/convolution.cc):
RELAY_REGISTER_OP("nn.conv1d")
.describe(R"code()code" TVM_ADD_FILELINE)
.set_attrs_type<Conv1DAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2)
.add_type_rel("Conv1D", Conv1DRel<Conv1DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ...);
四、AttrRegistry
在注册的过程中,OpRegEntry只是提供一堆设置的接口,具体的设置内容是通过OpRegEntry来间接设置到AttrRegistry中去的,AttrRegistry是一个模板类,不光op的注册是用的这个基础类,TVM中还有另外的两个模块的注册也是基于这个基础类:
using OpRegistry = AttrRegistry<OpRegEntry, Op>
using TargetTagRegistry = AttrRegistry<TargetTagRegEntry, TargetTag>
using TargetKindRegistry = AttrRegistry<TargetKindRegEntry, TargetKind>
AttrRegistry是一个单例,主要数据结构和API如下(我直接把这个模板类的模板参数换成上面第一个应用中的OpRegEntry和Op):
// src/node/attr_registry.h
class AttrRegistry {
public:
// 通过创建一个OpRegEntry的方式注册一个op,
// 实现代码在前面第二节的最后已经讲过
OpRegEntry &RegisterOrGet(const String &name);
// 获取维护在entries_中的已经创建过的OpRegEntry
const OpRegEntry *Get(const String &name) const;
// 获取所有创建过的op name
Array<String> ListAllNames() const;
void UpdateAttr(const String &attr_name, ...);
// 这个类的一个单例,代码简单不再列出
static AttrRegistry *Global();
private:
// entries_用于维护创建好的OpRegEntry对象的生命周期
// entry_map_使用map来快速使用创建过的OpRegEntry
vector<unique_ptr<OpRegEntry>> entries_;
map<String, OpRegEntry *> entry_map_;
// attrs_维护了op的属性,code base其实是unordered_map,
// 这里为了省字符,包括上面的entry_map_,就都用map了
map<String, unique_ptr<AttrRegistryMapContainerMap<Op>>> attrs_;
};
上面代码中的注释已经解释了大多数的接口,还剩下UpdateAttr这个关键的API没说,这个API由前面第三节的最后的op注册示例中的set_attr一路调用而来,是用来设置op的attribute的,所以放在下一节和AttrRegistryMapContainerMap一起说
五、AttrRegistryMapContainerMap
先看AttrRegistryMapContainerMap这个数据结构,这也是个模板类,我们这里说的是op的注册,所以我们直接使用Op作为模板参数来看这个类:
class AttrRegistryMapContainerMap {
public:
// 这个API是用于检查对于给定op,是否有对应的attribute被设置
// 个人感觉这个API的名字和返回值类型都不合适,感觉改成类似:
// bool HasAttr(const Op &op) const;
// 更加合适
int count(const Op &op) const;
// 用来取给定op的attribute
const runtime::TVMRetValue& operator[](const Op& op) const;
// 也是用来取给定op的attribute,只是带默认值
// 这是个模板函数,为了简化代码直接改成了下面这样
ValueType get(const Op &op, ValueType def_value) const;
private:
String attr_name_;
vector<pair<TVMRetValue, int>> data_;
};
这个数据结构主要有两个数据成员:
attr_name_:意如其名,attr的名字
data_:具体的attr value,这是个vector,其index对应注册op时创建OpRegEntry对象时使用的index,也对应创建OpRegEntry对象时创建的一个OpNode对象中的index_这个值,总之这个index可以认为是一个op的一个唯一identifier
看完AttrRegistryMapContainerMap这个数据结构,再看上一节中的UpdateAttr这个API,这个API主要是把注册op时调用set_attr设置的attribute维护到上面介绍的AttrRegistryMapContainerMap中去:
void UpdateAttr(const String &attr_name, const Op &op,
TVMRetValue value, int plevel) {
auto &op_map = attrs_[attr_name];
if (op_map == nullptr) {
op_map.reset(new AttrRegistryMapContainerMap<Op>());
op_map->attr_name_ = attr_name;
}
auto index = op->AttrRegistryIndex();
if (op_map->data_.size() <= index) {
op_map->data_.resize(index + 1, make_pair(TVMRetValue(), 0));
}
pair<TVMRetValue, int>& p = op_map->data_[index];
if (p.second < plevel && value.type_code() != kTVMNullptr) {
op_map->data_[index] = make_pair(value, plevel);
}
}
这里第一句代码auto &op_map = attrs_[attr_name]就有两个值得注意的点:
这里使用的是operator[],attrs_是一个map,如果不存在attr_name这个item,则会自动创建一个
op_map必须是一个引用变量,因为map中的value是一个unique_ptr
所有有相同attribute名字的op的attribute都会维护在同一个AttrRegistryMapContainerMap对象的data_成员中,使用op的index_作为索引,我对注册op的set_attr设置的attribute进行了一下统计,如下图:
可见目前设置的op attr有17种,每一种的attr name和拥有这个attr的op的个数一目了然,这里可以看到最多的op num有355个,比文章开头统计的RELAY_REGISTER_OP的个数要多,是因为除了RELAY_REGISTER_OP这个宏注册OP之外,还有其它的宏来注册。
六、最后
RELAY_REGISTER_OP这个op注册宏用的非常多,和前文注册机调研:TVM之TVM_REGISTER_GLOBAL这个注册宏一样重要,这里也不知道有没有完全讲清楚,希望大家留言反馈。
写分析代码的文章不是很容易,代码中数据结构的嵌套、函数的层层调用、冗余代码,这些东西通过文字描述出来,如果不做化简、组织不好的话,很难讲的清楚,这篇简单的分析花了整整一天时间才写好,但肯定有很多自己看不到的问题,只能根据反馈边写边迭代改进,希望能越做越好。