当前位置: 首页 > news >正文

使用mnist数据集和LeakyReLU高级激活函数训练神经网络示例代码

在这里插入图片描述

一、概述

神经网络中的激活函数是用于增加网络的非线性特性的函数,没有激活函数,神经网络将仅仅是一个线性模型,无法解决复杂的非线性问题。激活函数的选择对神经网络的性能有很大的影响。

基础激活函数是神经网络中使用较早、较为简单的激活函数,主要包括Sigmoid、Tanh、ReLU、ELU、SELU等,具体请参考老猿在CSDN博文《神经网络激活函数列表大全及keras中的激活函数定义 https://blog.csdn.net/LaoYuanPython/article/details/142731106 》的介绍。

随着深度学习的发展,为了解决基础激活函数的一些问题(如梯度消失、梯度爆炸、计算复杂度等),研究者们提出了一些高级激活函数,如Leaky ReLU、Parametric ReLU (PReLU)等。高级激活函数是在基础激活函数的基础上发展起来的,继承了基础激活函数的某些特性,同时引入了新的机制来改进性能,通常是为了解决基础激活函数在实际应用中遇到的问题。

高级激活函数通常比基础激活函数更复杂,可能包含更多的参数或计算步骤,通常旨在解决基础激活函数的某些限制,如梯度消失或激活函数的非单调性。基础激活函数适用于大多数情况,但高级激活函数可能在特定任务或网络结构中表现更好。
关于高级激活函数请参考老猿博文《神经网络高级激活函数大全及keras中的函数定义 https://blog.csdn.net/LaoYuanPython/article/details/142742719》

选择哪种激活函数通常取决于具体任务的需求、数据的特性以及实验的结果。在实践中,可能需要尝试不同的激活函数来找到最适合特定问题的激活函数。

二、keras.datasets.mnist介绍

keras.datasets.mnist 是 Keras 库中的一个数据集模块,它提供了对 MNIST 数据集的访问。MNIST 数据集是一个广泛使用的手写数字识别数据集,它包含了60,000个训练样本和10,000个测试样本,每个样本都是一个28x28像素的灰度图像,以及对应的标签(0到9的数字)。

以下是 keras.datasets.mnist 的一些基本用法:

  1. 加载数据集:使用 load_data() 函数可以加载 MNIST 数据集。这个函数会返回两个元组,分别代表训练集和测试集,每个元组包含图像数据和标签;
  2. 数据预处理:加载的数据通常需要进行预处理,比如归一化,以提高模型的性能;
  3. 构建模型:使用 Keras 构建一个神经网络模型来训练和测试数据;
  4. 训练模型:使用训练数据训练模型;
  5. 评估模型:使用测试数据评估模型的性能。
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, LeakyReLU
from keras.optimizers import Adam# 加载MNIST数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()# 将像素值归一化到0-1范围内
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255# 将图像展平为784个特征的一维数组
x_train = x_train.reshape((x_train.shape[0], -1))
x_test = x_test.reshape((x_test.shape[0], -1))# 将标签转换为one-hot编码
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)# 创建一个顺序模型
model = Sequential()# 添加一个全连接层
model.add(Dense(512, input_shape=(784,)))
model.add(LeakyReLU(alpha=0.01))  # 添加LeakyReLU激活函数# 继续构建模型
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.01))  # 再次添加LeakyReLU激活函数
model.add(Dense(128))
model.add(LeakyReLU(alpha=0.01))  # 添加LeakyReLU激活函数
model.add(Dense(10, activation='softmax'))  # 输出层使用softmax激活函数# 编译模型
model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy'])# 模型总结
model.summary()# 训练模型
model.fit(x_train, y_train, epochs=5, batch_size=128, validation_data=(x_test, y_test))# 评估模型
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

执行后的输出:

Model: "sequential"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================dense (Dense)               (None, 512)               401920    leaky_re_lu (LeakyReLU)     (None, 512)               0         dense_1 (Dense)             (None, 256)               131328    leaky_re_lu_1 (LeakyReLU)   (None, 256)               0         dense_2 (Dense)             (None, 128)               32896     leaky_re_lu_2 (LeakyReLU)   (None, 128)               0         dense_3 (Dense)             (None, 10)                1290      =================================================================
Total params: 567,434
Trainable params: 567,434
Non-trainable params: 0
_________________________________________________________________
Epoch 1/5
469/469 [==============================] - 9s 17ms/step - loss: 0.2330 - accuracy: 0.9321 - val_loss: 0.1017 - val_accuracy: 0.9694
Epoch 2/5
469/469 [==============================] - 7s 16ms/step - loss: 0.0851 - accuracy: 0.9743 - val_loss: 0.0896 - val_accuracy: 0.9710
Epoch 3/5
469/469 [==============================] - 7s 16ms/step - loss: 0.0554 - accuracy: 0.9822 - val_loss: 0.1009 - val_accuracy: 0.9680
Epoch 4/5
469/469 [==============================] - 7s 16ms/step - loss: 0.0406 - accuracy: 0.9871 - val_loss: 0.0709 - val_accuracy: 0.9791
Epoch 5/5
469/469 [==============================] - 8s 16ms/step - loss: 0.0290 - accuracy: 0.9906 - val_loss: 0.0797 - val_accuracy: 0.9789
Test loss: 0.0796799287199974
Test accuracy: 0.9789000153541565

