查看原文
其他

【他山之石】Pytorch 基础-tensor 数据结构

“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。

作者:知乎—唐奋

地址:https://www.zhihu.com/people/tang-fen-44-49


01

torch.Tensor

torch.Tensor 是一种包含单一数据类型元素的多维矩阵,类似于 numpy 的 array。Tensor 可以使用 torch.tensor() 转换 Python 的 list 或序列数据生成,生成的是dtype 默认是 torch.FloatTensor。

注意 torch.tensor() 总是拷贝 data。如果你有一个 Tensor data 并且仅仅想改变它的 requires_grad 属性,可用 requires_grad_() 或者 detach() 来避免拷贝。如果你有一个 numpy 数组并且想避免拷贝,请使用 torch.as_tensor()。
1,指定数据类型的 Tensor 可以通过传递参数 torch.dtype 和/或者 torch.device 到构造函数生成:
注意为了改变已有的 tensor 的 torch.device 和/或者 torch.dtype, 考虑使用 to() 方法.
>>> torch.ones([2,3], dtype=torch.float64, device="cuda:0")tensor([[1., 1., 1.], [1., 1., 1.]], device='cuda:0', dtype=torch.float64)>>> torch.ones([2,3], dtype=torch.float32)tensor([[1., 1., 1.], [1., 1., 1.]])
2,Tensor 的内容可以通过 Python索引或者切片访问以及修改:
>>> matrix = torch.tensor([[2,3,4],[5,6,7]])>>> print(matrix[1][2])tensor(7)>>> matrix[1][2] = 9>>> print(matrix)tensor([[2, 3, 4], [5, 6, 9]])
3,使用 torch.Tensor.item() 或者 int() 方法从只有一个值的 Tensor中获取 Python Number:
>>> x = torch.tensor([[4.5]])>>> xtensor([[4.5000]])>>> x.item()4.5>>> int(x)4
4,Tensor可以通过参数 requires_grad=True 创建, 这样 torch.autograd 会记录相关的运算实现自动求导:
>>> x = torch.tensor([[1., -1.], [1., 1.]], requires_grad=True)>>> out = x.pow(2).sum()>>> out.backward()>>> x.gradtensor([[ 2.0000, -2.0000], [ 2.0000, 2.0000]])
5,每一个 tensor都有一个相应的 torch.Storage 保存其数据。tensor 类提供了一个多维的、strided 视图, 并定义了数值操作。

02

Tensor 数据类型
Torch 定义了七种 CPU tensor 类型和八种 GPU tensor 类型:
torch.Tensor 是默认的 tensor 类型(torch.FloatTensor)的简称,即 32 位浮点数数据类型。


03

Tensor 的属性
Tensor 有很多属性,包括数据类型、Tensor 的维度、Tensor 的尺寸。
  • 数据类型:可通过改变 torch.tensor() 方法的 dtype 参数值,来设定不同的 tensor 数据类型。
  • 维度:不同类型的数据可以用不同维度(dimension)的张量来表示。标量为 0 维张量,向量为 1 维张量,矩阵为 2 维张量。彩色图像有 rgb 三个通道,可以表示为 3 维张量。视频还有时间维,可以表示为 4 维张量,有几个中括号 [ 维度就是几。可使用 dim() 方法 获取 tensor 的维度。
  • 尺寸:可以使用 shape属性或者 size()方法查看张量在每一维的长度,可以使用 view()方法或者reshape() 方法改变张量的尺寸。
样例代码如下:
matrix = torch.tensor([[[1,2,3,4],[5,6,7,8]], [[5,4,6,7], [5,6,8,9]]], dtype = torch.float64)print(matrix) # 打印 tensorprint(matrix.dtype) # 打印 tensor 数据类型print(matrix.dim()) # 打印 tensor 维度print(matrix.size()) # 打印 tensor 尺寸print(matrix.shape) # 打印 tensor 尺寸matrix2 = matrix.view(4, 2, 2) # 改变 tensor 尺寸print(matrix2)
程序输出结果如下:


04

view 和 reshape 的区别
  • 两个方法都是用来改变 tensor 的 shape,view() 只适合对满足连续性条件(contiguous)的 tensor 进行操作,而 reshape() 同时还可以对不满足连续性条件的 tensor 进行操作。
  • 在满足 tensor 连续性条件(contiguous)时,a.reshape() 返回的结果与a.view() 相同,都不会开辟新内存空间;不满足 contiguous 时, 直接使用 view() 方法会失败,reshape() 依然有用,但是会重新开辟内存空间,不与之前的 tensor 共享内存,即返回的是 ”副本“(等价于先调用 contiguous() 方法再使用 view() 方法)。更多理解参考这篇文章


05

Tensor 与 ndarray
1,张量和 numpy 数组。可以用 .numpy() 方法从 Tensor 得到 numpy 数组,也可以用 torch.from_numpy 从 numpy 数组得到Tensor。这两种方法关联的 Tensor 和 numpy 数组是共享数据内存的。可以用张量的 clone方法拷贝张量,中断这种关联。
arr = np.random.rand(4,5)print(type(arr))tensor1 = torch.from_numpy(arr)print(type(tensor1))arr1 = tensor1.numpy()print(type(arr1))"""<class 'numpy.ndarray'><class 'torch.Tensor'><class 'numpy.ndarray'>"""
2,item() 方法和 tolist() 方法可以将张量转换成 Python 数值和数值列表
# item方法和tolist方法可以将张量转换成Python数值和数值列表scalar = torch.tensor(5) # 标量s = scalar.item()print(s)print(type(s))
tensor = torch.rand(3,2) # 矩阵t = tensor.tolist()print(t)print(type(t))"""1.0<class 'float'>[[0.8211846351623535, 0.20020723342895508], [0.011571824550628662, 0.2906131148338318]]<class 'list'>"""


06

创建 Tensor
创建 tensor ,可以传入数据或者维度,torch.tensor() 方法只能传入数据,torch.Tensor() 方法既可以传入数据也可以传维度,强烈建议 tensor() 传数据,Tensor() 传维度,否则易搞混。

07

传入维度的方法
样例代码:
>>> torch.normal(2, 3, size=(1, 4))tensor([[3.6851, 3.2853, 1.8538, 3.5181]])>>> torch.full([2, 2], 4)tensor([[4, 4], [4, 4]])>>> torch.arange(0,10,2)tensor([0, 2, 4, 6, 8])>>> torch.eye(3,3)tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])

参考资料

PyTorch:view() 与 reshape() 区别详解

https://blog.csdn.net/Flag_ing/article/details/109129752

torch.rand和torch.randn和torch.normal和linespace()

https://zhuanlan.zhihu.com/p/115997577

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


“他山之石”历史文章


更多他山之石专栏文章,

请点击文章底部“阅读原文”查看



分享、点赞、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存