查看原文
其他

PyTorch 源码解读之流水线并行

OpenMMLab OpenMMLab 2024-04-23


随着 Transformer 架构模型的广泛应用,大语言模型的参数量也是水涨船高,像是 GPT-3 就已经达到了惊人的 175B 参数量,GPT-4 更是被曝有 1800B 的参数量。为了训练这样的大模型,并且尽可能提高 GPU 的利用率,流水线并行(Pipeline Parallelism, PP)的训练策略应运而生。PyTorch 也实现了一套流水线并行的解决方法。本文将介绍 torch.distributed.pipeline.sync 的实现细节。


相关代码位于 https://github.com/pytorch/pytorch/tree/v2.1.0-rc6/torch/distributed/pipeline/sync



1 流水线并行介绍



如图所示为谷歌提出的流水线并行算法,名为 GPipe,论文位于 https://arxiv.org/abs/1811.06965。首先将模型切分为连续的多个 stage,每个 stage 占据一台设备,从而利用多台设备容纳下单设备无法容纳的模型。其次,GPipe 将 mini-batch 切分为多个 micro-batch,每次只处理一个 micro-batch。在处理完当个 micro-batch 后,该 micro-batch 的结果将会被发送给下一台设备,同时开始处理下一个 micro-batch。


可能这一段话听起来有些烧脑,让我们举个栗子。


假如我们有四张显卡,所以我们将模型按照顺序切分为了  ,  ,  ,  四个 stage。然后我们将一个 mini-batch 的数据分为 4 个 micro-batch。在   时刻,  完成第一个 micro-batch 的前向过程,记为  。随后在 

  时刻,  接收到  的结果,从而完成第一个 micro-batch 第二阶段的前向过程,即  。同时,  完成第二个 micro-batch 的前向过程,为   。类似地,  时刻则有  ,  ,  完成。以此类推,便完成了模型的前向过程。


而反向过程与前向过程也很类似,只不过执行顺序与前向是完全相反的。在所有的 micro-batch 都完成后,模型将会根据所有 micro-batch 的梯度信息进行更新,然后等待下一个 mini-batch 的到来。但是可以注意到在前向过程与后向过程中有一段设备空闲期,该段空闲期被称为 Bubble。虽然它影响到了整体的性能,但它是该算法不可避免的。


该算法有着易理解、易实现等优点,目前 PyTorch 也是实现了这一算法。相关 API 位于 torch.distributed.pipeline.sync,我们接下来就简单介绍它的使用方法以及实现细节。



2 使用方法


在介绍算法的实现细节之前,我们先来介绍该算法的使用方法。对于使用方法,PyTorch 也是做了比较详细的介绍。同时该方法的使用也是可以开箱即用的(但是对模型做出了强要求,即必须继承于 nn.Sequential)。模型经过 Pipe 类包装后,就可以像一般的 PyTorch 模型一样训练和测试啦。


model = Pipe(torch.nn.Sequential(*module_list), chunks=chunks)model.train()result = model(data)loss = loss_func(result, target)loss.backward()


感兴趣的小伙伴可以看这里哦,https://pytorch.org/tutorials/intermediate/pipeline_tutorial.html



3 实现细节


3.1 流水线初始化

流水线初始化的相关代码位于 torch.distributed.pipeline.sync.Pipe。由于在进行流水线并行计算操作前,需要将模型切分并将相应权重分发到指定设备,因此 Pipe 类初始化过程的主要功能则是检查模型切分策略是否合法,在确认合法后将模型切分,并将模型封装为 torch.distributed.pipeline.sync.Pipeline 类,以便实现并行。


常见的模型切分策略主要分为两种,基于运行时间的切分策略与基于模块大小的切分策略。前者会使得模型的各个 stage 的运行时间(包括前向与后向)尽可能平衡,而后者则是保证模型的各个 stage 所消耗的显存大小尽可能一致。


Pipe 类的初始化过程的主要任务便是检查切分策略,并封装模型,而 Pipeline 类的初始化过程则是对传入参数的注册。因此,流水线初始化过程便不做赘述了。


3.2 流水线执行时

在流水线执行过程中,我们首先会调用到 Pipe 类的 forward 方法。


def forward(self, *inputs) -> RRef:    first_partition_device = self.devices[0] if len(self.devices) != 0 else torch.device("cpu")    microbatch.check(first_partition_device, *inputs)
   if not self.devices:        # Empty sequential module is not illegal.        return RRef(*inputs)
   # Divide a mini-batch into micro-batches.    batches = microbatch.scatter(*inputs, chunks=self.chunks)
   # Run pipeline parallelism.    self.pipeline.run(batches)
   # Merge the micro-batches into one mini-batch.    output = microbatch.gather(batches)    return RRef(output)


