机器学习-10-基于paddle实现神经网络

news/2024/5/4 7:57:56

文章目录

    • 总结
    • 参考
    • 本门课程的目标
    • 机器学习定义
    • 第一步:数据准备
    • 第二步:定义网络
    • 第三步:训练网络
    • 第四步:测试训练好的网络

总结

本系列是机器学习课程的系列课程,主要介绍基于paddle实现神经网络。

参考

MNIST 训练_副本

本门课程的目标

完成一个特定行业的算法应用全过程:

懂业务+会选择合适的算法+数据处理+算法训练+算法调优+算法融合
+算法评估+持续调优+工程化接口实现

机器学习定义

关于机器学习的定义,Tom Michael Mitchell的这段话被广泛引用:
对于某类任务T性能度量P,如果一个计算机程序在T上其性能P随着经验E而自我完善,那么我们称这个计算机程序从经验E中学习
在这里插入图片描述

使用MNIST数据集训练和测试模型。

第一步:数据准备

MNIST数据集

import paddle
from paddle.vision.datasets import MNIST
from paddle.vision.transforms import ToTensortrain_dataset = MNIST(mode='train', transform=ToTensor())
test_dataset = MNIST(mode='test', transform=ToTensor())

展示数据集图片

import matplotlib.pyplot as plt
import numpy as nptrain_data0, train_label_0 = train_dataset[0][0], train_dataset[0][1]
train_data0 = train_data0.reshape([28, 28])
plt.figure(figsize=(2, 2))
plt.imshow(train_data0, cmap=plt.cm.binary)
print('train_data0 的标签为: ' + str(train_label_0))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import Sized
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingif isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingreturn list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_max = np.asscalar(a_max.astype(scaled_dtype))train_data0 的标签为: [5]

第二步:定义网络

import paddle
import paddle.nn.functional as F
from paddle.nn import Conv2D, MaxPool2D, Linearclass MyModel(paddle.nn.Layer):def __init__(self):super(MyModel, self).__init__()self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)self.conv2 = Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)self.linear1 = Linear(in_features=16*5*5, out_features=120)self.linear2 = Linear(in_features=120, out_features=84)self.linear3 = Linear(in_features=84, out_features=10)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = self.max_pool1(x)x = F.relu(x)x = self.conv2(x)x = self.max_pool2(x)x = paddle.flatten(x, start_axis=1, stop_axis=-1)x = self.linear1(x)x = F.relu(x)x = self.linear2(x)x = F.relu(x)x = self.linear3(x)return x

模型可视化

import paddle
mnist = MyModel()
paddle.summary(mnist, (1, 1, 28, 28))
---------------------------------------------------------------------------Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================Conv2D-1       [[1, 1, 28, 28]]      [1, 6, 28, 28]          156      MaxPool2D-1     [[1, 6, 28, 28]]      [1, 6, 14, 14]           0       Conv2D-2       [[1, 6, 14, 14]]     [1, 16, 10, 10]         2,416     MaxPool2D-2    [[1, 16, 10, 10]]      [1, 16, 5, 5]            0       Linear-1          [[1, 400]]            [1, 120]           48,120     Linear-2          [[1, 120]]            [1, 84]            10,164     Linear-3          [[1, 84]]             [1, 10]              850      
===========================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.06
Params size (MB): 0.24
Estimated Total Size (MB): 0.30
---------------------------------------------------------------------------{'total_params': 61706, 'trainable_params': 61706}

第三步:训练网络

import paddle
from paddle.metric import Accuracy
from paddle.static import InputSpecinputs = InputSpec([None, 784], 'float32', 'x')
labels = InputSpec([None, 10], 'float32', 'x')# 用Model封装模型
model = paddle.Model(MyModel(), inputs, labels)# 定义损失函数
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())# 配置模型
model.prepare(optim, paddle.nn.CrossEntropyLoss(), Accuracy())
# 训练模型
model.fit(train_dataset, test_dataset, epochs=3, batch_size=64, save_dir='mnist_checkpoint', verbose=1)
The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/3
step 938/938 [==============================] - loss: 0.0208 - acc: 0.9456 - 34ms/step          
save checkpoint at /home/aistudio/mnist_checkpoint/0
Eval begin...
step 157/157 [==============================] - loss: 0.0041 - acc: 0.9777 - 19ms/step          
Eval samples: 10000
Epoch 2/3
step 938/938 [==============================] - loss: 0.0021 - acc: 0.9820 - 34ms/step          
save checkpoint at /home/aistudio/mnist_checkpoint/1
Eval begin...
step 157/157 [==============================] - loss: 2.1037e-04 - acc: 0.9858 - 19ms/step      
Eval samples: 10000
Epoch 3/3
step 938/938 [==============================] - loss: 0.0126 - acc: 0.9876 - 34ms/step          
save checkpoint at /home/aistudio/mnist_checkpoint/2
Eval begin...
step 157/157 [==============================] - loss: 4.7168e-04 - acc: 0.9884 - 19ms/step      
Eval samples: 10000
save checkpoint at /home/aistudio/mnist_checkpoint/final

