查看原文
其他

深入理解TVM:RELAY_REGISTER_OP

月踏 知知爸爸是码农 2022-06-13

一、简介

RELAY_REGISTER_OP是TVM中非常重要的一个注册宏,用来注册relay中的op,目前code base中大概有200多个(基于commit fe25b9e7c,因为官方的代码一直在更新,所以这个数字也会继续变化,但现在op基本趋于稳定,应该不会再有大的变化),先贴几个relay中注册的op,给大家一个直观的印象:

下面来详细介绍这个注册机制的实现和原理,为了便于理解,所有代码都做了不同程度的精简,有些语法不通的地方是混用了一些伪代码。

二、定义

这个宏定义在include/tvm/relay/op.h,具体如下:

#define RELAY_REGISTER_OP(OpName) TVM_REGISTER_OP(OpName)

可以看到这个宏其实是TVM_REGISTER_OP(定义在include/tvm/ir/op.h)这个宏的alias,TVM_REGISTER_OP的定义如下(做了化简):

#define TVM_REGISTER_OP(OpName)                   \  static OpRegEntry& __make_##Op__COUNTER__ =     \      OpRegEntry::RegisterOrGet(OpName).set_name()

这里定义了一个OpRegEntry的static引用变量,__COUNTER__宏保证这个变量名全局唯一(参照前文注册机调研:TVM之TVM_REGISTER_GLOBAL),static变量保证在main函数执行之前完成初始化,上面RegisterOrGet的实现如下:

OpRegEntry& RegisterOrGet(const String& name) {  if (entry_map_.find(name) != entry_map_.end())     return entry_map_[name];  auto entry = make_unique<OpRegEntry>(entries_.size());  entry->name = name;  entry_map_[name] = entry.get();  entries_.emplace_back(move(entry));  return entry_map_[name];}

三、OpRegEntry

每注册一个op,都会创建这个数据结构的一个实例,前面的一段代码可以看到实际的创建过程,最后把创建好的实例保存在了entries_这个变量里。OpRegEntry定义在include/tvm/ir/op.h,主要接口和数据成员如下:

class OpRegEntry {public: static OpRegEntry& RegisterOrGet(const String& name);private: string name; Op op_;public:  OpRegEntry& describe(const string& desp);  OpRegEntry& add_argument(const string& name, ...);  OpRegEntry& add_type_rel(const string& rel_name); OpRegEntry& set_attrs_type(); OpRegEntry& set_num_inputs(int32_t n);  OpRegEntry& set_support_level(int32_t level);  OpRegEntry& set_attr(const std::string& attr_name, ...);  OpRegEntry& set_name();};

这里使用的时候也是可以像下面这样使用链式调用(src/relay/op/nn/convolution.cc):

RELAY_REGISTER_OP("nn.conv1d")    .describe(R"code()code" TVM_ADD_FILELINE) .set_attrs_type<Conv1DAttrs>() .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(2) .add_type_rel("Conv1D", Conv1DRel<Conv1DAttrs>)    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ...);

四、AttrRegistry

在注册的过程中,OpRegEntry只是提供一堆设置的接口,具体的设置内容是通过OpRegEntry来间接设置到AttrRegistry中去的,AttrRegistry是一个模板类,不光op的注册是用的这个基础类,TVM中还有另外的两个模块的注册也是基于这个基础类:

  1. using OpRegistry = AttrRegistry<OpRegEntry, Op>

  2. using TargetTagRegistry = AttrRegistry<TargetTagRegEntry, TargetTag>

  3. using TargetKindRegistry = AttrRegistry<TargetKindRegEntry, TargetKind>

AttrRegistry是一个单例,主要数据结构和API如下(我直接把这个模板类的模板参数换成上面第一个应用中的OpRegEntry和Op):

// src/node/attr_registry.hclass AttrRegistry {public: // 通过创建一个OpRegEntry的方式注册一个op, // 实现代码在前面第二节的最后已经讲过 OpRegEntry &RegisterOrGet(const String &name); // 获取维护在entries_中的已经创建过的OpRegEntry const OpRegEntry *Get(const String &name) const; // 获取所有创建过的op name  Array<String> ListAllNames() const;    void UpdateAttr(const String &attr_name, ...);
  // 这个类的一个单例,代码简单不再列出  static AttrRegistry *Global();
