其他
OneFlow源码解析:Op、Kernel与解释器
1 Op与Kernel的注册
1.1 ReluOp的注册
class定义: build/oneflow/core/framework/op_generated.h 注册op、op的部分实现: build/oneflow/core/framework/op_generated.cpp 主要实现: oneflow/oneflow/user/ops/relu_op.cpp
static UserOpRegisterTrigger<OpRegistry> g_register_trigger715 =
::oneflow::user_op::UserOpRegistryMgr::Get()
.CheckAndGetOpRegistry("relu")
.Input("x")
.Output("y")
.SetGetSbpFn(&ReluOp::GetSbp)
.SetLogicalTensorDescInferFn(&ReluOp::InferLogicalTensorDesc)
.SetPhysicalTensorDescInferFn(&ReluOp::InferPhysicalTensorDesc)
.SetDataTypeInferFn(&ReluOp::InferDataType);
1.2 ReluKernel的注册
static UserOpRegisterTrigger<OpKernelRegistry> g_register_trigger0 =
UserOpRegistryMgr::Get().
CheckAndGetOpKernelRegistry("relu").
.SetCreateFn(...)
.SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kRelu, "y", "x"))
.SetInplaceProposalFn([](const user_op::InferContext&,
const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {
OF_RETURN_IF_ERROR(AddInplaceArgPairFn("y", 0, "x", 0, true));
return Maybe<void>::Ok();
});
[]() {
return user_op::NewOpKernel<UnaryPrimitiveKernel>(
"y", "x", [](user_op::KernelComputeContext* ctx) {
const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("x", 0);
const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("y", 0);
return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(
ctx->device_type(), ep::primitive::UnaryOp::kRelu, src->data_type(),
dst->data_type());
});
}
1.3 Op和Kernel注册相关的类关系图
class ReluFunctor { public: ReluFunctor() { op_ = CHECK_JUST(one::OpBuilder("relu").Input("x", 1).Output("y", 1).Build()); } Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x, bool inplace) const { // 忽略inplace相关逻辑 return OpInterpUtil::Dispatch<Tensor>(*op_, {x}); } private: std::shared_ptr<OpExpr> op_;};
base_attrs_ tensor_desc_infer_fn_ dtype_infer_fn_ device_and_stream_infer_fn_
3 Functor的执行
3.1 根据环境和输入选择解释器
LazyInterpreter: 用于lazy mode下的分布式静态图执行模式 EagerLocalInterpreter: 用于eager local mode本地单卡执行模式(和pytorch单卡或DDP对齐) EagerGlobalInterpreter: 用于eager global mode,的分布式动态图执行模式
3.2 Apply
Maybe<void> EagerInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const {#define APPLY_IF(op_type) \ if (const auto* op = dynamic_cast<const op_type##Expr*>(&op_expr)) { \ return ApplyImpl(*op, inputs, outputs, ctx); \ } APPLY_IF(UserOp); APPLY_IF(VariableOp); APPLY_IF(CastToLocalOp); APPLY_IF(CastFromLocalOp); APPLY_IF(GlobalToGlobalOp); APPLY_IF(CastToGlobalOp); APPLY_IF(CastFromGlobalOp); APPLY_IF(DistributeSplitOp); APPLY_IF(DistributeCloneOp); APPLY_IF(DistributeConcatOp); APPLY_IF(DistributeAddOp); APPLY_IF(FunctionOp); APPLY_IF(SelectTopNOp)#undef APPLY_IF OF_UNIMPLEMENTED() << "The type " << op_expr.op_type_name() << " has not been supported in EagerInterpreter::Apply.";}
if (const auto* op = dynamic_cast<const UserOpExpr*>(&op_expr)) {
return ApplyImpl(*op, inputs, outputs, ctx);
}
Maybe<void> EagerLocalInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTuple& inputs,
TensorTuple* outputs,
const OpExprInterpContext& ctx) const {
return NaiveInterpret(op_expr, inputs, outputs, ctx);
}
3.3 NaiveInterpret
check input tensor的device是否一致 生成output tensor 为output tensor推导和检查shape/stride/dtype 构建op执行指令,并派发至vm
Maybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, const Symbol<Device>& default_device, TensorTuple* outputs, const OpExprInterpContext& ctx) { const auto& attrs = ctx.attrs; // 检查input tensor是否位于相同device上 ... // 推导outout tensor的设备类型 // Infer devices if (!user_op_expr.has_device_and_stream_infer_fn()) { stream = JUST(GetDefaultStreamByDevice(default_device)); for (int i = 0; i < outputs->size(); i++) { auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); *JUST(tensor_impl->mut_device()) = default_device; } } else { need_check_mem_case = false; stream = JUST(user_op_expr.InferDeviceAndStream(attrs, inputs, outputs)); } // 推导outout tensor的形状、数据类型 // Infer shapes and dtypes const auto& device_tag = stream->device()->type(); JUST(user_op_expr.InferPhysicalTensorDesc( attrs, device_tag, [&](int32_t i) -> const TensorMeta* { return CHECK_JUST(TensorImpl4Tensor(inputs[i]))->mut_tensor_meta(); }, [&](int32_t i) -> TensorMeta* { // using thread_local TensorMeta pointer if inplace. // using tensor_impl TensorMeta pointer if not inplace. return output_tensor_metas->at(i); })); // 为output tensor初始化eager_blob_object for (int i = 0; i < output_eager_blob_objects->size(); i++) { auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); if (!output_eager_blob_objects->at(i)) { if (!JUST(user_op_expr.SupportNonContiguous())) { std::shared_ptr<Stride> stride(new Stride(*tensor_impl->shape())); tensor_impl->mut_tensor_meta()->set_stride(stride); } const auto& dep_object = NewLocalDepObject(); JUST(tensor_impl->InitEagerBlobObject(dep_object)); output_eager_blob_objects->at(i) = JUST(tensor_impl->eager_blob_object()); } else { // output i is inplaced. // check thread_local TensorMeta and tensor_impl TensorMeta. CHECK_OR_RETURN(tensor_impl->tensor_meta()->shape() == output_tensor_metas->at(i)->shape()); CHECK_OR_RETURN(tensor_impl->tensor_meta()->dtype() == output_tensor_metas->at(i)->dtype()); } } // 从user_op_expr中取出kernel const auto& kernel = JUST(user_op_expr.MutKernel4Stream(stream)); kernel->set_need_check_mem_case(need_check_mem_case); for (int64_t index : kernel->output_tuple_indexes4mut2_obns()) { output_eager_blob_objects->at(index)->set_is_shape_synced(false); } // kernel dispatch至VM,等待后续实际的调度执行 JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> { return builder->Call(kernel, input_eager_blob_objects, output_eager_blob_objects, ctx, stream); })); return Maybe<void>::Ok();}
参考资料
OneFlow学习笔记:Op注册 (https://mp.weixin.qq.com/s/eF-c2irraxnH4iAesURy0Q) 从Functor到OpExprInterpreter https://github.com/Oneflow-Inc/oneflow/tree/v0.8.1 https://zhuanlan.zhihu.com/p/523884650