从源码中不难看出,其主要负责了 micro-batch 的切分,调用 Pipeline 类的 run 方法来执行流水线并行,将输出结果收集并整理,这样的几个任务。返回的类型 RRef 则是 PyTorch 用于分布式训练的“分布式共享指针”,其全拼为 Remote REFerence,感兴趣的小伙伴可以在 https://pytorch.org/docs/stable/rpc/rref.html 获得更多对于 RRef 的介绍。


所以流水线并行的核心内容还在 Pipeline 类的 run 方法里面。那接下来一起来看 Pipeline 类的 run 方法,快要到达流水线最核心的部分了。


def run(self, batches: List[Batch]) -> None:    """Runs pipeline parallelism.
   It modifies the given batches in place.
   """    partitions = self.partitions    devices = self.devices    skip_layout = self.skip_layout
   m = len(batches)    n = len(partitions)
   skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches]
   for schedule in _clock_cycles(m, n):        self.fence(batches, schedule, skip_trackers)        self.compute(batches, schedule, skip_trackers)


看完 run 方法的源码,有些小伙伴一定会有疑惑,什么是 skip_trackers,schedule 和 clock_cycles 又是什么,fence 和 compute 又在干什么?


简单来讲,skip_trackers 是一些传送门,它们高效地解决了跳跃连接情况的跨设备问题;schedule 指示了某时刻流水线并行应该做什么;clock_cycles 则是负责产生 schedule;而 fence 作为执行过程中的”篱笆“,充当了计算步骤间”中场休息“的作用,让算法可以做好相应的准备以处理下一次计算;而 compute 则是流水线并行算法中最核心的一部分,它完成了模型的前向与后向过程中的全部计算任务。


接下来,让我们一步一步来解释这些概念。


3.3 Skip:张量传送门

考虑到模型中可能存在跳跃连接的问题,PyTorch 为了进一步提升算法性能,设计了一套张量传送门(tensor portal)来解决跳跃连接中的张量传递问题。



图片来自 torchgpipe: On-the-fly Pipeline Parallelism for Training Giant Models。左图为无张量传送门情况下的解决办法,右图则为存在 portal 情况下的跳跃连接解决方法。


我们假设有一个张量,在模型的第一个 stage (位于 Device 0)中产生,而其跳跃连接的目标却是模型的 stage 3(位于 Device 2)。在无张量传送门的情况下,该张量需要先从 Device 0 搬移到 Device 1,再搬移到 Device 2。这一过程中产生了两次通信。然而在有张量传送门的情况下,我们只需要将该张量在 Device 0 丢进传送门,它就会在 Device 2 出现,从而减少了一次通信过程。


由于这一部分并不涉及流水线并行的核心部分,故不再进一步展开介绍。感兴趣的小伙伴可以前往 https://github.com/pytorch/pytorch/tree/v2.1.0-rc6/torch/distributed/pipeline/sync/skip


3.4 Schedule 和 clock

在我们最开始介绍什么是流水线并行的时候,我们提到过,在某一个时刻需要完成某一些前向或是后向任务。这一过程所需的时钟信号的产生与任务调度便是由 clock_cycle 产生的 schedule 来控制的。


def _clock_cycles(m: int, n: int) -> Iterable[List[Tuple[int, int]]]:    """Generates schedules for each clock cycle."""    # m: number of micro-batches    # n: number of partitions    # i: index of micro-batch    # j: index of partition    # k: clock number    #    # k (i,j) (i,j) (i,j)    # - ----- ----- -----    # 0 (0,0)    # 1 (1,0) (0,1)    # 2 (2,0) (1,1) (0,2)    # 3       (2,1) (1,2)    # 4             (2,2)    for k in range(m + n - 1):        yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))]


执行这一部分代码,我们就可以看到该函数第一次输出了[(0,0)],对应  时刻需要完成  任务。类似的,第二次调用该函数,便会得到输出[(1,0), (0,1)],即第一个 stage 完成第二个 micro-batch,第二个 stage 完成第一个 micro-batch。这与  时刻的  ,  (注意角标是相反顺序的)任务也是恰好对应的。同时 schedule 的每一项(诸如  )的第一项指示了micro-batch id。,第二项指示了 stage id。

因此该函数主要用来指示流水线并行过程中,每一时刻所需要执行并完成的任务,就好像是流水线并行的”大脑“一样。


