其他
注册机调研:从Caffe开始
caffe里主要有两个注册宏:
REGISTER_LAYER_CLASS
REGISTER_SOLVER_CLASS
下面通过实际代码来详细解读这两个注册宏(注:为方便理解,相关代码都做了简化和修改)
if (layer_type == "conv1d") {
return MakeConv1dLayer(param);
} else if (layer_type == "conv2d") {
return MakeConv2dLayer(param);
} else if (layer_type == "conv3d") {
return MakeConv3dLayer(param);
} else if(...) {
...
}
// 以concat_layer.cpp为例,前面是实现,最后加一句注册即可
// concat实现部分
...
void ConcatLayer::Forward_cpu(...) {...}
void ConcatLayer::Backward_cpu(...) {...}
// 注册
REGISTER_LAYER_CLASS(Concat);
#define REGISTER_LAYER_CREATOR(type, creator) \
static LayerRegisterer<float> g_creator_f_##type(#type, creator<float>); \
static LayerRegisterer<double> g_creator_d_##type(#type, creator<double>) \
#define REGISTER_LAYER_CLASS(type) \
template <typename Dtype> \
shared_ptr<Layer<Dtype>> Creator_##type##Layer(const LayerParameter& param){ \
return shared_ptr<Layer<Dtype>>(new type##Layer<Dtype>(param)); \
} \
REGISTER_LAYER_CREATOR(type, Creator_##type##Layer)
template <typename Dtype>
class LayerRegisterer {
public:
LayerRegisterer(const string& type,
shared_ptr<Layer<Dtype>> (*creator)(const LayerParameter&)) {
LayerRegistry<Dtype>::AddCreator(type, creator);
}
};
一个Registry函数,用于返回一个全局的存储数据结构
一个AddCreator函数,用于注册一个函数
一个CreateLayer函数,用于调用注册的函数
// include/caffe/layer_factory.hpp
template <typename Dtype>
class LayerRegistry {
public:
typedef shared_ptr<Layer<Dtype>>
(*Creator)(const LayerParameter&);
typedef std::map<string, Creator> CreatorRegistry;
static CreatorRegistry& Registry() {
static CreatorRegistry* g_registry_ = new CreatorRegistry();
return *g_registry_;
}
static void AddCreator(const string& type, Creator creator) {
CreatorRegistry& registry = Registry();
registry[type] = creator;
}
static shared_ptr<Layer<Dtype>> CreateLayer(
const LayerParameter& param) {
CreatorRegistry& registry = Registry();
return registry[param.type()](param);
}
};
在Caffe中,在运行一个实际的网络之前,需要把所有的Layer都创建出来,在Net类的Init函数中可以看到创建Layer的代码,只有一句话:
// src/caffe/net.cpp Init函数中的创建Layer代码
layers_.push_back(LayerRegistry<Dtype>::CreateLayer(layer_param));
// include/caffe/net.hpp中的layers_定义
vector<shared_ptr<Layer<Dtype>>> layers_;
这个注册宏是用来注册solver的,solver的作用是定义参数的更新方式,常见的有sgd、adam等方法,其中sgd算法是我们常用的,这里只关注注册部分