使用TensorFlow训练深度学习模型实战(下)

news/2024/5/17 14:47:42

大家好,本文接TensorFlow训练深度学习模型的上半部分继续进行讲述,下面将介绍有关定义深度学习模型、训练模型和评估模型的内容。

定义深度学习模型

数据准备完成后,下一步是使用TensorFlow搭建神经网络模型,搭建模型有两个选项:

可以使用各种层,包括Dense、Conv2D和LSTM,从头开始搭建模型。这些层定义了模型的架构及数据流经过它的方式,可基于TensorFlow Hub提供的预训练模型搭建模型。这些模型已经在大型数据集上进行了训练,并可以在特定数据集上进行微调,以达到在较短的训练时间内达到较高的准确度。

可以根据TensorFlow Hub中的预训练模型来建立模型。这些模型已经在大型数据集上进行了训练,并且可以在你的特定数据集上进行微调,以达到较少的训练时间,达到较高的准确性。

  • 从头开始定义深度学习模型

TensorFlow中的tf.keras.Sequential函数允许我们逐层定义神经网络模型,我们可以选择各种层,如Dense、Conv2D和LSTM,来搭建定制的模型架构。以下是示例: 

# 定义模型架构
model = tf.keras.Sequential([tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),tf.keras.layers.MaxPooling2D(),tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dropout(0.5),tf.keras.layers.Dense(10)
])

在这个示例中,我们定义了一个模型,包含以下六个层(4个隐藏层):

  1. Conv2D层,具有32个过滤器,3x3的内核大小和ReLU激活。此层以形状为(28,28,1)的输入图像作为输入。

  2. MaxPooling2D层,具有默认的2x2池大小。此层对从上一层获得的特征映射进行下采样。

  3. Flatten层,将2D特征映射展平为1D向量。

  4. Dense层,具有128个神经元和ReLU激活。此层对展平的特征映射执行完全连接操作。

  5. Dropout层,在训练期间随机丢弃50%的连接以防止过拟合。

  6. Dense层,具有十个神经元,无激活函数。此层表示模型的输出层,神经元的数量对应于分类任务中的类别数目。

这个模型遵循典型的卷积神经网络架构,包括多个卷积层和池化层,以及一个或多个全连接层。

  • 从预训练模型定义深度学习模型 

利用TensorFlow Hub提供的预训练模型可能是一个不错的选择,因为它们已经在大量的数据集上进行了训练,可以帮助在减少训练时间的同时实现高准确度。在实现任何这些模型之前,让我们先了解一些TensorFlow Hub提供的常见预训练模型。

  1. VGG:The Visual Geometry Group(VGG)模型是由牛津大学开发的。这些模型广泛用于图像分类任务,并在各种基准数据集上取得了最先进的结果。

  2. ResNet:The Residual Network(ResNet)模型是由微软研究院开发的。这些模型具有独特的架构,可以训练非常深的神经网络(高达1000层)。

  3. Inception:Inception模型是由Google开发的。这些模型具有独特的架构,使用不同尺度的多个并行卷积,Inception模型广泛用于目标检测和图像分类任务。

  4. MobileNet:MobileNet模型是由Google开发的。这些模型具有针对移动设备和嵌入式设备进行优化的独特架构,MobileNet模型广泛用于移动设备上的图像分类和目标检测任务。

可以通过向预训练模型添加额外层并在特定数据集上训练模型来应用迁移学习。与从头开始训练模型相比,这种技术可以节省大量时间和计算资源。但是,在选择预训练模型并将数据集转换为该格式以确保兼容之前,了解预训练模型所需的输入格式非常重要。

在这个示例中,MobileNet模型被作为基本模型使用。在使用基本模型之前,检查模型所需的格式非常重要, 在本示例中,格式为(224,224,3)。然而,MNIST数据集是一个灰度图像,大小为(28,28,1),其中单个值表示像素的亮度。图像大小也比所需的格式要小得多。因此,需要重新调整数据集。以下是调整大小的主要思路:

使用image.resize函数将图像调整为所需的大小。该函数使用双线性插值来保留原始图像中的信息,同时将其调整为新大小。因此,此步骤可以将原始形状(28,28,1)调整为(224,224,1)的形状。

使用image.grayscale_to_rgb函数将图像转换为新的RGB图像,通过将单个灰度通道复制到新的RGB图像的所有三个通道中,从而将原始形状(224,224,1)调整为(224,224,3)的形状。

# 调整输入图像的大小为224x224,并将其转换为三通道的RGB图像
X_train = tf.image.grayscale_to_rgb(tf.image.resize(X_train, [224, 224]))
X_test = tf.image.grayscale_to_rgb(tf.image.resize(X_test, [224, 224]))

 现在让我们基于MobileNet模型定义我们的模型:

# 加载MobileNet模型,不包括顶层
base_model = MobileNet(include_top=False, input_shape=(224, 224, 3))# 添加一个全局平均池化层和一个全连接输出层
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.5)(x)
x = Dense(10, activation='softmax')(x)# 将基础模型和新层结合起来,创建完整的模型
model = tf.keras.models.Model(inputs=base_model.input, outputs=x)# 冻结基础模型中的各层
for layer in base_model.layers:layer.trainable = False

在上面的示例中,我们定义了一个模型,如下所示:

  1. 使用MobileNet()定义基本模型

  2. GlobalAveragePooling2D层,使用基本模型的最后一个卷积层的输出,计算每个特征映射的平均值,从而得到一个固定长度的向量,总结了特征映射中的空间信息。

  3. Dropout层,在训练期间随机丢弃50%的连接以防止过拟合。

  4. Dense层,使用十个单元的完全连接层和softmax激活。它接收来自上一层的输出并生成覆盖十个可能类别的概率分布。

编译和训练模型 

在创建模型之后,必须通过指定在训练期间使用的损失函数、优化器和指标来编译它。以下是一个编译模型的示例代码:

# 编译该模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])

由于这是一个多分类问题,因此此示例代码使用了稀疏交叉熵损失函数,我们使用的是Adam优化器和准确率指标。

在训练模型之后可以在测试集上评估它,以查看它在未见过的数据上的表现如何,以下是一个评估模型的示例代码:

# 在测试数据上评估该模型
test_loss, test_acc = model.evaluate(X_test, y_test)
print('Test loss: ', test_loss)
print('Test Acc: ', test_acc)

 在此示例代码中,我们在测试集上评估模型,并输出测试损失和准确率。

进行预测

一旦训练和评估了模型,就可以使用它来预测新数据。以下是一个进行预测的示例代码:

# 对新数据进行预测
y_pred = model.predict(X_test)
y_pred_labels = np.argmax(y_pred, axis=1)
print(y_pred_labels)

在此示例代码中,我们在模型上使用predict()方法对整个测试集进行预测。

如果我们想要预测单个图像并返回预测标签与真实标签,那么就需要对Keras模型的predict()方法进行更改。因为Keras模型的predict()方法期望输入数据形式为一批图像,而我们想要传递单个图像给predict()方法,所以需要将其重新调整为批次大小为1。

def predict_and_compare(model, X_test, y_test, index):# 从X_test中获取给定索引的例子example = X_test[index]# 将例子重塑为预期的输入形状example = np.reshape(example, (1, 28, 28, 1))# 预测这个例子的标签y_pred = model.predict(example)# 将预测的概率转换为类别标签y_pred_label = np.argmax(y_pred, axis=1)[0]# 使用索引从y_test获取真实标签y_test_array = y_test.values# Get the label for the first example in the test set y_true = y_test_array[index]# 输出预测的和真实的标签print("Predicted label:", y_pred_label)print("True label:", y_true)# 返回预测的和真实的标签return y_pred_label, y_true# 预测并比较测试集中第一个例子的标签
y_pred_label, y_true = predict_and_compare(model, X_test, y_test, 0)