3.5 Fence

Fence 函数作为计算步骤间的“篱笆”,是计算步骤间的“中场休息”。Fence 阶段会让算法做完相关的数据搬移过程,以处理下一次计算。


def fence(    self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],) -> None:    """Copies micro-batches after computation for the previous    micro-batches.    """    copy_streams = self.copy_streams    skip_layout = self.skip_layout        # i 表示 micro-batch id, j 表示 stage id    for i, j in schedule:        # Ensure that batches[i-1] is executed after batches[i] in        # backpropagation by an explicit dependency.        if i != 0 and j != 0:            # 在 stage j 处理 micro-batch i 之前,需要            # 先等待 stage j 完全完成 micro-batch i-1            # 的前向与后向过程完全完成            _depend(batches[i - 1], batches[i])
       next_stream = copy_streams[j][i]
       for prev_j, ns, name in skip_layout.copy_policy(j):            prev_stream = copy_streams[prev_j][i]            skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name)
       if j != 0:            # 获取上一个 stage 所用的数据拷贝流            prev_stream = copy_streams[j - 1][i]            # 将数据从上一个拷贝流搬移到下一个拷贝流            _copy(batches[i], prev_stream, next_stream)


接下来我们逐行解释源码。


首先根据 schedule 的每一项,如果当前要执行的 micro-batch id 不是 0 并且 stage id 不是 0,那么在进行当前 micro-batch 的计算前,先等待上一个 micro-batch 的计算(包括前向与后向)完全完成。因为上一个 micro-batch 在上一时刻也是由 stage j 负责计算的,故在 stage j 承担新的计算任务前需要先等待上一个任务的完成。


接着我们获得 micro-batch i, stage j 所对应的数据拷贝流(是一种 cuda stream,只用来承担将数据移动到指定设备的任务),并根据 skip_trackers 的信息,完成依赖 portal 的数据传递。


由于 micro-batch i 刚刚被 stage j-1 完成计算,因此需要获得 micro-batch i,stage j-1 所对应的数据拷贝流,并将 micro-batch i 的数据由 stage j-1 流所对应的设备转移到 stage j 流所对应的设备上。


我们举个例子,当 i = 2,j= 1 的时候,根据指示,现在需要让模型的 stage 1 完成 micro-batch 2 的计算任务。首先,我们需要等待 stage 1 完成 micro-batch 1 的计算任务,才能开始计算 micro-batch 2。在等待结束后,由于 micro-batch 2 刚刚被 stage 0 计算过,其数据还在 device 0,因此我们需要利用数据拷贝流将 micro-batch 2 移动到 stage 1 所对应的 device 1。


这样一来,在 fence 阶段我们完成了将对应的数据转移到了对应的设备上的任务,从而在接下来的计算任务中不会出现设备不统一的错误。


3.6 Compute

接下来就是 Compute 函数,也就是流水线并行中最核心的部分了。这一部分完成了模型的前向与后向计算任务。


def compute( self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],) -> None: """Runs tasks with synchronization to copy streams.""" partitions = self.partitions devices = self.devices copy_streams = self.copy_streams checkpoint_stop = self.checkpoint_stop
# Disable checkpointing if in eval mode. if not self.partitions[0].training: checkpoint_stop = 0
n = len(partitions) # 为每一台设备创建执行时所用流 streams = [current_stream(d) for d in devices] exc_info: Optional[ExcInfo] = None # i 表示 micro-batch id, j 表示 stage id for i, j in schedule: batch = batches[i] partition = partitions[j]
# Synchronize with the copied input. ([1] in the diagram) # 等待数据拷贝流把数据移动到指定设备 if j != 0: _wait(batch, copy_streams[j][i], streams[j])
# Determine whether checkpointing or not. checkpoint = i < checkpoint_stop # 根据检查点启用情况创建任务 if checkpoint:
def function( *inputs, partition: nn.Module = partition, skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], chunk_id: int = i, part_id: int = j,) -> TensorOrTensors: with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): return partition(*inputs)
chk = Checkpointing(function, batch) # type: ignore[arg-type] # 启用检查点情况下,finalize 函数即为重计算过程 task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute) del function, chk
else:
def compute( batch: Batch = batch, partition: nn.Module = partition, skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], chunk_id: int = i, part_id: int = j,) -> Batch: with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): return batch.call(partition) # 不启用检查点情况下,无需重计算 task = Task(streams[j], compute=compute, finalize=None) del compute
# Compute tasks in parallel. ([2] in the diagram) # 向 in_queue 中添加任务 self.in_queues[j].put(task)