第四步:测试训练好的网络

import paddle
import numpy as np
import matplotlib.pyplot as plt
from paddle.metric import Accuracy
from paddle.static import InputSpecinputs = InputSpec([None, 784], 'float32', 'x')
labels = InputSpec([None, 10], 'float32', 'x')
model = paddle.Model(MyModel(), inputs, labels)
model.load('./mnist_checkpoint/final')
model.prepare(optim, paddle.nn.CrossEntropyLoss(), Accuracy())# results = model.evaluate(test_dataset, batch_size=64, verbose=1)
# print(results)results = model.predict(test_dataset, batch_size=64)test_data0, test_label_0 = test_dataset[0][0], test_dataset[0][1]
test_data0 = test_data0.reshape([28, 28])
plt.figure(figsize=(2,2))
plt.imshow(test_data0, cmap=plt.cm.binary)print('test_data0 的标签为: ' + str(test_label_0))
print('test_data0 预测的数值为:%d' % np.argsort(results[0][0])[0][-1])
Predict begin...
step 157/157 [==============================] - 27ms/step          
Predict samples: 10000
test_data0 的标签为: [7]
test_data0 预测的数值为:7/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_max = np.asscalar(a_max.astype(scaled_dtype))

在这里插入图片描述



http://www.mrgr.cn/p/75712536

相关文章

Linux:进程状态

Linux:进程状态 进程状态运行状态R状态 阻塞状态S状态D状态T状态t状态 挂起状态僵尸进程 & 孤儿进程X状态Z状态孤儿进程 进程状态 当一个可执行程序,被载入内存,获得自己的PCB,那么其就可以变成一个进程。也许你学习过一些进…

java高校办公室行政事务管理系统设计与实现(springboot+mysql源码+文档)

风定落花生,歌声逐流水,大家好我是风歌,混迹在java圈的辛苦码农。今天要和大家聊的是一款基于springboot的闲一品交易平台。项目源码以及部署相关请联系风歌,文末附上联系信息 。 项目简介: 基于mvc的高校办公室行政…

AJAX——Promise-链式调用

1.Promise链式调用 概念:依靠then()方法会返回一个新生成的Promise对象特性,继续串联下一环任务,知道结束 细节:then()回调函数中的返回值,会影响新生成的Promise对象最终状态和结果 好处:通过链式调用&…

【Java基础】23.接口

文章目录 一、接口的概念1.接口介绍2.接口与类相似点3.接口与类的区别4.接口特性5.抽象类和接口的区别 二、接口的声明三、接口的实现四、接口的继承五、接口的多继承六、标记接口 一、接口的概念 1.接口介绍 接口(英文:Interface)&#xf…

鸿蒙OpenHarmony【小型系统编写“Hello World”程序】 (基于Hi3516开发板)

编写“Hello World”程序 下方将展示如何在单板上运行第一个应用程序,其中包括新建应用程序、编译、烧写、运行等步骤,最终输出“Hello World!”。 前提条件 已参考[创建工程并获取源码],创建Hi3516开发板的源码工程。 鸿蒙开发…

如何处理Keil uVision5注释无法输入汉字且输入汉字变成问号的问题

好久没用KEIL,今天在注释中出现无法输入汉字的情况,且输入或粘贴的汉字都变成了问号,解决方法很简单,将General Editor Settings: Encoding:设置为Chinese GB2312(Simplified)即可(出现问号的当前设置是Encode in ANSI…

海外云手机怎么解决tiktok运营难题?

最近打算做TikTok的商家越来越多了,而做TikTok的第一步就面临如何养号、涨粉的困境,本文将介绍如何通过海外云手机轻松解决这些问题。 早期大家用的比较多的,是真机科学上网的方法。但是这种方法,需要自己搭建海外环境&#xff0c…

day04 51单片机-矩阵按键

1 矩阵按键 1.1 需求描述 本案例实现以下功能:按下矩阵按键SW5到SW20,数码管会显示对应的按键编号。 1.2 硬件设计 1.2.1 硬件原理图 1.2.2 矩阵按键原理 1.3软件设计 1)Int_MatrixKeyboard.h 在项目的Int目录下创建Int_MatrixKeyboard…