在上面的示例中,我们通过添加一个额外的维度来代表批次大小,从而将输入图像从(28,28,1)调整为(1,28,28,1)。这样,我们就可以传递单个图像给predict()方法,并获得该图像的预测结果。当我们调用上面的函数时,可以自定义要预测的图像:

 这就是在TensorFlow中实现深度学习的步骤。当然,这只是一个基本示例。你可以搭建具有更多层、不同类型的层和不同超参数的更复杂的模型,以便在数据集上获得更好的性能。

综上,本文我们演示了如何对数据进行预处理、搭建和训练模型、在单独的测试集上评估其性能以及使用简单的卷积神经网络(CNN)进行图像分类的预测,通过学习可以获得如何在TensorFlow中构建深度学习模型以及如何将这些概念应用于真实世界数据集的理解。


http://www.mrgr.cn/p/54032346

相关文章

Android 中 app freezer 原理详解(一):S 版本

基于版本:Android S 0. 前言 在之前的两篇博文《Android 中app内存回收优化(一)》和 《Android 中app内存回收优化(二)》中详细剖析了 Android 中 app 内存优化的流程。这个机制的管理通过 CachedAppOptimizer 类管理,为什么叫这个名字,而不…

k8s一站式使用笔记

前言 个人感觉比较磨心态,要坐住,因为细节太多,建议:一遍看个大概,二遍回来细品,不要当成任务,把握零碎时间 一、k8s安装 1、配置准备 硬件要求 内存:2GB或更多RAMCPU: 2核CPU或更…

【RabbitMQ】Linux系统服务器安装RabbitMQ

一、下载 首先应该下载erlang,rabbitmq运行需要有erland环境。 官网地址:https://www.erlang.org/downloads 下载rabbitmq 官网环境:https://www.rabbitmq.com/download.html 注意:el7对应centos7,el8对应centos8…

centos下安装ftp-读取目录列表失败-

1.下载安装ftp服务器端和客户端 #1.安装yum -y install vsftpdyum -y install ftp #2.修改配置文件vim /etc/vsftpd.conflocal_enablesYESwrite_enableYESanonymous_enableYESanon_mkdir_write_enableYES //允许匿名用户在FTP上创建目录anon_upload_enableYES //允许匿名用户…

数值线性代数: 共轭梯度法

本文总结线性方程组求解的相关算法,特别是共轭梯度法的原理及流程。 零、预修 0.1 LU分解 设,若对于,均有,则存在下三角矩阵和上三角矩阵,使得。 设,若对于,均有,则存在唯一的下三…

kotlin 编写一个简单的天气预报app(四)

编写界面来显示返回的数据 用户友好性&#xff1a;通过界面设计和用户体验优化&#xff0c;可以使天气信息更易读、易理解和易操作。有效的界面设计可以提高用户满意度并提供更好的交互体验。 增加城市名字的TextView <TextViewandroid:id"id/textViewCityName"…

matlab使用教程(5)—矩阵定义和基本运算

本博客介绍如何在 MATLAB 中创建矩阵和执行基本矩阵计算。 MATLAB 环境使用矩阵来表示包含以二维网格排列的实数或复数的变量。更广泛而言&#xff0c;数组为向量、矩阵或更高维度的数值网格。MATLAB 中的所有数组都是矩形&#xff0c;在这种意义上沿任何维度的分量向量的长度…

【英杰送书第三期】Spring 解决依赖版本不一致报错 | 文末送书

Yan-英杰的主 悟已往之不谏 知来者之可追 C程序员&#xff0c;2024届电子信息研究生 目录 问题描述 报错信息如下 报错描述 解决方法 总结 【粉丝福利】 【文末送书】 目录&#xff1a; 本书特色&#xff1a; 问题描述 报错信息如下 Description:An attempt…

三、前端高德地图、测量两个点之前的距离

点击测距工具可以开启测量&#xff0c;再次点击关闭测量&#xff0c;清除地图上的点、连线、文字 再次点击测量工具的时候清除。 首先 上面的功能条河下面的地图我搞成了两个组件&#xff0c;他们作为兄弟组件存在&#xff0c;所以简单用js写了个事件监听触发的对象&#xff…

视频怎么加水印?这几种加水印方法非常简单