首先,为每一台可用的设备创建执行流(也是 cuda stream,但是用于模型计算过程)。


接着,检测模型是否处于训练状态。如果模型不是训练状态,那么完全没必要启用检查点技术(梯度检查点,在反向传播过程中,使用重计算的方法来重新获得激活值与梯度信息,使得模型在前向过程中不需要保存激活值,从而减少显存开销,MMEngine 现已支持检查点技术,相关 PR 位于https://github.com/open-mmlab/mmengine/pull/1319)。


接下来,首先让执行流等待 fence 阶段的数据拷贝流完成数据的搬移工作,然后进入数据发送阶段。在数据发送阶段,算法会根据 schedule 的指示,计算出该当前模型 stage 是否启用检查点,并相应创建任务,放进与 stage id 对应的 in_queue 中。小伙伴们这时候可能又会有疑惑了,把任务放进 in_queue 中,那么谁来执行任务呢?所以这里我们给出相应的源码。


def worker(in_queue: InQueue, out_queue: OutQueue, device: torch.device) -> None:    """The main loop of a worker thread."""    with use_device(device):        while True:            # 获取任务            task = in_queue.get()
           if task is None:                break
           try:                # 执行任务                batch = task.compute()            except Exception:                # 获取报错信息                exc_info = cast(ExcInfo, sys.exc_info())                out_queue.put((False, exc_info))                continue            # 向 out_queue 中放置执行结果            out_queue.put((True, (task, batch)))
   done = (False, None)    out_queue.put(done)


这段源代码便展示了 PyTorch 流水线并行中的执行单元 worker。主线程首先为每一台用于计算的设备创建一个 in_queue 和一个 out_queue,并创建一个子线程,将 worker,in_queue 与 out_queue 绑定在一起。worker 则在子线程中进行,持续在 in_queue 中获取任务,执行后将结果放进 out_queue,我们从而可以通过向 in_queue 中添加任务,获取 out_queue 中输出的方法来执行任务。由于 worker 在子线程中进行,所以这一执行过程是并发的。


看完了 in_queue 与 worker,我们回到 compute 函数中来。


   for i, j in schedule:        # 获取执行状态与结果        ok, payload = self.out_queues[j].get()
       # Hold the first exception.        if exc_info is not None:            continue        elif not ok:            exc_info = cast(ExcInfo, payload)            continue        # 在任务成功执行结束后,拿到任务(准备可能存在的重计算)        # 以及计算结果,即 batch        task, batch = cast(Tuple[Task, Batch], payload)
       # The copy stream synchronizes to copy the output. ([3] in the        # diagram)        if j != n - 1:            # 非最后一个 stage 情况下,数据拷贝流需要等待执行流结束            # 以准备进入下一个 fence 过程            _wait(batch, streams[j], copy_streams[j][i])
       # Finalize tasks. If checkpointing is enabled, here the        # recomputation is scheduled at backpropagation. ([4] in the        # diagram)        with use_device(devices[j]):            task.finalize(batch)
       batches[i] = batch
   # Fail at the first exception.    if exc_info is not None:        raise exc_info[0].with_traceback(exc_info[1], exc_info[2])


在经过了数据发送阶段后,进入数据接收阶段。在这一阶段中,算法根据 schedule 指示,对相应的 out_queue 中的结果进行收集。如果有报错信息,那么先收集,等循环跳出后报错。如果不是最后一阶段,那么需要让数据拷贝流等待执行流上的任务结束。等到任务执行完成后,调用 finalize 函数彻底完成任务。finalize 函数在启用检查点技术的时候的主要作用即为重计算梯度,而未启用检查点情况下则直接跳过。


看到这里的小伙伴们是否对于 run 函数有了更深刻的认识呢?整体 run 函数可以用以下流程图展示:




4 总结


本文介绍了 PyTorch 流水线并行的使用方法与实现细节。目前 MMEngine 也正在支持仅适用于模型推理的流水线并行方法,相关 PR 位于https://github.com/open-mmlab/mmengine/pull/1355,欢迎大家关注以及在 PR 留下自己的 comment!


三模态版ChatGPT震撼来袭,OpenAI卷起来了

2023-09-26

你的3D感知模型够鲁棒吗?Robo3D告诉你答案!

2023-09-25

如何在秋招脱颖而出,从理论到实战的经验分享 |《offer来了》第2期

2023-09-22




继续滑动看下一个
向上滑动看下一个

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存