其他
OneFlow学习笔记:从OpExprInterpreter到OpKernel
class Global final {
public:
// 获取创建过的对象
static T* Get() { ... }
// 创建对象
static void SetAllocated(T* val) { ... }
template<typename... Args>
static T* New(Args&&... args) { ... }
// 释放对象
static void Delete() { ... }
...
};
def __init__(self):
if not HasAllMultiClientEnvVars():
SetDefaultMultiClientEnvVars()
self._env_cxt = create_env()
...
def create_env():
"""create environment
Returns:
Env: [description]
"""
global default_env_proto
assert len(default_env_proto.machine) > 0
CompleteEnvProto(default_env_proto)
if default_env_proto.ctrl_bootstrap_conf.world_size > 1:
check_non_localhost_proxy_and_print_warning()
return c_api_util.GetEnvContext(default_env_proto)
assert type(env_proto) is env_pb2.EnvProto
env_proto_str = text_format.MessageToString(env_proto)
env_ctx = oneflow._oneflow_internal.EnvContext(env_proto_str)
return env_ctx
...
Global<EnvDesc>::New(env_proto);
Global<ProcessCtx>::New();
...
#ifdef WITH_CUDA
Global<EagerNcclCommMgr>::New();
Global<CudnnConvAlgoCache>::New();
Global<embedding::EmbeddingManager>::New();
#endif
Global<vm::VirtualMachineScope>::New(Global<ResourceDesc, ForSession>::Get()->resource());
Global<EagerJobBuildAndInferCtxMgr>::New();
...
return Maybe<void>::Ok();
}
Global<VirtualMachine>::New(resource, GlobalProcessCtx::Rank());
}
namespace { \
struct OF_PP_CAT(CommandT, __LINE__) { \
OF_PP_CAT(CommandT, __LINE__)() { __VA_ARGS__; } \
}; \
OF_PP_CAT(CommandT, __LINE__) OF_PP_CAT(g_command_var, __LINE__); \
}
public:
CpuLocalCallOpKernelInstructionType() = default;
~CpuLocalCallOpKernelInstructionType() override = default;
using stream_type = vm::CpuStreamType;
};
void RegisterInstructionType(const std::string& instr_type_name) {
RegisterInstrTypeId<T>(instr_type_name, StaticGlobalStreamType<typename T::stream_type>());
}
const InstructionType* instruction_type) {
InstrTypeId instr_type_id;
instr_type_id.__Init__(stream_type, instruction_type);
CHECK(InstrTypeId4InstructionName()->emplace(instruction_name, instr_type_id).second);
}
初始化一个InstrTypeId对象,并调用其__Init__方法为其成员变量stream_type_和instruction_type_赋值,这里stream_type就是T::stream_type,即vm::CpuStreamType;instruction_type即指向T的指令类型的指针对象。
通过InstrTypeId4InstructionName()方法拿到一个静态HashMap<std::string, InstrTypeId> map对象的指针。
将instruction_name("cpu.LocalCallOpKernel")作为key,InstrTypeId对象instr_type_id作为value插入这个map中。
vm::InstructionMsgList instruction_list;
InstructionsBuilder instructions_builder(std::make_shared<vm::PhysicalIdGenerator>(),
&instruction_list);
JUST(Build(&instructions_builder));
JUST(vm::Run(instructions_builder.mut_instruction_list()));
return Maybe<void>::Ok();
}
return builder->LocalCallOpKernel(kernel, input_eager_blob_objects, output_eager_blob_objects, ctx, op_device);
}
...
auto phy_instr_operand = JUST(vm::LocalCallOpKernelPhyInstrOperand::New(
opkernel, input_eager_blob_objects, output_eager_blob_objects, consistent_tensor_infer_result,
ctx, *one::CurrentDevVmDepObjectConsumeMode()));
auto instruction = intrusive::make_shared<vm::InstructionMsg>(
Global<VirtualMachine>::Get()->mut_vm(), JUST(op_device->local_call_instruction_name()),
parallel_desc_sym, phy_instr_operand);
instruction_list_->EmplaceBack(std::move(instruction));
...
return Maybe<void>::Ok();
}
auto* virtual_machine = JUST(GlobalMaybe<VirtualMachine>());
JUST(virtual_machine->Receive(instr_msg_list));
return Maybe<void>::Ok();
}
intrusive::shared_ptr<InstructionMsg>&& compute_instr_msg) {
InstructionMsgList instr_msg_list;
instr_msg_list.EmplaceBack(std::move(compute_instr_msg));
return Receive(&instr_msg_list);
}
OF_PROFILER_RANGE_PUSH("vm:Receive");
INTRUSIVE_UNSAFE_FOR_EACH_PTR(compute_instr_msg, compute_instr_msg_list) {
OF_PROFILER_RANGE_PUSH(compute_instr_msg->DebugName());
OF_PROFILER_RANGE_POP();
}
bool old_list_empty = mut_pending_msg_list()->MoveFrom(compute_instr_msg_list);
OF_PROFILER_RANGE_POP();
return old_list_empty;
}
Maybe<bool> VirtualMachineEngine::Receive(
intrusive::shared_ptr<InstructionMsg>&& compute_instr_msg) {
InstructionMsgList instr_msg_list;
instr_msg_list.EmplaceBack(std::move(compute_instr_msg));
return Receive(&instr_msg_list);
}
: vm_threads_closed_(false) {
...
std::function<void()> SchedulerInitializer;
GetSchedulerThreadInitializer(&SchedulerInitializer);
schedule_thread_ = std::thread(&VirtualMachine::ScheduleLoop, this, SchedulerInitializer);
}
...
while (pending_notifier_.WaitAndClearNotifiedCnt() == kNotifierStatusSuccess) {
...
do {
...
do {
...
do { vm->Schedule(schedule_ctx); } while (!vm->ThreadUnsafeEmpty());
vm->MoveToGarbageMsgListAndNotifyGC(schedule_ctx);
} while (++i < kNumSchedulingPerTimoutTest);
} while (MicrosecondsFrom(start) < kWorkingMicroseconds);
}
...
}
if (...) { ReleaseFinishedInstructions(); }
if (...) { TryRunBarrierInstruction(); }
if (...) { HandleLocalPending(); }
if (...) { DispatchAndPrescheduleInstructions(); }
}
...
InstructionMsgList pending_instr_msgs;
INTRUSIVE_FOR_EACH_PTR(instr_msg, &pending_instr_msgs) {
MakeInstructions(instr_msg, /*out*/ &new_instruction_list);
}
...
INTRUSIVE_FOR_EACH_PTR(instruction, &new_instruction_list) {
ConsumeMirroredObjects(instruction);
if (likely(Dispatchable(instruction))) {
mut_ready_instruction_list()->PushBack(instruction);
new_instruction_list.Erase(instruction);
}
}
}
ReadyInstructionList tmp_ready_instruction_list;
mut_ready_instruction_list()->MoveTo(&tmp_ready_instruction_list);
INTRUSIVE_FOR_EACH(instruction, &tmp_ready_instruction_list) {
...
DispatchInstruction(instruction.Mutable());
...
}
...
}
const ScheduleCtx& schedule_ctx) {
auto* stream = instruction->mut_stream();
stream->mut_running_instruction_list()->PushBack(instruction);
if (stream->active_stream_hook().empty()) { mut_active_stream_list()->PushBack(stream); }
const auto& stream_type = stream->stream_type();
if (OnSchedulerThread(stream_type)) {
stream_type.Run(instruction);
} else {
stream->mut_thread_ctx()->mut_pending_instruction_list()->PushBack(instruction);
schedule_ctx.OnWorkerLoadPending(stream->mut_thread_ctx());
}
}
...
const auto& instruction_name = JUST(StreamRoleSwitch<GetCallInstructionName>(
stream->stream_role(), stream->device()->enum_type()));
auto instruction = intrusive::make_shared<vm::InstructionMsg>(
Global<VirtualMachine>::Get()->mut_vm(), instruction_name, parallel_desc_sym,
phy_instr_operand);
instruction_list_->EmplaceBack(std::move(instruction));
...
return Maybe<void>::Ok();
}
...
InstrTypeId instr_type_id_;
std::string instr_type_name_;
...
Stream* phy_instr_stream_;
};
const std::shared_ptr<const ParallelDesc>& phy_instr_parallel_desc,
const std::shared_ptr<PhyInstrOperand>& phy_instr_operand) {
__Init__();
if (likely(phy_instr_parallel_desc)) {
int device_id = phy_instr_parallel_desc->parallel_id2device_id().at(0);
vm->GetCachedInstrTypeIdAndPhyInstrStream(instr_type_name, device_id, mut_instr_type_id(),
&phy_instr_stream_);
}
...
}
int device_id,
InstrTypeId* instr_type_id,
Stream** stream) {
auto* cache = &instr_type_name2rt_instr_type_id_;
auto iter = cache->find(instr_type_name);
if (unlikely(iter == cache->end())) {
const auto& instr_type_id_val = LookupInstrTypeId(instr_type_name);
const auto* stream_type = &instr_type_id_val.stream_type();
auto* stream_rt_desc = this->mut_stream_type2stream_rt_desc()->FindPtr(stream_type);
iter = cache->emplace(instr_type_name, RtInstrTypeId(instr_type_id_val, stream_rt_desc)).first;
}
instr_type_id->CopyFrom(iter->second.instr_type_id());
*stream = iter->second.GetStream(device_id);
}
public:
RtInstrTypeId(const RtInstrTypeId&) = default;
RtInstrTypeId(RtInstrTypeId&&) = default;
~RtInstrTypeId() = default;
RtInstrTypeId(const InstrTypeId& instr_type_id, StreamRtDesc* stream_rt_desc)
: instr_type_id_(instr_type_id), stream_rt_desc_(stream_rt_desc) {
if (stream_rt_desc->stream_type().IsControlStreamType()) {
get_stream_ = &StreamRtDesc::GetSoleStream;
} else {
get_stream_ = &StreamRtDesc::GetDeviceStream;
}
}
const InstrTypeId& instr_type_id() const { return instr_type_id_; }
Stream* GetStream(int device_id) const { return (stream_rt_desc_->*get_stream_)(device_id); }
private:
const InstrTypeId instr_type_id_;
StreamRtDesc* stream_rt_desc_;
Stream* (StreamRtDesc::*get_stream_)(int device_id) const;
};
const auto& instr_type_id_val = LookupInstrTypeId(instr_type_name);
const auto* stream_type = &instr_type_id_val.stream_type();
auto* stream_rt_desc = this->mut_stream_type2stream_rt_desc()->FindPtr(stream_type);
iter = cache->emplace(instr_type_name, RtInstrTypeId(instr_type_id_val, stream_rt_desc)).first;
}
...
INTRUSIVE_UNSAFE_FOR_EACH_PTR(stream_desc, &vm_desc.stream_type_id2desc()) {
if (stream_desc->num_threads() == 0) { continue; }
auto stream_rt_desc = intrusive::make_shared<StreamRtDesc>(stream_desc);
mut_stream_type_id2stream_rt_desc()->Insert(stream_rt_desc.Mutable());
...
}
}
...
InstructionMsgList pending_instr_msgs;
INTRUSIVE_FOR_EACH_PTR(instr_msg, &pending_instr_msgs) {
MakeInstructions(instr_msg, /*out*/ &new_instruction_list);
}
...
INTRUSIVE_FOR_EACH_PTR(instruction, &new_instruction_list) {
ConsumeMirroredObjects(instruction);
if (likely(Dispatchable(instruction))) {
mut_ready_instruction_list()->PushBack(instruction);
new_instruction_list.Erase(instruction);
}
}
}
/*out*/ InstructionList* new_instruction_list) {
const auto& instruction_type = instr_msg->instr_type_id().instruction_type();
bool is_barrier_instruction = instruction_type.IsFrontSequential();
Stream* stream = CHECK_NOTNULL(instr_msg->phy_instr_stream());
const auto& pd = instr_msg->phy_instr_parallel_desc();
intrusive::shared_ptr<Instruction> instr = stream->NewInstruction(instr_msg, pd);
LivelyInstructionListPushBack(instr.Mutable());
if (unlikely(is_barrier_instruction)) {
mut_barrier_instruction_list()->PushBack(instr.Mutable());
} else {
new_instruction_list->PushBack(instr.Mutable());
}
}
...
if (OnSchedulerThread(stream_type)) {
stream_type.Run(instruction);
} else {
stream->mut_thread_ctx()->mut_pending_instruction_list()->PushBack(instruction);
schedule_ctx.OnWorkerLoadPending(stream->mut_thread_ctx());
}
...
}
public:
virtual ~StreamType() = default;
void Run(Instruction* instruction) const { Compute(instruction); }
virtual const char* stream_tag() const = 0;
virtual void InitDeviceCtx(std::unique_ptr<DeviceCtx>* device_ctx, Stream* stream) const = 0;
virtual void InitInstructionStatus(const Stream& stream,
InstructionStatusBuffer* status_buffer) const = 0;
virtual void DeleteInstructionStatus(const Stream& stream,
InstructionStatusBuffer* status_buffer) const = 0;
virtual bool QueryInstructionStatusDone(const Stream& stream,
const InstructionStatusBuffer& status_buffer) const = 0;
virtual void Compute(Instruction* instruction) const = 0;
virtual intrusive::shared_ptr<StreamDesc> MakeStreamDesc(const Resource& resource,
int64_t this_machine_id) const = 0;
virtual bool OnSchedulerThread() const = 0;
virtual bool SupportingTransportInstructions() const = 0;
virtual bool IsControlStreamType() const { return false; }
protected:
StreamType() = default;
};
...
{
const auto& instr_type_id = instruction->mut_instr_msg()->instr_type_id();
instr_type_id.instruction_type().Compute(instruction);
}
auto* status_buffer = instruction->mut_status_buffer();
NaiveInstrStatusQuerier::MutCast(status_buffer->mut_buffer()->mut_data())->set_done();
...
}
...
virtual void Compute(Instruction* instruction) const = 0;
virtual void ComputeInFuseMode(InstructionMsg* instr_msg) const { LOG(FATAL) << "UNIMPLEMENTED"; }
...
};
CHECK_JUST(LocalCallOpKernelUtil::Compute(instruction));
}
static inline Maybe<void> Compute(vm::Instruction* instruction) {
...
OpKernelCompute(operand, device_ctx, state, cache);
...
return Maybe<void>::Ok();
}
...
};
...
static inline void OpKernelCompute(LocalCallOpKernelPhyInstrOperand* operand,
DeviceCtx* device_ctx, user_op::OpKernelState* state,
const user_op::OpKernelCache* cache) {
...
operand->user_opkernel()->Compute(compute_ctx, state, cache);
...
}
};
private:
void Compute(user_op::KernelComputeContext* ctx) const override {
// do computing!
}
};
https://github.com/Oneflow-Inc/oneflow