给视频加水印是一种保护知识产权的方法。水印是一种数字标记&#xff0c;可以包括作者的名称、品牌标识或其他信息&#xff0c;以便识别和追踪视频的来源。通过给视频加水印&#xff0c;能够有效地防止视频被盗用或未经授权的使用&#xff0c;让我们的知识产权得到更好的保护。…

fpga开发——蜂鸣器

蜂鸣器的原理 有源蜂鸣器和无源蜂鸣器 无源蜂鸣器利用电磁感应现象&#xff0c;为音圈接入交变电流后形成的电磁铁与永磁铁相吸或相斥而推动振膜发声&#xff0c;接入直流电只能持续推动振膜而无法产生声音&#xff0c;只能在接通或断开时产生声音。无源蜂鸣器的工作原理与扬声…

Ueditor 百度强大富文本Springboot 项目集成使用(包含上传文件和上传图片的功能使用)简单易懂,举一反三

Ueditor 百度强大富文本Springboot 项目集成使用 首先如果大家的富文本中不考虑图片或者附件的情况下&#xff0c;只考虑纯文本且排版的情况下我们可以直接让前端的vue来继承UEditor就可以啦。但是要让前端将那几个上传图片和附件的哪些功能给阉割掉&#xff01; 然后就是说如…

Cisco 路由器配置管理

大多数网络中断的最常见原因是错误的配置更改。对网络设备配置的每一次更改都伴随着造成网络中断、安全问题甚至性能下降的风险。计划外更改使网络容易受到意外中断的影响。 Network Configuration Manager 网络更改和配置管理 &#xff08;NCCM&#xff09;解决方案&#xff…

14个最强大的建筑设计AI工具

在整个行业中&#xff0c;建筑师在他们的创造性追求中正在拥抱一个新的合作伙伴&#xff1a;AI。 一旦受到重复和单调的困扰&#xff0c;建筑工人发现自己正处于数字革命的风口浪尖&#xff0c;其中比特和字节掌握着自动化和曾经难以想象的可能性的关键。 推荐&#xff1a;用 …

【Linux】网络基础之TCP协议

目录 &#x1f308;前言&#x1f338;1、基本概念&#x1f33a;2、TCP协议报文结构&#x1f368;2.1、源端口号和目的端口号&#x1f369;2.2、4位首部长度&#x1f36a;2.3、32位序号和确认序号&#xff08;重点&#xff09;&#x1f36b;2.4、16位窗口大小&#x1f36c;2.5、…

golang单元测试及mock总结

文章目录 一、前言1、单测的定位2、vscode中生成单测 二、构造测试case的注意事项1、项目初始化2、构造空interface{}3、构造结构体的time.Time类型4、构造json格式的test case 三、运行单测文件1、整体运行单测文件2、运行单个单测文件报错&#xff08;1&#xff09;command-l…

IO流(2)-缓冲流

1. 缓冲流的简单介绍 我们上贴说到了 FileInputStream&#xff0c;FileOutputStream&#xff0c;FileReader&#xff0c;FileWriter。 其实这四个流&#xff0c;我们通常把它叫做原始流&#xff0c;它们是比较偏底层的&#xff1b;而今天我们要说的四个缓冲流&#xff0c;如…

微服务 云原生:搭建 K8S 集群

为节约时间和成本&#xff0c;仅供学习使用&#xff0c;直接在两台虚拟机上模拟 K8S 集群搭建 踩坑之旅 系统环境&#xff1a;CentOS-7-x86_64-Minimal-2009 镜像&#xff0c;为方便起见&#xff0c;直接在 root 账户下操作&#xff0c;现实情况最好不要这样做。 基础准备 关…

vins调试的注意事项

1、摄像头的内参和畸变矫正系数 这个系数不对&#xff0c;没法做&#xff0c;因为下一步没法做对。这个会导致系统无法初始化。 2、对畸变的像素点&#xff0c;求得归一化坐标的方法 理解不同矫正模型的原理&#xff0c;确保矫正对了&#xff0c;得到z1平面的去畸变点。 3、摄…