首页
技术小册
AIGC
面试刷题
技术文章
MAGENTO
云计算
视频课程
源码下载
PDF书籍
「涨薪秘籍」
登录
注册
01 | PyTorch:网红中的顶流明星
02 | NumPy(上):核心数据结构详解
03 | NumPy(下):深度学习中的常用操作
04 | Tensor:PyTorch中最基础的计算单元
05 | Tensor变形记:快速掌握Tensor切分、变形等方法
06 | Torchvision(上):数据读取,训练开始的第一步
07 | Torchvision(中):数据增强,让数据更加多样性
08 | Torchvision(下):其他有趣的功能
09 | 卷积(上):如何用卷积为计算机“开天眼”?
10 | 卷积(下):如何用卷积为计算机“开天眼”?
11 | 损失函数:如何帮助模型学会“自省”?
12 | 计算梯度:网络的前向与反向传播
13 | 优化方法:更新模型参数的方法
14 | 构建网络:一站式实现模型搭建与训练
15 | 可视化工具:如何实现训练的可视化监控?
16|分布式训练:如何加速你的模型训练?
17 | 图像分类(上):图像分类原理与图像分类模型
18 | 图像分类(下):如何构建一个图像分类模型?
19 | 图像分割(上):详解图像分割原理与图像分割模型
20 | 图像分割(下):如何构建一个图像分割模型?
21 | NLP基础(上):详解自然语言处理原理与常用算法
22 | NLP基础(下):详解语言模型与注意力机制
23 | 情感分析:如何使用LSTM进行情感分析?
24 | 文本分类:如何使用BERT构建文本分类模型?
25 | 摘要:如何快速实现自动文摘生成?
当前位置:
首页>>
技术小册>>
PyTorch深度学习实战
小册名称:PyTorch深度学习实战
### 06 | Torchvision(上):数据读取,训练开始的第一步 在深度学习的世界里,数据是驱动模型学习与进步的核心燃料。而`Torchvision`,作为PyTorch生态系统中的一个重要库,专注于提供计算机视觉领域的工具和数据集,极大地简化了数据预处理、增强及模型评估的流程。本章将深入探讨`Torchvision`在数据读取方面的应用,这是任何深度学习项目启动的第一步,也是奠定训练成效基石的关键环节。 #### 一、Torchvision简介 `Torchvision`是围绕PyTorch构建的一个开源库,旨在促进计算机视觉研究与应用。它包含了大量预训练的模型、数据集加载器、图像转换工具等,使得研究人员和开发者能够轻松地进行图像分类、检测、分割等任务。`Torchvision`的数据集模块(`torchvision.datasets`)和转换模块(`torchvision.transforms`)是本章讨论的重点。 #### 二、数据读取的重要性 在深度学习训练中,数据的准备和预处理往往占据了大量的时间和资源。高质量的数据集不仅能提升模型的性能,还能加快训练速度,减少过拟合风险。因此,掌握如何高效地读取、预处理数据,是每个深度学习从业者必须掌握的技能之一。 #### 三、使用`torchvision.datasets`加载数据集 `torchvision.datasets`提供了许多常用的数据集加载器,如CIFAR10、CIFAR100、MNIST、ImageNet等,这些加载器自动处理数据的下载(如果本地不存在)、加载和格式化,极大地简化了数据准备过程。 ##### 示例:加载CIFAR10数据集 CIFAR10是一个包含60000张32x32彩色图像的数据集,分为10个类别,每个类别有6000张图像。下面是如何使用`torchvision`加载CIFAR10数据集的示例: ```python import torch import torchvision import torchvision.transforms as transforms # 定义数据转换操作 transform = transforms.Compose([ transforms.ToTensor(), # 将图片转换为Tensor transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化处理 ]) # 加载训练集和测试集 trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2) # 遍历训练数据 for images, labels in trainloader: print(images.shape) # 输出: torch.Size([4, 3, 32, 32]) print(labels.shape) # 输出: torch.Size([4]) break # 仅展示一个批次的数据 ``` #### 四、使用`torchvision.transforms`进行数据预处理 `torchvision.transforms`模块提供了一系列图像转换操作,如裁剪、旋转、缩放、归一化等,这些操作可以单独使用,也可以通过`Compose`类组合起来使用,以便在加载数据时进行一系列连续的预处理操作。 ##### 常见的数据预处理步骤 1. **转换为Tensor**:将PIL图像或NumPy ndarray转换为PyTorch的Tensor,这是进行深度学习计算的前提。 2. **归一化**:调整图像数据的分布,使其符合模型训练的期望输入范围,如将像素值从[0, 255]缩放到[-1, 1]或[0, 1]。 3. **数据增强**:通过随机裁剪、旋转、翻转等操作增加数据集的多样性,提高模型的泛化能力。 ##### 示例:组合多个转换操作 ```python transform = transforms.Compose([ transforms.RandomResizedCrop(224), # 随机裁剪并缩放至224x224 transforms.RandomHorizontalFlip(), # 随机水平翻转 transforms.ToTensor(), # 转换为Tensor transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 使用ImageNet的均值和标准差进行归一化 ]) ``` #### 五、数据加载与并行处理 在实际应用中,为了加快数据加载速度,通常会使用`torch.utils.data.DataLoader`来封装数据集。`DataLoader`支持多线程/多进程加载数据,自动打乱数据顺序(在训练集上),并批量返回数据,这对于加速训练过程至关重要。 ##### 示例:使用`DataLoader` ```python trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=4) for images, labels in trainloader: # 进行模型训练 pass ``` #### 六、总结与展望 通过本章的学习,我们了解了`Torchvision`在数据读取和预处理方面的重要性,掌握了如何使用`torchvision.datasets`加载常用数据集,以及如何利用`torchvision.transforms`进行数据预处理和增强。这些技能是启动深度学习训练项目的第一步,也是至关重要的一步。 未来,随着深度学习应用的不断扩展,我们将面临更多样化、更复杂的数据集。因此,继续深入探索数据预处理和增强的高级技术,如自动化数据增强策略、领域自适应方法等,将是我们不断追求的目标。同时,了解并掌握更多高级数据集加载技巧,如使用`torch.utils.data.Dataset`和`torch.utils.data.DataLoader`的自定义扩展,也将是提升我们深度学习项目效率和效果的关键。
上一篇:
05 | Tensor变形记:快速掌握Tensor切分、变形等方法
下一篇:
07 | Torchvision(中):数据增强,让数据更加多样性
该分类下的相关小册推荐:
人工智能原理、技术及应用(中)
ChatGPT使用指南
人工智能超入门丛书--情感分析
ChatGPT与提示工程(下)
快速部署大模型:LLM策略与实践(上)
TensorFlow项目进阶实战
AI智能写作: 巧用AI大模型让新媒体变现插上翅膀
深度学习与大模型基础(下)
AIGC原理与实践:零基础学大语言模型(二)
人工智能基础——基于Python的人工智能实践(上)
玩转ChatGPT:秒变AI提问和追问高手(下)
深度强化学习--算法原理与金融实践(一)