private:  // entries_用于维护创建好的OpRegEntry对象的生命周期  // entry_map_使用map来快速使用创建过的OpRegEntry vector<unique_ptr<OpRegEntry>> entries_; map<String, OpRegEntry *> entry_map_;  // attrs_维护了op的属性,code base其实是unordered_map,  // 这里为了省字符,包括上面的entry_map_,就都用map了 map<String, unique_ptr<AttrRegistryMapContainerMap<Op>>> attrs_;};

上面代码中的注释已经解释了大多数的接口,还剩下UpdateAttr这个关键的API没说,这个API由前面第三节的最后的op注册示例中的set_attr一路调用而来,是用来设置op的attribute的,所以放在下一节和AttrRegistryMapContainerMap一起说

五、AttrRegistryMapContainerMap

先看AttrRegistryMapContainerMap这个数据结构,这也是个模板类,我们这里说的是op的注册,所以我们直接使用Op作为模板参数来看这个类:

class AttrRegistryMapContainerMap {public:  // 这个API是用于检查对于给定op,是否有对应的attribute被设置  // 个人感觉这个API的名字和返回值类型都不合适,感觉改成类似:  // bool HasAttr(const Op &op) const;  // 更加合适 int count(const Op &op) const;  // 用来取给定op的attribute const runtime::TVMRetValue& operator[](const Op& op) const;  // 也是用来取给定op的attribute,只是带默认值  // 这是个模板函数,为了简化代码直接改成了下面这样 ValueType get(const Op &op, ValueType def_value) const;
private: String attr_name_; vector<pair<TVMRetValue, int>> data_;};

这个数据结构主要有两个数据成员:

  • attr_name_:意如其名,attr的名字

  • data_:具体的attr value,这是个vector,其index对应注册op时创建OpRegEntry对象时使用的index,也对应创建OpRegEntry对象时创建的一个OpNode对象中的index_这个值,总之这个index可以认为是一个op的一个唯一identifier

看完AttrRegistryMapContainerMap这个数据结构,再看上一节中的UpdateAttr这个API,这个API主要是把注册op时调用set_attr设置的attribute维护到上面介绍的AttrRegistryMapContainerMap中去:

void UpdateAttr(const String &attr_name, const Op &op,                        TVMRetValue value, int plevel) { auto &op_map = attrs_[attr_name]; if (op_map == nullptr) { op_map.reset(new AttrRegistryMapContainerMap<Op>()); op_map->attr_name_ = attr_name; }
  auto index = op->AttrRegistryIndex(); if (op_map->data_.size() <= index) { op_map->data_.resize(index + 1, make_pair(TVMRetValue(), 0)); } pair<TVMRetValue, int>& p = op_map->data_[index]; if (p.second < plevel && value.type_code() != kTVMNullptr) { op_map->data_[index] = make_pair(value, plevel); }}

这里第一句代码auto &op_map = attrs_[attr_name]就有两个值得注意的点:

  • 这里使用的是operator[],attrs_是一个map,如果不存在attr_name这个item,则会自动创建一个

  • op_map必须是一个引用变量,因为map中的value是一个unique_ptr

这个API总体来说是在操作attrs_这个变量,AttrRegistry管理注册op设置的attr有点不是很直观,我画了一个图,希望能能直观好理解一些:

所有有相同attribute名字的op的attribute都会维护在同一个AttrRegistryMapContainerMap对象的data_成员中,使用op的index_作为索引,我对注册op的set_attr设置的attribute进行了一下统计,如下图:

可见目前设置的op attr有17种,每一种的attr name和拥有这个attr的op的个数一目了然,这里可以看到最多的op num有355个,比文章开头统计的RELAY_REGISTER_OP的个数要多,是因为除了RELAY_REGISTER_OP这个宏注册OP之外,还有其它的宏来注册。

六、最后

RELAY_REGISTER_OP这个op注册宏用的非常多,和前文注册机调研:TVM之TVM_REGISTER_GLOBAL这个注册宏一样重要,这里也不知道有没有完全讲清楚,希望大家留言反馈。

写分析代码的文章不是很容易,代码中数据结构的嵌套、函数的层层调用、冗余代码,这些东西通过文字描述出来,如果不做化简、组织不好的话,很难讲的清楚,这篇简单的分析花了整整一天时间才写好,但肯定有很多自己看不到的问题,只能根据反馈边写边迭代改进,希望能越做越好。


您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存