OneFlow学习笔记:从Functor到OpExprInterpreter
撰文|月踏
更新|赵露阳
此前写过的《OneFlow学习笔记:python到C++调用过程分析》,从Python代码追到了Functor这一层,本文从Functor开始继续往下追,后面就是OpExprInterpreter。
1
Functor回顾
Functor层作为OneFlow的基础设施,为Python端和C++端提供了op操作的统一入口,这在《python到C++调用过程分析》中有详细分析,其中使用了Relu作为示例,这是为了尽可能的减小理解成本,本文继续以Relu作为示例来往下追代码,前文已经列过ReluFunctor的代码,这里为了方便衔接上下文,再简单列一下:
class ReluFunctor {
public:
ReluFunctor() { op_ = CHECK_JUST(one::OpBuilder("relu").Input("x", 1).Output("y", 1).Build()); }
Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x, bool inplace) const {
...
return OpInterpUtil::Dispatch<Tensor>(*op_, {x});
}
private:
std::shared_ptr<OpExpr> op_;
};
代码很简单,可以分成三部分来看:
定义了数据结构:也就是类成员变量op_,它是OpExpr类型,这是下面第二节主要讲的部分
构造函数:使用OpBuilder这个辅助类对op_进行了初始化,主要还是在最后调用Build()的时候,内部调用了第二节讲到的UserOpExpr中的静态函数New来进行创建
函数调用运算符重载函数:这里通过一个Dispatch函数来把具体的计算做调度,最终会在某个具体的设备上来真正进行计算,这里面的细节太多了,本文的第三节先讲一部分的内容,完整的链条后续会再继续总结出来
std::string op_name_;
std::shared_ptr<const ArgTuple> input_arg_tuple_;
std::shared_ptr<const ArgTuple> output_arg_tuple_;
};
class BuiltinOpExprImpl : public BuiltinOpExpr {
ProtoType op_proto_;
mutable std::shared_ptr<OpExprGradFunctionIf> op_grad_func_;
};
// oneflow/core/framework/user_op_conf.proto
message UserOpConf {
message ListString { repeated string s = 1; }
required string op_type_name = 1;
map<string, ListString> input = 2;
map<string, ListString> output = 3;
map<string, AttrValue> attr = 4;
repeated string input_order = 5;
repeated string output_order = 6;
}
AttrMap base_attrs_;
user_op::TensorDescInferFn shape_infer_fn_;
user_op::DataTypeInferFn dtype_infer_fn_;
user_op::DeviceInferFn device_infer_fn_;
mutable HashMap<Symbol<Device>, std::shared_ptr<StatefulLocalOpKernel>> device2kernel_;
std::shared_ptr<ConsistentTensorInferCache> consistent_tensor_infer_cache_;
public:
static Maybe<UserOpExpr> New(const std::string& op_name, ...);
};
std::shared_ptr<OpExprInterpreter> internal_;
public:
Maybe<void> Apply(const OpExpr& op_expr, ...) const { ... }
};
const OpExpr& op_expr,
const TensorTuple& inputs,
TensorTuple* outputs,
const OpExprInterpContext& ctx) {
return JUST(GetInterpreter(inputs, ctx, op_expr))->Apply(op_expr, inputs, outputs, ctx);
}
...
AttrMap attrs;
Optional<Symbol<Device>> device;
Optional<Symbol<ParallelDesc>> parallel_desc;
Optional<Symbol<cfg::NdSbp>> nd_sbp;
std::shared_ptr<user_op::OpKernelState> state;
};
const OpExpr& op_expr) {
static const auto& g_lazy_interpreter = BuildLazyInterpreter();
static const auto& g_eager_consistent_interpreter = BuildEagerInterpreter(/*is_mirrored=*/false);
static const auto& g_eager_mirrored_interpreter = BuildEagerInterpreter(/*is_mirrored=*/true);
if (!LazyMode::is_enabled()) {
if (inputs.empty()) {
if (ctx.parallel_desc.has_value()) {
JUST(ctx.nd_sbp);
CHECK_OR_RETURN(!ctx.device.has_value());
return g_eager_consistent_interpreter;
} else {
CHECK_OR_RETURN(!ctx.nd_sbp.has_value());
return g_eager_mirrored_interpreter;
}
}
...
TensorTuple* outputs, const OpExprInterpContext& ctx) const {
bool requires_grad = false;
if (autograd::GradMode::is_enabled() && !JUST(op_expr.IsGradDisabled())) {
requires_grad =
std::any_of(inputs.begin(), inputs.end(),
[](const std::shared_ptr<Tensor>& tensor) { return tensor->requires_grad(); });
}
{
autograd::AutoGradMode mode(false);
JUST(internal_->Apply(op_expr, inputs, outputs, ctx));
}
// Lazy mode will construct backward compute graph in passes, so disable autograd if lazy mode.
std::shared_ptr<OpExprGradClosure> grad_closure(nullptr);
if (requires_grad && !LazyMode::is_enabled()) {
grad_closure = JUST(op_expr.GetOrCreateOpGradClosure());
auto backward_fn =
std::make_shared<std::function<Maybe<void>(const TensorTuple&, TensorTuple*, bool)>>(
[=](const TensorTuple& out_grads, TensorTuple* in_grads,
bool create_graph) -> Maybe<void> {
autograd::AutoGradMode mode(create_graph);
JUST(grad_closure->Apply(out_grads, in_grads));
return Maybe<void>::Ok();
});
JUST(GetThreadLocalAutogradEngine()->AddBackwardFuncPtr(op_expr.op_type_name() + "_backward",
backward_fn, inputs, outputs));
}
// Update outputs autograd meta
// Note: if requires_grad is True, we will create a new autograd meta for each output
// in `AddBackwardFuncPtr` to support inplace operation, so the update should after
// `AddBackwardFuncPtr`
for (auto& output : *outputs) {
output->set_is_leaf(inputs.size() == 0 || !requires_grad);
if (!output->requires_grad()) {
JUST(output->set_requires_grad(
requires_grad && IsSupportRequireGradDataType(output->dtype()->data_type())));
}
}
if (requires_grad && !LazyMode::is_enabled()) {
// Capture inputs and outputs after `AddBackwardFuncPtr` because of that grad function
// node has been attached to them.
JUST(grad_closure->Capture(inputs, *outputs, ctx));
}
return Maybe<void>::Ok();
}
#define APPLY_IF(op_type) \
if (const auto* op = dynamic_cast<const op_type##Expr*>(&op_expr)) { \
return ApplyImpl(*op, inputs, outputs, ctx); \
}
APPLY_IF(UserOp);
APPLY_IF(VariableOp);
APPLY_IF(CastToMirroredOp);
...
}
const TensorTuple& inputs, TensorTuple* outputs,
const OpExprInterpContext& ctx) const {
return NaiveInterpret(op_expr, inputs, outputs, ctx);
}
std::unique_ptr<Blob> blob_;
std::unique_ptr<char[]> header_buffer_;
std::shared_ptr<TensorStorage> tensor_storage_;
std::atomic<bool> is_shape_synced_;
int64_t storage_offset_;
intrusive::shared_ptr<LocalDepObject> compute_local_dep_object_;
};
...
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
return builder->LocalCallOpKernel(
kernel,
input_eager_blob_objects,
output_eager_blob_objects,
ctx,
op_device);
}));
return Maybe<void>::Ok();
}
#define APPLY_IF(op_type) \
if (const auto* op = dynamic_cast<const op_type##Expr*>(&op_expr)) { \
return ApplyImpl(*op, inputs, outputs, ctx); \
}
APPLY_IF(UserOp);
APPLY_IF(VariableOp);
APPLY_IF(CastToMirroredOp);
...
}
const TensorTuple& inputs, TensorTuple* outputs,
const OpExprInterpContext& ctx) const {
return InterpretThenInitConsistentId(op_expr, inputs, outputs, ctx);
}
struct WithDecorator final {
template<typename T, typename = void>
struct Decorate;
template<typename T, typename... Args>
struct Decorate<T (*)(Args...)> final {
template<T (*func)(Args...)>
static T Call(Args... args) {
return Decorator<T, Args...>::template Call<func>(args...);
}
};
};
#define DECORATE(fn_ptr, decorator) \
(&WithDecorator<decorator>::Decorate<decltype(fn_ptr)>::Call<fn_ptr>)
struct NonRecursiveInitConsistentId<Maybe<void>, Arg0, Arg1, TensorTuple*, Args...> {
template<Maybe<void> (*func)(Arg0, Arg1, TensorTuple*, Args...)>
static Maybe<void> Call(Arg0 arg0, Arg1 arg1, TensorTuple* outputs, Args... args) {
auto* recursive_depth = MutThreadLocalConsistentIdDepth();
++*recursive_depth;
Maybe<void> ret = func(arg0, arg1, outputs, args...);
--*recursive_depth;
if (*recursive_depth == 0 && ret.IsOk()) { JUST(InitConsistentId(outputs)); }
return ret;
}
};
创建前文《Global View的相关概念和实现》第三节中讲到的ConsistentTensorMeta信息,存于ConsistentTensorInferResult这个数据结构中
为output创建相应的EagerConsistentTensorImpl和ConsistentTensor
根据输入输出Tensor,创建前面图3展示的vm::EagerBlobObject对象,这些对象会在OneFlow的虚拟机中被用到,这中间可能会做boxing的操作,这部分目前不太熟悉,以后熟悉了再单独总结
进入虚拟机,调度并执行当前的这个op
TensorTuple* outputs, const OpExprInterpContext& ctx) {
// step 1
const auto& infer_args = JUST(ConsistentTensorMetaInferArgs::New(ctx.attrs, inputs));
std::shared_ptr<const ConsistentTensorInferResult> result =
JUST(user_op_expr.mut_consistent_tensor_infer_cache()->GetOrInfer(*infer_args));
const auto& output_tensor_metas = result->output_tensor_metas();
// step 2
for (int i = 0; i < outputs->size(); ++i) {
if (!outputs->at(i)) {
const auto& tensor_impl = JUST(EagerConsistentTensorImpl::New(
output_tensor_metas.at(i), tensor_device, parallel_id, false, false));
outputs->at(i).reset(new ConsistentTensor(tensor_impl));
}
}
// step 3
for (int i = 0; i < inputs.size(); ++i) {
const auto& local_tensor = JUST(input->cur_rank_phy_tensor());
input_eager_blob_objects->at(i) = JUST(local_tensor->eager_blob_object());
}
for (int i = 0; i < outputs->size(); ++i) {
const auto& local_tensor = JUST(outputs->at(i)->cur_rank_phy_tensor());
output_eager_blob_objects->at(i) = JUST(local_tensor->eager_blob_object());
}
// step 4
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
return builder->LocalCallOpKernel(kernel, input_eager_blob_objects, output_eager_blob_objects,
result, ctx, result->op_device());
}));
return Maybe<void>::Ok();
}
https://github.com/Oneflow-Inc/oneflow