首页
技术小册
AIGC
面试刷题
技术文章
MAGENTO
云计算
视频课程
源码下载
PDF书籍
「涨薪秘籍」
登录
注册
01 | PyTorch:网红中的顶流明星
02 | NumPy(上):核心数据结构详解
03 | NumPy(下):深度学习中的常用操作
04 | Tensor:PyTorch中最基础的计算单元
05 | Tensor变形记:快速掌握Tensor切分、变形等方法
06 | Torchvision(上):数据读取,训练开始的第一步
07 | Torchvision(中):数据增强,让数据更加多样性
08 | Torchvision(下):其他有趣的功能
09 | 卷积(上):如何用卷积为计算机“开天眼”?
10 | 卷积(下):如何用卷积为计算机“开天眼”?
11 | 损失函数:如何帮助模型学会“自省”?
12 | 计算梯度:网络的前向与反向传播
13 | 优化方法:更新模型参数的方法
14 | 构建网络:一站式实现模型搭建与训练
15 | 可视化工具:如何实现训练的可视化监控?
16|分布式训练:如何加速你的模型训练?
17 | 图像分类(上):图像分类原理与图像分类模型
18 | 图像分类(下):如何构建一个图像分类模型?
19 | 图像分割(上):详解图像分割原理与图像分割模型
20 | 图像分割(下):如何构建一个图像分割模型?
21 | NLP基础(上):详解自然语言处理原理与常用算法
22 | NLP基础(下):详解语言模型与注意力机制
23 | 情感分析:如何使用LSTM进行情感分析?
24 | 文本分类:如何使用BERT构建文本分类模型?
25 | 摘要:如何快速实现自动文摘生成?
当前位置:
首页>>
技术小册>>
PyTorch深度学习实战
小册名称:PyTorch深度学习实战
### 05 | Tensor变形记:快速掌握Tensor切分、变形等方法 在PyTorch的深度学习之旅中,Tensor(张量)作为数据处理与模型计算的核心单元,其操作的灵活性和高效性直接关系到模型设计的复杂度和训练效率。本章“Tensor变形记:快速掌握Tensor切分、变形等方法”将深入探讨PyTorch中Tensor的切分(Splitting)、变形(Reshaping)、转置(Transposing)以及合并(Concatenating)等关键操作,帮助读者快速掌握这些技巧,从而在构建和训练深度学习模型时游刃有余。 #### 一、Tensor基础回顾 在深入讲解之前,我们先简要回顾一下Tensor的基本概念。Tensor是PyTorch中的一个核心概念,它类似于NumPy的ndarray,但专为GPU加速计算而设计。Tensor可以看作是一个高维数组,其维度由`torch.Size`对象表示,可以通过`.shape`属性访问。Tensor的元素可以是任何数据类型,但最常见的是浮点数(用于神经网络计算)。 #### 二、Tensor切分(Splitting) Tensor的切分是将一个大的Tensor沿着指定的维度分割成多个较小的Tensor的过程。PyTorch提供了`torch.split()`和`torch.chunk()`两个主要函数来实现这一功能。 ##### 2.1 torch.split() `torch.split()`函数允许你根据指定的分割大小或分割段数来切分Tensor。其语法如下: ```python torch.split(tensor, split_size_or_sections, dim=0) ``` - `tensor`:待切分的Tensor。 - `split_size_or_sections`:可以是整数或整数列表。如果是整数,表示每个分割段的大小;如果是整数列表,则直接指定每个段的具体大小。 - `dim`:沿着哪个维度进行切分,默认为0。 **示例**: ```python import torch x = torch.arange(9).reshape(3, 3) # 使用整数分割大小 split1 = torch.split(x, 2, dim=1) # 沿第1维(列)分割,每块2列 # 使用整数列表指定具体大小 split2 = torch.split(x, [1, 2, 1], dim=0) # 沿第0维(行)分割,分别为1, 2, 1行 print(list(split1)) print(list(split2)) ``` ##### 2.2 torch.chunk() `torch.chunk()`函数根据指定的分割段数来切分Tensor,每个段的大小尽可能相等(除了最后一个段可能较小)。其语法如下: ```python torch.chunk(tensor, chunks, dim=0) ``` - `tensor`:待切分的Tensor。 - `chunks`:分割成的段数。 - `dim`:沿着哪个维度进行切分,默认为0。 **示例**: ```python chunks = torch.chunk(x, 3, dim=1) # 沿第1维(列)分割成3块 print(list(chunks)) ``` #### 三、Tensor变形(Reshaping) Tensor的变形操作不改变其元素的总数,只改变其形状。PyTorch提供了多种方法来实现Tensor的变形,其中`torch.reshape()`和`torch.view()`最为常用。 ##### 3.1 torch.reshape() `torch.reshape()`函数返回一个新的Tensor,其数据与输入Tensor相同,但形状已改变。如果新形状与原始形状的乘积不相等,则会抛出错误。 ```python y = x.reshape(1, 9) # 将3x3的Tensor变形为1x9 print(y) ``` ##### 3.2 torch.view() `torch.view()`函数与`reshape()`类似,但它会尝试在不复制数据的情况下返回一个新的视图。如果新形状与原始数据的存储方式不兼容(例如,跨越了非连续的内存块),则会抛出错误。 ```python z = x.view(9, 1) # 将3x3的Tensor变形为9x1 print(z) ``` #### 四、Tensor转置(Transposing) Tensor的转置操作是交换其两个维度的过程。在PyTorch中,可以使用`.transpose(dim0, dim1)`或`.T`(仅对二维Tensor有效)来实现。 ##### 4.1 .transpose(dim0, dim1) ```python transposed = x.transpose(0, 1) # 交换第0维和第1维 print(transposed) ``` ##### 4.2 .T(仅限二维) ```python if x.ndim == 2: print(x.T) # 对于二维Tensor,.T是.transpose(0, 1)的简写 ``` #### 五、Tensor合并(Concatenating) Tensor的合并操作是将多个Tensor沿着指定的维度拼接起来,形成一个更大的Tensor。PyTorch提供了`torch.cat()`函数来实现这一功能。 ```python torch.cat(tensors, dim=0) ``` - `tensors`:待合并的Tensor序列。 - `dim`:沿着哪个维度进行合并,默认为0。 **示例**: ```python a = torch.arange(3).reshape(1, 3) b = torch.arange(3, 6).reshape(1, 3) concatenated = torch.cat((a, b), dim=0) # 沿第0维(行)合并 print(concatenated) ``` #### 六、高级变形技巧 除了上述基本的Tensor操作外,PyTorch还提供了更多高级变形技巧,如`torch.squeeze()`(去除单维度条目)、`torch.unsqueeze()`(在指定位置增加单维度)、`torch.flatten()`(展平Tensor,默认展平所有维度但保留一个维度)等,这些技巧在处理复杂模型结构和数据预处理时非常有用。 #### 七、总结 本章通过详细介绍Tensor的切分、变形、转置和合并等操作,展示了PyTorch在处理多维数据时的强大能力。掌握这些基础而关键的Tensor操作技巧,对于设计复杂的深度学习模型、优化模型训练过程以及进行高效的数据预处理都具有重要意义。希望读者能够通过实践,进一步加深对这些操作的理解和应用,从而在PyTorch的深度学习之路上越走越远。
上一篇:
04 | Tensor:PyTorch中最基础的计算单元
下一篇:
06 | Torchvision(上):数据读取,训练开始的第一步
该分类下的相关小册推荐:
ChatGPT使用指南
AI 时代的软件工程
大规模语言模型:从理论到实践(下)
AI时代程序员:ChatGPT与程序员(中)
AI写作宝典:如何成为AI写作高手
深度学习与大模型基础(下)
AI Agent 智能体实战课
人工智能原理、技术及应用(上)
AI时代产品经理:ChatGPT与产品经理(中)
ChatGPT原理与实战:大型语言模型(中)
可解释AI实战PyTorch版(上)
NLP入门到实战精讲(上)