三、小结

本文介绍了使用mnist数据集和LeakyReLU高级激活函数训练神经网络示例代码,这个示例代码使用全连接层,激活函数在隐藏层使用的是LeakyReLU,输出层使用的是softmax。这个神经网络是比较简单的神经网络,根据训练后的测试情况,其识别精度接近98%。

更多人工智能知识学习请关注专栏《零基础机器学习入门》后续的文章。

更多人工智能知识学习过程中可能遇到的疑难问题及解决办法请关注专栏《机器学习疑难问题集》后续的文章。

写博不易,敬请支持:

如果阅读本文于您有所获,敬请点赞、评论、收藏,谢谢大家的支持!

关于老猿的付费专栏

  1. 付费专栏《https://blog.csdn.net/laoyuanpython/category_9607725.html 使用PyQt开发图形界面Python应用》专门介绍基于Python的PyQt图形界面开发基础教程,对应文章目录为《 https://blog.csdn.net/LaoYuanPython/article/details/107580932 使用PyQt开发图形界面Python应用专栏目录》;
  2. 付费专栏《https://blog.csdn.net/laoyuanpython/category_10232926.html moviepy音视频开发专栏 )详细介绍moviepy音视频剪辑合成处理的类相关方法及使用相关方法进行相关剪辑合成场景的处理,对应文章目录为《https://blog.csdn.net/LaoYuanPython/article/details/107574583 moviepy音视频开发专栏文章目录》;
  3. 付费专栏《https://blog.csdn.net/laoyuanpython/category_10581071.html OpenCV-Python初学者疑难问题集》为《https://blog.csdn.net/laoyuanpython/category_9979286.html OpenCV-Python图形图像处理 》的伴生专栏,是笔者对OpenCV-Python图形图像处理学习中遇到的一些问题个人感悟的整合,相关资料基本上都是老猿反复研究的成果,有助于OpenCV-Python初学者比较深入地理解OpenCV,对应文章目录为《https://blog.csdn.net/LaoYuanPython/article/details/109713407 OpenCV-Python初学者疑难问题集专栏目录 》
  4. 付费专栏《https://blog.csdn.net/laoyuanpython/category_10762553.html Python爬虫入门 》站在一个互联网前端开发小白的角度介绍爬虫开发应知应会内容,包括爬虫入门的基础知识,以及爬取CSDN文章信息、博主信息、给文章点赞、评论等实战内容。

前两个专栏都适合有一定Python基础但无相关知识的小白读者学习,第三个专栏请大家结合《https://blog.csdn.net/laoyuanpython/category_9979286.html OpenCV-Python图形图像处理 》的学习使用。

对于缺乏Python基础的同仁,可以通过老猿的免费专栏《https://blog.csdn.net/laoyuanpython/category_9831699.html 专栏:Python基础教程目录)从零开始学习Python。

如果有兴趣也愿意支持老猿的读者,欢迎购买付费专栏。

老猿Python,跟老猿学Python!

☞ ░ 前往老猿Python博文目录 https://blog.csdn.net/LaoYuanPython ░

http://www.mrgr.cn/news/49390.html

相关文章:

  • Springboot 使用【过滤器】实现在请求到达 Controller 之前修改请求体参数和在结果返回之前修改响应体
  • 25.1 降低采集资源消耗的收益和无用监控指标的判定依据
  • 7-2 试试多线程
  • 探索C#编程基础:从输入验证到杨辉三角的生成
  • Java——数组的定义与使用
  • AndroidLogger 使用问题
  • 大厂面试真题-AQS中节点的入队时机有哪些
  • React入门 9:React Router
  • 【汇编语言】寄存器(CPU工作原理)(七)—— 查看CPU和内存,用机器指令和汇编指令编程
  • 多语言网站的设计的探索——安企CMS多语言功能的实现记录
  • Python字符串格式
  • 鸿蒙开发 三十七 ArkTs类 class
  • HAL库常用的函数:
  • oracle存储过程
  • 位定时结构
  • 面试真题 | 百度C++研发工程师面经
  • 动态规划最大子段和讲解和【题解】——最大子段和
  • springcloud之服务提供与负载均衡调用 Eureka
  • 『香驰控股』上线采购数字化平台,企企通助推农业产业化国家重点龙头提升供应链价值
  • AttributeError: ‘str‘ Object Has No Attribute ‘x‘:字符串对象没有属性x的完美解决方法