【创建型模式】单例模式

一、单例模式概述 单例模式的定义:又叫单件模式,确保一个类只有一个实例,并提供一个全局访问点。(对象创建型) 要点: 1.某个类只能有一个实例;2.必须自行创建这个实例;3.必须自行向整…

探索RadSystems:低代码开发的新选择(二)

系列文章目录 探索RadSystems:低代码开发的新选择(一)🚪 文章目录 系列文章目录前言一、RadSystems Studio是什么?二、用户认证三、系统角色许可四、用户记录管理五、时间戳记录总结 前言 在数字化时代,低…

HarmonyOs开发:导航tabs组件封装与使用

前言 主页的底部导航以及页面顶部的切换导航,无论哪个系统,哪个App,都是最常见的功能之一,虽然说在鸿蒙中有现成的组件tabs可以很快速的实现,但是在使用的时候,依然有几个潜在的问题存在,第一&a…

CountDownLatch倒计时器源码解读与使用

🏷️个人主页:牵着猫散步的鼠鼠 🏷️系列专栏:Java全栈-专栏 🏷️个人学习笔记,若有缺误,欢迎评论区指正 目录 1. 前言 2. CountDownLatch有什么用 3. CountDownLatch底层原理 3.1. count…

不如你把我杀了吧 | 绘制自定义的 3D 地图

如何根据自己的json数据绘制类似这种地图,仅供参考 1、准备数据。 因为自定义,所以全部的数据都来源自己。我们需要准备地图数据(包括但不限于地图轮廓数据,点数据) 这里我的数据使用的是arcgis导出json数据,因此数据格式足够规范,这省去了很多的麻烦。 2、导入相关库、…

【智能算法】寄生捕食算法(PPA)原理及实现

目录 1.背景2.算法原理2.1算法思想2.2算法过程 3.结果展示4.参考文献 1.背景 2020年,AAA Mohamed等人受到自然界乌鸦-布谷鸟-猫寄生系统启发,提出了寄生捕食算法(Parasitism – Predation Algorithm, PPA)。 2.算法原理 2.1算法…

面向对象设计与分析40讲(25)中介模式、代理模式、门面模式、桥接模式、适配器模式

文章目录 门面模式代理模式中介模式 之所以把这几个模式放到一起写,是因为它们的界限比较模糊,结构上没有明显的差别,差别只是语义上。 这几种模式在结构上都类似: 代理将原本A–>C的直接调用变成: A–>B–>…

针对窗口数量多导致窗口大小显示受限制的问题,使用滚动条控制窗口

建议:首先观察结果展示,判断是否可以满足你的需求。 目录 1. 问题分析 2. 解决方案 2.1 界面设计 2.2 生成代码 2.3 源码实现 3. 结果展示 1. 问题分析 项目需要显示的窗口数量颇多,主界面中,如果一次性显示全部窗口&#x…

30 消息队列

原理 操作系统可以通过页表映射在共享区创建一块共享内存,也可以申请一个队列。A进程和B进程可以向这个队列发送数据块,两个进程接收数据块来通信 函数 申请数据块 参数中的key来自于ftok函数 删除消息队列 同样消息队列也有数据结构管理&#xff…

service类型及功能简介+pod类型

功能定义及简介: 定义:K8s 中的Service是一种抽象,用于定义一组Pod的逻辑集合,并为它们提供统一的网络入口。Service充当了Pod的负载平衡器和服务发现器,为应用程序提供了稳定的网络地址,使得应用程序可以访问与之关联的Pod而无需了解其具体的IP地址或端口。 功能:Servi…

记一次自己的蠢币操作

今天在Hyper-V上安装了一个Ubuntu 22.04的虚拟机,折腾了老半天安装完以后随便搜了一篇换apt源的博客,点击查看镜像源(Ubuntu 20.04) #添加清华源 deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal main restricted universe multiverse deb-src https://mirrors.tuna…

Dubbo应用可观测性升级指南与踩坑记录

应用从dubbo-3.1.*升级到dubbo-*:3.2.*最新稳定版本,提升应用的可观测性和度量数据准确性。 1. dubbo版本发布说明(可不关注) dubbo版本发布 https://github.com/apache/dubbo/releases 【升级兼容性】3.1 升级到 3.2 2. 应用修改点 应用一般只需要升级dubbo-s…