对齐PyTorch,一文详解OneFlow的DataLoader实现
撰文 | 赵露阳
新增动态图特性:OneFlow 默认以动态图模式(eager)运行,与静态图模式(graph)相比,更容易搭建网络、调试和验证算法。
面向对象式的动态图接口 nn.Module,熟悉 PyTorch 的用户可以轻松上手。
“一行代码转换 OneFlow 与 PyTorch 网络”:与 PyTorch 对齐的算子数目增加至200+。在 ResNet50、AlexNet 等 十几个常用网络 上已通过 import oneflow as torch 和 import torch as flow 验证。注意:此特性是为方便用户由 PyTorch 迁移至 OneFlow 而设计,并不是承诺完全兼容 PyTorch。
面向对象式的静态图接口:新增面向对象的静态图接口 nn.Graph。保留了 OneFlow 静态图性能优势的同时,让静态图的编程门槛与动态图接近,期待更多的算法工程师把 OneFlow 的高性能优势玩起来。这是一个用 nn.Graph 搭建 ResNet50 示例
易用高效的分布式训练:分布式训练是大势所趋,OneFlow 本版本新增的 Consistent Tensor,让用户可以像操作单机单卡一样,操作整个集群,并立即看到效果。新增的 launch 模块、DDP 模块 配合 OneFlow 的一致性视角 让用户轻松启动分布式训练,无论是 数据并行、模型并行、还是流水并行,OneFlow 均原生支持,易用高效。
https://github.com/Oneflow-Inc/oneflow/pull/5406
https://github.com/Oneflow-Inc/oneflow/pull/5500
https://github.com/Oneflow-Inc/oneflow/pull/5644
https://github.com/Oneflow-Inc/oneflow/pull/6280
dataloader简介 dataloader原理 dataloader工作流程 multiprocessing dataloader工作原理
2
dataloader原理
核心组建
Dataloader
Dataset
Sampler
Fetcher
使用示例
1.MNIST
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
dataset1 = datasets.MNIST('../data', train=True, download=True,
transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.targets[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
....
2.ImageNet
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
class ImageFolder(DatasetFolder):
r"""A generic data loader where the images are arranged in this way by default:
.. code-block:: shell
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
This class inherits from :class:`~vision.datasets.DatasetFolder` so
the same methods can be overridden to customize the dataset.
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
super(ImageFolder, self).__init__(
root,
loader,
IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file,
)
self.imgs = self.samples
root:图片文件夹路径
transform:对经过loader读取到的PIL图片,经过哪些transform处理,如上述的Resize、CenterCrop等
loader:一个用于根据path加载图片的图像加载器,通常默认的loader是PIL
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
for i, (images, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
if args.gpu is not None:
images = images.cuda(args.gpu, non_blocking=True)
if torch.cuda.is_available():
target = target.cuda(args.gpu, non_blocking=True)
.....
3
dataloader工作流程
_getitem__
方法,用于定义根据传入的index获取数据的方式。同时,自定义数据集也可选重写len方法,用于判断数据集的size。class Dataset(Generic[T_co]):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~flow.utils.data.Sampler` implementations and the default options
of :class:`~flow.utils.data.DataLoader`.
.. note::
:class:`~flow.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
def __getitem__(self, index) -> T_co:
raise NotImplementedError
def __add__(self, other: "Dataset[T_co]") -> "ConcatDataset[T_co]":
return ConcatDataset([self, other])
class DataLoader(Generic[T_co]):
def __init__(
self,
dataset: Dataset[T_co],
batch_size: Optional[int] = 1,
shuffle: bool = False,
sampler: Optional[Sampler[int]] = None,
batch_sampler: Optional[Sampler[Sequence[int]]] = None,
num_workers: int = 0,
collate_fn: Optional[_collate_fn_t] = None,
drop_last: bool = False,
timeout: float = 0,
worker_init_fn: Optional[_worker_init_fn_t] = None,
multiprocessing_context=None,
generator=None,
*,
prefetch_factor: int = 2,
persistent_workers: bool = False
):
...
...
# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
# since '_BaseDataLoaderIter' references 'DataLoader'.
def __iter__(self) -> "_BaseDataLoaderIter":
# When using a single worker the returned iterator should be
# created everytime to avoid reseting its state
# However, in the case of a multiple workers iterator
# the iterator is only created once in the lifetime of the
# DataLoader object so that workers can be reused
if self.persistent_workers and self.num_workers > 0:
if self._iterator is None:
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
return self._iterator
else:
return self._get_iterator()
def _get_iterator(self) -> "_BaseDataLoaderIter":
if self.num_workers == 0 or self.num_workers == 1:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
__iter__
方法完成取数据和label。__iter__
里通过_get_iterator
方法获取相应的DataLoaderIter实例。在单进程下,即
_SingleProcessDataLoaderIter
;多进程下,即
_MultiProcessingDataLoaderIter
,他们都继承自_BaseDataLoaderIter
class _BaseDataLoaderIter(object):
def __init__(self, loader: DataLoader) -> None:
self._dataset = loader.dataset
self._dataset_kind = loader._dataset_kind
self._IterableDataset_len_called = loader._IterableDataset_len_called
self._auto_collation = loader._auto_collation
self._drop_last = loader.drop_last
self._index_sampler = loader._index_sampler
self._num_workers = loader.num_workers
self._prefetch_factor = loader.prefetch_factor
self._pin_memory = False
self._timeout = loader.timeout
self._collate_fn = loader.collate_fn
self._sampler_iter = iter(self._index_sampler)
self._base_seed = flow.tensor([0], dtype=flow.int64).uniform_().numpy().item()
# TODO: flow.empty()
# self._base_seed = flow.empty((), dtype=flow.int64).random_(generator=loader.generator).item()
self._persistent_workers = loader.persistent_workers
self._num_yielded = 0
self._profile_name = "enumerate(DataLoader)#{}.__next__".format(
self.__class__.__name__
)
def __iter__(self) -> "_BaseDataLoaderIter":
return self
def _reset(self, loader, first_iter=False):
self._sampler_iter = iter(self._index_sampler)
self._num_yielded = 0
self._IterableDataset_len_called = loader._IterableDataset_len_called
def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration
def _next_data(self):
raise NotImplementedError
def __next__(self) -> Any:
if self._sampler_iter is None:
self._reset()
data = self._next_data()
self._num_yielded += 1
if (
self._dataset_kind == _DatasetKind.Iterable
and self._IterableDataset_len_called is not None
and self._num_yielded > self._IterableDataset_len_called
):
warn_msg = (
"Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
"samples have been fetched. "
).format(self._dataset, self._IterableDataset_len_called, self._num_yielded)
if self._num_workers > 1:
warn_msg += "Multiprocessing dataloader is not support yet!"
warnings.warn(warn_msg)
return data
def __len__(self) -> int:
return len(self._index_sampler)
def __getstate__(self):
raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert 0 <= self._num_workers <= 1
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind,
self._dataset,
self._auto_collation,
self._collate_fn,
self._drop_last,
)
def _next_data(self):
index = self._next_index() # may raise StopIteration
if self._pin_memory:
raise NotImplementedError("Dataloader pin memory is not support yet!")
return self._dataset_fetcher.fetch(index)
__next__
方法,进而调用自类实现的_next_data
方法获取数据。以_SingleProcessDataLoaderIter
为例:index = self._next_index()
通过Sampler获取此次迭代的数据集索引;self._dataset_fetcher.fetch(index)
Fetcher根据index索引取相应的数据。
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(
dataset, auto_collation, collate_fn, drop_last
)
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
_MapDatasetFetcher
的子类实现为例,看一下Fetcher的主要工作。data = [self.dataset[idx] for idx in possibly_batched_index]
return self.collate_fn(data)
collate_fn
方法,收集处理这batch个数据,并打包成训练/验证时可直接使用的Tensor。4
multiprocessing dataloader工作原理
原理
# prime the prefetch loop
for _ in range(self._prefetch_factor * self._num_workers):
self._try_put_index()
工作流程
def _next_data(self):
# DataLoaderIter通过此方法获取每个iter的数据,主要调用_get_data实现
def _get_data(self):
# _get_data方法中,主要通过调用_try_get_data()获取数据
def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
# 从主进程的_data_queue中获取数据
...
try:
data = self._data_queue.get(timeout=timeout)
return (True, data)
except Exception as e:
...
def _process_data(self, data):
# 主要工作即:1.通过_try_put_index()来将下一个iter的index放入一个活跃的worker进程中
# 2.同时标记_rcvd_idx,使其增加1。
self._rcvd_idx += 1
self._try_put_index()
if isinstance(data, ExceptionWrapper):
data.reraise()
return data
def _try_put_index(self):
# 主要工作即遍历所有workers,找到第一个活跃的worker(worker_queue_idx标识)
# 将index和_send_idx信息放入此worker的index_queue中
# 每个worker拥有独立的index_queue,收到index_queue的信息后即开始工作
assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
try:
index = self._next_index()
except StopIteration:
return
for _ in range(self._num_workers): # find the next active worker, if any
worker_queue_idx = next(self._worker_queue_idx_cycle)
if self._workers_status[worker_queue_idx]:
break
else:
# not found (i.e., didn't break)
return
self._index_queues[worker_queue_idx].put((self._send_idx, index))
self._task_info[self._send_idx] = (worker_queue_idx,)
self._tasks_outstanding += 1
self._send_idx += 1
oneflow/python/oneflow/utils/data/_utils/worker.py
的_worker_loop()方法中:while watchdog.is_alive():
try:
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
if isinstance(r, _ResumeIteration):
# Acknowledge the main process
data_queue.put((r, None))
iteration_end = False
# Recreate the fetcher for worker-reuse policy
fetcher = _DatasetKind.create_fetcher(
dataset_kind, dataset, auto_collation, collate_fn, drop_last
)
continue
elif r is None:
# Received the final signal
assert done_event.is_set() or iteration_end
break
elif done_event.is_set() or iteration_end:
# `done_event` is set. But I haven't received the final signal
# (None) yet. I will keep continuing until get it, and skip the
# processing steps.
continue
idx, index = r
data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
if init_exception is not None:
data = init_exception
init_exception = None
else:
try:
data = fetcher.fetch(index)
except Exception as e:
if (
isinstance(e, StopIteration)
and dataset_kind == _DatasetKind.Iterable
):
data = _IterableDatasetStopIteration(worker_id)
# Set `iteration_end`
# (1) to save future `next(...)` calls, and
# (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
iteration_end = True
else:
# It is important that we don't store exc_info in a variable.
# `ExceptionWrapper` does the correct thing.
# See NOTE [ Python Traceback Reference Cycle Problem ]
data = ExceptionWrapper(
where="in DataLoader worker process {}".format(worker_id)
)
data_queue.put((idx, data))
del data, idx, index, r # save memory
except KeyboardInterrupt:
# Main process will raise KeyboardInterrupt anyways.
pass
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
获取index_queue中的index数据,就会开始工作:idx, index = r
>> data = fetcher.fetch(index)
这部分内容和之前描述的单进程DataLoader的工作流程没有区别。data_queue.put((idx, data))
等待DataLoader主线程从queue中获取结果。