首页
技术小册
AIGC
面试刷题
技术文章
MAGENTO
云计算
视频课程
源码下载
PDF书籍
「涨薪秘籍」
登录
注册
01 | 课程介绍:AI进阶需要落地实战
02 | 内容综述:如何快速⾼效学习AI与TensorFlow 2
03 | TensorFlow 2新特性
04 | TensorFlow 2核心模块
05 | TensorFlow 2 vs TensorFlow 1.x
06 | TensorFlow 2落地应⽤
07 | TensorFlow 2开发环境搭建
08 | TensorFlow 2数据导入与使⽤
09 | 使用tf.keras.datasets加载数据
10 | 使用tf.keras管理Sequential模型
11 | 使用tf.keras管理functional API
12 | Fashion MNIST数据集介绍
13 | 使用TensorFlow2训练分类网络
14 | 行业背景:AI新零售是什么?
15 | 用户需求:线下门店业绩如何提升?
16 | 长期⽬标:货架数字化与业务智能化
17 | 短期目标:自动化陈列审核和促销管理
18 | 方案设计:基于深度学习的检测/分类的AI流水线
19 | 方案交付:支持在线识别和API调用的AI SaaS
20 | 基础:目标检测问题定义与说明
21 | 基础:深度学习在目标检测中的应用
22 | 理论:R-CNN系列二阶段模型综述
23 | 理论:YOLO系列一阶段模型概述
24 | 应用:RetinaNet 与 Facol Loss 带来了什么
25 | 应用:检测数据标注方法与流程
26 | 应用:划分检测训练集与测试集
27 | 应用:生成 CSV 格式数据集与标注
28 | 应用:使用TensorFlow 2训练RetinaNet
29 | 应用:使用RetinaNet检测货架商品
30 | 扩展:目标检测常用数据集综述
31 | 扩展:目标检测更多应用场景介绍
32 | 基础:图像分类问题定义与说明
33 | 基础:越来越深的图像分类网络
34 | 应⽤:检测SKU抠图与分类标注流程
35 | 应⽤:分类训练集与验证集划分
36 | 应⽤:使⽤TensorFlow 2训练ResNet
37 | 应用:使用ResNet识别货架商品
38 | 扩展:图像分类常用数据集综述
39 | 扩展:图像分类更多应⽤场景介绍
40 | 串联AI流程理论:商品检测与商品识别
41 | 串联AI流程实战:商品检测与商品识别
42 | 展现AI效果理论:使用OpenCV可视化识别结果
43 | 展现AI效果实战:使用OpenCV可视化识别结果
44 | 搭建AI SaaS理论:Web框架选型
45 | 搭建AI SaaS理论:数据库ORM选型
46 | 搭建AI SaaS理论:10分钟快速开发AI SaaS
47 | 搭建AI SaaS实战:10 分钟快速开发AI SaaS
48 | 交付AI SaaS:10分钟快速掌握容器部署
49 | 交付AI SaaS:部署和测试AI SaaS
50 | 使⽤TensorFlow 2实现图像数据增强
51 | 使⽤TensorFlow 2实现分布式训练
52 | 使⽤TensorFlow Hub迁移学习
53 | 使⽤@tf.function提升性能
54 | 使⽤TensorFlow Serving部署云端服务
55 | 使⽤TensorFlow Lite实现边缘智能
当前位置:
首页>>
技术小册>>
TensorFlow项目进阶实战
小册名称:TensorFlow项目进阶实战
### 章节 36 | 应用:使用TensorFlow 2训练ResNet 在深度学习领域,卷积神经网络(CNN)尤其是残差网络(ResNet)已成为处理图像识别、分类等任务的首选模型之一。ResNet通过引入残差连接(residual connections)解决了深层网络训练中的梯度消失或梯度爆炸问题,使得构建极深的网络结构成为可能。本章节将详细介绍如何在TensorFlow 2框架下,从零开始构建并训练一个ResNet模型,用于图像分类任务。 #### 36.1 引言 随着计算机视觉技术的飞速发展,对图像数据的处理和分析能力成为衡量AI系统性能的重要指标。ResNet作为深度学习中里程碑式的模型,其强大的特征提取能力和高效的训练效率,使得它在ImageNet等大型图像分类竞赛中屡获佳绩。本章节旨在通过实战演练,帮助读者掌握使用TensorFlow 2构建和训练ResNet模型的全过程。 #### 36.2 TensorFlow 2简介 TensorFlow 2是Google推出的第二代开源机器学习框架,它简化了模型构建、训练和部署的流程。TensorFlow 2引入了Eager Execution(动态图执行)作为默认模式,使得代码更加直观易懂,同时也保留了静态图(Graph Execution)的高性能优势。此外,TensorFlow 2还集成了Keras高级API,进一步降低了深度学习应用的门槛。 #### 36.3 ResNet基础 **36.3.1 残差连接** 残差连接是ResNet的核心思想,它通过直接将输入特征映射(identity mapping)与卷积层的输出相加,实现了跨层的信息传递。这种设计使得网络在训练过程中能够更容易地学习到恒等映射,从而避免了深层网络中的退化问题。 **36.3.2 基本块与瓶颈块** ResNet的基本构建单元包括基本块(Basic Block)和瓶颈块(Bottleneck Block)。基本块由两个3x3的卷积层组成,适用于较浅的ResNet(如ResNet18、ResNet34)。而瓶颈块则采用1x1、3x3、1x1三个卷积层的组合,先通过1x1卷积降维,再通过3x3卷积提取特征,最后通过1x1卷积升维,这种设计减少了计算量,适用于较深的ResNet(如ResNet50、ResNet101等)。 #### 36.4 TensorFlow 2中构建ResNet **36.4.1 导入必要的库** 首先,我们需要导入TensorFlow 2及其相关库,以及数据预处理和模型评估所需的库。 ```python import tensorflow as tf from tensorflow.keras import layers, models, optimizers, regularizers from tensorflow.keras.applications.imagenet_utils import preprocess_input, decode_predictions from tensorflow.keras.preprocessing import image from tensorflow.keras.datasets import cifar10 import numpy as np ``` **36.4.2 定义ResNet块** 接下来,我们定义ResNet的基本块和瓶颈块。 ```python def basic_block(x, filters, kernel_size=3, stride=1, conv_shortcut=True): # 实现基本块逻辑 pass def bottleneck_block(x, filters, kernel_size=3, stride=1, conv_shortcut=True): # 实现瓶颈块逻辑 pass ``` **36.4.3 构建ResNet模型** 基于定义好的ResNet块,我们可以构建完整的ResNet模型。 ```python def ResNet(input_shape, depth, num_classes=10): if (depth - 2) % 9 != 0: raise ValueError('depth should be 6n+2 (e.g., 50, 101, 152)') # 初始化模型 num_blocks = (depth - 2) // 9 num_filters = [64, 128, 256, 512] inputs = tf.keras.Input(shape=input_shape) x = layers.Conv2D(64, (7, 7), strides=(2, 2), padding='same', kernel_initializer='he_normal')(inputs) x = layers.BatchNormalization()(x) x = layers.Activation('relu')(x) x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) # 堆叠残差块 for i in range(3): for j in range(num_blocks): strides = 1 if j == 0 and i != 0 else 2 x = bottleneck_block(x, num_filters[i], stride=strides) # 全局平均池化 x = layers.GlobalAveragePooling2D()(x) # 分类层 outputs = layers.Dense(num_classes, activation='softmax')(x) model = models.Model(inputs=inputs, outputs=outputs, name='ResNet' + str(depth)) return model # 实例化模型 model = ResNet((32, 32, 3), 50, num_classes=10) model.summary() ``` **36.4.4 编译模型** 使用适当的优化器、损失函数和评估指标编译模型。 ```python model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) ``` #### 36.5 数据准备与训练 **36.5.1 数据加载与预处理** 以CIFAR-10数据集为例,展示如何加载和预处理数据。 ```python (x_train, y_train), (x_test, y_test) = cifar10.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 # 数据增强 train_datagen = tf.keras.preprocessing.image.ImageDataGenerator( rotation_range=15, width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True, zoom_range=0.1, ) train_generator = train_datagen.flow(x_train, y_train, batch_size=32) ``` **36.5.2 训练模型** 使用训练生成器训练模型,并监控验证集上的性能。 ```python history = model.fit(train_generator, steps_per_epoch=len(x_train) // 32, epochs=10, validation_data=(x_test, y_test)) ``` #### 36.6 模型评估与预测 **36.6.1 评估模型** 在测试集上评估模型的性能。 ```python test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2) print('\nTest accuracy:', test_acc) ``` **36.6.2 预测新图像** 加载并预处理一张新图像,使用训练好的模型进行预测。 ```python img_path = 'path_to_image.jpg' img = image.load_img(img_path, target_size=(32, 32)) x = image.img_to_array(img) x = np.expand_dims(x, axis=0) x = preprocess_input(x) predictions = model.predict(x) print('Predicted:', decode_predictions(predictions, top=3)[0]) ``` #### 36.7 结论 通过本章节的学习,我们深入了解了ResNet的基本原理,并在TensorFlow 2框架下成功构建和训练了一个ResNet模型。从数据准备、模型构建、编译、训练到评估与预测,每一步都详细阐述了实现过程。希望读者能够通过实践,掌握使用TensorFlow 2进行深度学习项目开发的技能,为进一步探索更复杂的计算机视觉任务打下坚实的基础。
上一篇:
35 | 应⽤:分类训练集与验证集划分
下一篇:
37 | 应用:使用ResNet识别货架商品
该分类下的相关小册推荐:
人工智能超入门丛书--数据科学
AIGC:内容生产力的时代变革
人工智能基础——基于Python的人工智能实践(上)
ChatGPT与提示工程(上)
机器学习训练指南
ChatGPT大模型:技术场景与商业应用(下)
文心一言:你的百倍增效工作神器
AI 大模型系统实战
机器学习入门指南
AI时代程序员:ChatGPT与程序员(中)
ChatGPT大模型:技术场景与商业应用(中)
深度强化学习--算法原理与金融实践(五)