其他
【他山之石】Pytorch技巧:DataLoader的collate_fn参数使用详解
“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。
地址:https://www.zhihu.com/people/AI_team-WSF
class torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=<function default_collate>,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None)
shuffle:设置为True的时候,每个世代都会打乱数据集。 collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能。 drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留。
import torch
import torch.utils.data as Data
import numpy as np
test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])
inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))
torch_dataset = Data.TensorDataset(inputing,target)
batch = 3
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=batch,
)
for (i, j) in loader:
print(i)
print(j)
tensor([[0, 1, 2],
[1, 2, 3],
[2, 3, 4]], dtype=torch.int32)
tensor([[0],
[1],
[2]], dtype=torch.int32)
tensor([[3, 4, 5],
[4, 5, 6],
[5, 6, 7]], dtype=torch.int32)
tensor([[3],
[4],
[5]], dtype=torch.int32)
tensor([[ 6, 7, 8],
[ 7, 8, 9],
[ 8, 9, 10]], dtype=torch.int32)
tensor([[6],
[7],
[8]], dtype=torch.int32)
tensor([[ 9, 10, 11]], dtype=torch.int32)
tensor([[9]], dtype=torch.int32)
collate_fn=lambda x:(
torch.cat(
[x[i][j].unsqueeze(0) for i in range(len(x))], 0
) for j in range(len(x[0]))
)
collate_fn=lambda x:x
for i in loader:
print(i)
[(tensor([0, 1, 2], dtype=torch.int32), tensor([0], dtype=torch.int32)), (tensor([1, 2, 3], dtype=torch.int32), tensor([1], dtype=torch.int32)), (tensor([2, 3, 4], dtype=torch.int32), tensor([2], dtype=torch.int32))]
[(tensor([3, 4, 5], dtype=torch.int32), tensor([3], dtype=torch.int32)), (tensor([4, 5, 6], dtype=torch.int32), tensor([4], dtype=torch.int32)), (tensor([5, 6, 7], dtype=torch.int32), tensor([5], dtype=torch.int32))]
[(tensor([6, 7, 8], dtype=torch.int32), tensor([6], dtype=torch.int32)), (tensor([7, 8, 9], dtype=torch.int32), tensor([7], dtype=torch.int32)), (tensor([ 8, 9, 10], dtype=torch.int32), tensor([8], dtype=torch.int32))]
[(tensor([ 9, 10, 11], dtype=torch.int32), tensor([9], dtype=torch.int32))]
每个i都是一个列表,每个列表包含batch_size个元组,每个元组包含TensorDataset的单独数据。所以要将重新组合成每个batch包含3*3的input和3*1的target,就要重新解包并打包。看看我们的collate_fn:
collate_fn=lambda x:(
torch.cat(
[x[i][j].unsqueeze(0) for i in range(len(x))], 0
) for j in range(len(x[0]))
)
collate_fn=lambda x:(
torch.cat(
[x[i][j].unsqueeze(0) for i in range(len(x))], 0
).unsqueeze(0) for j in range(len(x[0]))
)
tensor([[[0, 1, 2],
[1, 2, 3],
[2, 3, 4]]], dtype=torch.int32)
tensor([[[0],
[1],
[2]]], dtype=torch.int32)
tensor([[[3, 4, 5],
[4, 5, 6],
[5, 6, 7]]], dtype=torch.int32)
tensor([[[3],
[4],
[5]]], dtype=torch.int32)
tensor([[[ 6, 7, 8],
[ 7, 8, 9],
[ 8, 9, 10]]], dtype=torch.int32)
tensor([[[6],
[7],
[8]]], dtype=torch.int32)
tensor([[[ 9, 10, 11]]], dtype=torch.int32)
tensor([[[9]]], dtype=torch.int32)
def detection_collate(batch):
"""Custom collate fn for dealing with batches of images that have a different
number of associated object annotations (bounding boxes).
Arguments:
batch: (tuple) A tuple of tensor images and lists of annotations
Return:
A tuple containing:
1) (tensor) batch of images stacked on their 0 dim
2) (list of tensors) annotations for a given image are stacked on
0 dim
"""
targets = []
imgs = []
for sample in batch:
imgs.append(sample[0])
targets.append(torch.FloatTensor(sample[1]))
return torch.stack(imgs, 0), targets
# 代码只写出了collate_fn部分,其余的省略了。
dataloader = torch.utils.data.DataLoader(
collate_fn=detection_collate,
)
参考
本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。
“他山之石”历史文章
Pytorch优化器及其内置优化算法原理介绍
神经网络学习 | 鸢尾花分类的实现
Pytorch 基础-tensor 数据结构
Transformer风险评分:实体嵌入+注意力机制
Pytorch:eval()的用法比较
ONNX模型文件->可执行文件 C Runtime通路 具体实现方法
Pytorch mixed precision 概述(混合精度)
Weights & Biases (兼容多种深度学习框架的可视化工具WB中文简介)
GCN实现及其中的归一化
Pytorch Lightning 完全攻略
Tensorflow之TFRecord的原理和使用心得
从零开始实现一个卷积神经网络
斯坦福大规模网络数据集
超轻量的YOLO-Nano
MMAction2: 新一代视频理解工具箱
更多他山之石专栏文章,
请点击文章底部“阅读原文”查看
分享、点赞、在看,给个三连击呗!