后向重计算在OneFlow中的实现:以时间换空间,大幅降低显存占用
在OneFlow中,Checkpointing的实现主要是通过静态内存复用的方式,前向Tensor的生命周期结束后,其余Tensor可以复用这块内存,从而起到内存复用、节省内存的效果。
1.亚线性内存优化的用法 2.亚线性内存优化的设计 3.代码解读
亚线性内存优化的用法
# 用法:
with flow.experimental.scope.config(checkpointing=True):
# your net work, such as :
# input layernorm
norm1 = layernorm("layernorm_1", h)
# attention
h = h + self.attn(norm1)
# output layernorm
norm2 = layernorm("layernorm_2", h)
# mlp
h = h + self.mlp(norm2)
checkpointing = True
scope标记重计算的部分。可以看见,开启checkpointing后会大幅降低GPT-2训练时的显存占用,在batch size = 4 时,内存节省超过50+%。
2
亚线性内存优化的设计
在系列文章《深度解析:让你掌握OneFlow框架的系统设计(上篇、中篇、下篇)》中,我们介绍了OneFlow中的OpNode/OpGragh抽象以及建立在这之上的Actor、SBP抽象等系统设计,正是这些良好的系统设计和抽象使得OneFlow在多种任务下都有着优秀的表现。
生成fake子图,并将其作为后向消费者的输入(而不是真实子图)
在fake子图中增加由end op连向所有源节点source nodes的控制边
将fake子图添加至job builder(被其管理)
4.在job builder中更新所有后向消费者ops
代码实现:
https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job/checkpointing_config_def.cpp
https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job/job_build_and_infer_ctx.cpp#L989
https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/checkpointing_pass.cpp
3
代码解读
3.1 收集所有前向pass下的ops
https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/checkpointing_pass.cpp#L65
void CollectAllCheckpointingOpsInForwardPass(
// 收集所有属于前向pass下,且符合条件的op nodes,存放至HashMap中
const OpGraph& op_graph, HashMap<std::string, const OpNode*>* checkpointing_op_name2op_node) {
// NOTE(chengcheng):
// ignore batch_norm ops because of recompute bn will repeat the calculation of 'm' and 'v'.
// in the future, we need to support the recomputation version of batch_norm which do NOT
// update forward variables.
HashSet<std::string> ignore_op_type_names = {"normalization", "normalization_add_relu",
"cudnn_fused_normalization_add_relu"};
op_graph.ForEachNode([&](const OpNode* op_node) {
const OperatorConf& op_conf = op_node->op().op_conf();
// 跳过不包含user_conf以及ignore_op_type_names指定的op_node
if (!op_conf.has_user_conf()) { return; }
if (ignore_op_type_names.find(op_conf.user_conf().op_type_name())
!= ignore_op_type_names.end()) {
return;
}
// 对scope范围内开启了checkpointing且标记为ForwardPass的op_node,则为目标node,将其插入HashMap中
if (IsForwardPass7CheckpointingScope(Scope4OpNode(op_node))) {
CHECK(checkpointing_op_name2op_node->emplace(op_conf.name(), op_node).second);
}
});
}
IsForwardPass7CheckpointingScope()
方法,来对符合条件的op node进行筛选:bool IsForwardPassScope(const Scope& scope) {
// scope中,calculation_pass_name属性为kForwardPass的node,则为参与前向计算的目标node
return scope.scope_proto().calculation_pass_name() == kForwardPass;
}
bool IsForwardPass7CheckpointingScope(const Scope& scope) {
// True if 属性为kForwardPass的node且scope开启了checkpointing
return IsForwardPassScope(scope) && scope.Bool("checkpointing");
}
IsForwardPass7CheckpointingScope()
方法通过node的scope来判断该op node是否属于直接参与前向计算的node(scope中包含kForwardPass),且是否开启了“checkpointing”,同时满足则为目标node,将其插入hashmap(checkpointing_op_name2op_node)中。
3.2 收集ops下所有的subgraphs
筛选出checkpointing作用区域内所有的op nodes后,需要根据这些nodes生成所有子图subgraghs,这些子图有些是和后向重计算无关、有些则是后向重计算所需的目标子图,它们的输出作为后向op node的输入被消费,这些子图是实现activation checkpointing设计中前向重计算的最小单位。
// 根据ops生成所有subgraphs子图,并将其存放在vector中
// step 2. get all connected subgraphs in checkpointing ops.
std::vector<HashSet<const OpNode*>> checkpointing_subgraphs;
GenConnectedCheckpointingSubgraphs(checkpointing_op_name2op_node, &checkpointing_subgraphs);
void GenConnectedCheckpointingSubgraphs(
// 生成Subgraphs子图
const HashMap<std::string, const OpNode*>& checkpointing_op_name2op_node,
std::vector<HashSet<const OpNode*>>* checkpointing_subgraphs) {
HashSet<const OpNode*> visited_nodes;
for (const auto& pair : checkpointing_op_name2op_node) {
const OpNode* node = pair.second;
if (visited_nodes.find(node) != visited_nodes.end()) { continue; }
// new subgraph
checkpointing_subgraphs->push_back(HashSet<const OpNode*>());
CHECK(!checkpointing_subgraphs->empty());
auto& subgraph = checkpointing_subgraphs->back();
CHECK(subgraph.empty());
// bfs search all node in checkpointing ops
CHECK(visited_nodes.insert(node).second);
std::queue<const OpNode*> queued_nodes;
queued_nodes.push(node);
while (!queued_nodes.empty()) {
const OpNode* cur_node = queued_nodes.front();
queued_nodes.pop();
CHECK(subgraph.insert(cur_node).second);
cur_node->ForEachNodeOnInOutEdge([&](const OpNode* next_node) {
const std::string& next_op_name = next_node->op().op_name();
if (checkpointing_op_name2op_node.find(next_op_name) != checkpointing_op_name2op_node.end()
&& cur_node->parallel_desc() == next_node->parallel_desc()
&& visited_nodes.find(next_node) == visited_nodes.end()) {
queued_nodes.push(next_node);
CHECK(visited_nodes.insert(next_node).second);
}
});
}
}
}
生成fake子图,并将其作为后向消费者的输入(而不是真实子图)
在fake子图中增加由end op连向所有源节点source nodes的控制边
将fake子图添加至job builder(被其管理)
https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/checkpointing_pass.cpp#L148-L290
[for (auto& subgraph : checkpointing_subgraphs) {}]()
遍历循环的一开始,就会跳过不符合activation checkpointing条件的subgraghfor (auto& subgraph : checkpointing_subgraphs) {
// step 3.1 ignore this subgraph if there is no direct edge to backward pass op.
HashSet<const OpNode*> bw_consumers;
for (const OpNode* node : subgraph) {
node->ForEachNodeOnOutEdge([&](const OpNode* out_node) {
if (!IsForwardPassScope(Scope4OpNode(out_node))) {
bw_consumers.insert(out_node);
CHECK(subgraph.find(out_node) == subgraph.end());
}
});
}
if (bw_consumers.empty()) { continue; }
https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/checkpointing_pass.cpp#L168-L222
const OpNode* first_bw_consumer = nullptr;
int32_t first_bw_order = std::numeric_limits<int32_t>::max();
// 将backward消费者的input更改为fake子图op(而不是真实子图)
// step 3.3 change bw consumers input from subgraph to fake subgraph
for (const OpNode* node : bw_consumers) {
std::string bw_consumer_name = node->op().op_name();
OperatorConf bw_consumer_op_conf;
// NOTE(chengcheng):
// reuse bw conumer op conf if it has been existed in map.
if (total_bw_consumers_op_name2conf.find(bw_consumer_name)
!= total_bw_consumers_op_name2conf.end()) {
bw_consumer_op_conf = total_bw_consumers_op_name2conf.at(bw_consumer_name);
} else {
bw_consumer_op_conf = node->op().op_conf();
}
CHECK_EQ(bw_consumer_name, bw_consumer_op_conf.name());
auto* user_conf = bw_consumer_op_conf.mutable_user_conf();
// 修改和subgragh相关的backward op输入的blob name
// change input lbns if in subgraph
for (auto& pair : *(user_conf->mutable_input())) {
auto& list_s = pair.second;
for (int i = 0; i < list_s.s_size(); ++i) {
std::string old_lbn = list_s.s(i);
LogicalBlobId old_lbi = GenLogicalBlobId(old_lbn);
std::string old_input_op_name = old_lbi.op_name();
if (subgraph_op_name2op_node.find(old_input_op_name) != subgraph_op_name2op_node.end()) {
list_s.set_s(i, kCheckpointingFakeOpNamePrefix + old_lbn);
}
}
}
// NOTE(chengcheng):
// emplace maybe repeated, so do not check the return value
total_bw_consumers_op_name2conf.emplace(bw_consumer_name, bw_consumer_op_conf);
CHECK(op_node2order.find(node) != op_node2order.end());
int32_t this_order = op_node2order.at(node);
if (this_order < first_bw_order) {
first_bw_consumer = node;
first_bw_order = this_order;
}
}
在fake子图中为所有source node—end node添加控制边
https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/checkpointing_pass.cpp#L267-L284
// step 3.4 add control edge from End Op to all source node in fake subgraph
CHECK(first_bw_consumer != nullptr);
std::string end_op_name = kCheckpointingBadOpName;
int32_t end_order = -1;
first_bw_consumer->ForEachNodeOnInEdge([&](const OpNode* end_node) {
CHECK(op_node2order.find(end_node) != op_node2order.end());
int32_t this_order = op_node2order.at(end_node);
if (this_order > end_order) {
end_order = this_order;
end_op_name = end_node->op().op_name();
}
});
CHECK_NE(end_order, -1);
CHECK_NE(end_op_name, kCheckpointingBadOpName);
CHECK_LT(end_order, first_bw_order);
for (const auto& source_op_name : source_node_in_fake_subgraph) {
fake_op_name2conf.at(source_op_name).add_ctrl_in_op_name(end_op_name);
}
将fake子图添加至job build(被其管理)
// 将fake subgraph所包含的ops加入至job_builder管理(图改写)
// step 3.5 add fake subgraph ops to job builder
std::vector<OperatorConf> fake_op_confs;
for (auto& pair : fake_op_name2conf) { fake_op_confs.push_back(pair.second); }
job_builder->AddOps(parallel_conf, fake_op_confs);
3.4 更新所有后向消费者ops
// 在job builder中更新所有backward ops
// step 4. update bw consumers in job builder only once
std::vector<OperatorConf> total_bw_consumer_op_confs;
for (auto& pair : total_bw_consumers_op_name2conf) {
total_bw_consumer_op_confs.push_back(pair.second);
}
job_builder->MutOpsOnlyOnce(total_bw_consumer_op_confs);
return Maybe<void>::Ok();
https://github.com/Oneflow-Inc/OneFlow-Benchmark/tree/master/LanguageModeling/GPT