深度学习故障诊断实战 | 数据预处理之创建Dataloader数据集

news/2024/5/15 21:36:32

前言

本期给大家分享介绍如何用Dataloader创建数据集

背景

在这里插入图片描述
在这里插入图片描述

示例代码

from torch import nn
import torch
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
import torch.functional as F
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
#####----绘制混淆矩阵图----#
from sklearn.metrics import confusion_matrix
from matplotlib.pylab import style
# style.use('ggplot')    
from matplotlib import rcParamsconfig = {"font.family": 'serif', # 衬线字体"font.size": 10, # 相当于小四大小"font.serif": ['SimSun'], # 宋体"mathtext.fontset": 'stix', # matplotlib渲染数学字体时使用的字体,和Times New Roman差别不大'axes.unicode_minus': False # 处理负号,即-号
}
rcParams.update(config)

加载数据、创建数据集

def data_read(file_path):""":fun: 读取cwru mat格式数据:param file_path: .mat文件路径  eg: r'D:.../01_示例数据/1750_12k_0.021-OuterRace3.mat':return accl_data: 读取到的加速度数据"""import scipy.io as sciodata = scio.loadmat(file_path)  # 加载mat数据data_key_list = list(data.keys())  # mat文件为字典类型,将key变为list类型accl_key = data_key_list[3]  # mat文件为字典类型,其加速度列在key_list的第4个accl_data = data[accl_key].flatten()  # 获取加速度信号,并展成1维数据accl_data = (accl_data-np.mean(accl_data))/np.std(accl_data) #Z-score标准化数据集return accl_data
def data_spilt(data, num_2_generate=20, each_subdata_length=1024):""":Desription:  将数据分割成n个小块。输入数据data采样点数是400000,分成100个子样本数据,每个子样本数据就是4000个数据点:param data:  要输入的数据:param num_2_generate:  要生成的子样本数量:param each_subdata_length: 每个子样本长度:return spilt_datalist: 分割好的数据,类型为2维list"""data = list(data)total_length = len(data)start_num = 0   # 子样本起始值end_num = each_subdata_length  # 子样本终止值step_length = int((total_length - each_subdata_length) / (num_2_generate - 1))  # step_length: 向前移动长度i = 1spilt_datalist = []while i <= num_2_generate:each_data = data[start_num: end_num]each_data = (each_data-np.mean(each_data))/(np.std(each_data)) # 做Z-score归一化spilt_datalist.append(each_data)start_num = 0 + i * step_length;end_num = each_subdata_length + i * step_lengthi = i + 1spilt_data_arr = np.array(spilt_datalist)return spilt_data_arr

划分子样本

data_base_dir = r'D:\22-学习记录\01_自己学习积累\03_基于CNN的轴承故障诊断(4分类)\dataset/train_data'
fault_type_list = os.listdir(data_base_dir)
file_path_list = list()
train_data = []
train_label = []
for fault_type in fault_type_list:file_name_list = os.listdir(os.path.join(data_base_dir, fault_type))for file_name in file_name_list:print(file_name)num_2_generate = 60   # 每个mat文件生成60个子样本if 'Normal' in file_name:num_2_generate = num_2_generate*6  # Normal状态,每个mat文件生成360个子样本file_path = os.path.join(data_base_dir, fault_type, file_name)fault_type2fault_label = {'BF':'0', 'OF':'1', 'IF':'2', 'Normal':'3'}  ##=========获取数据===========##data = data_read(file_path)   # 读取单个mat数据##======基于滑动窗方法划分获取更多子样本数据======##sub_data_list = data_spilt(data=data, num_2_generate=num_2_generate)train_label.extend(list(fault_type2fault_label[fault_type]*len(sub_data_list)))  # 训练标签 eg: ['0', '0', ...,'3',]train_data.extend(sub_data_list)   # 训练数据

输出结果:

1730_12k_0.007-Ball.mat
1730_12k_0.014-Ball.mat
1730_12k_0.021-Ball.mat
1730_48k_0.007-Ball.mat
1730_48k_0.014-Ball.mat
1730_48k_0.021-Ball.mat
...
1772_12k_0.021-OuterRace3.mat
1772_48k_0.007-OuterRace3.mat
1772_48k_0.014-OuterRace6.mat
1772_48k_0.021-OuterRace3.mat
data_base_dir = r'D:\22-学习记录\01_自己学习积累\03_基于CNN的轴承故障诊断(4分类)/dataset/test_data'
fault_type_list = os.listdir(data_base_dir)
file_path_list = list()
test_data = []
test_label = []
for fault_type in fault_type_list:file_name_list = os.listdir(os.path.join(data_base_dir, fault_type))for file_name in file_name_list:print(file_name)num_2_generate = 30          # 每个mat文件生成30个子样本 if 'Normal' in file_name:num_2_generate = num_2_generate*6   # Normal状态,每个mat文件生成180个子样本file_path = os.path.join(data_base_dir, fault_type, file_name)fault_type2fault_label = {'BF':'0', 'OF':'1', 'IF':'2', 'Normal':'3'}##=========获取数据===========##data = data_read(file_path)   # 读取单个mat数据##======基于滑动窗方法划分获取更多子样本数据======##sub_data_list = data_spilt(data=data, num_2_generate=num_2_generate)test_label.extend(list(fault_type2fault_label[fault_type]*len(sub_data_list)))   # 测试标签 eg: ['0', '0', ...,'3',]test_data.extend(sub_data_list)  # 测试数据

输出结果:

1797_12k_0.007-Ball.mat
1797_12k_0.014-Ball.mat
1797_12k_0.021-Ball.mat
1797_48k_0.007-Ball.mat
...
1797_12k_0.021-OuterRace3.mat
1797_48k_0.007-OuterRace3.mat
1797_48k_0.014-OuterRace6.mat
1797_48k_0.021-OuterRace3.mat

创建dataloader

from torch.utils.data import DataLoader
from torch.utils.data import TensorDatasettrain_data = np.array(train_data)  #训练数据
train_label = np.array((train_label), dtype=int)  #训练标签 array([0, 0, 0, ..., 1, 1, 1])train_x = torch.from_numpy(train_data).type(torch.float32)  ##torch.Size([2880, 1024])
train_y = torch.from_numpy(train_label).type(torch.LongTensor) ##torch.Size([2880])
train_x = torch.unsqueeze(train_x, dim=1) ##扩展成3维 torch.Size([2880, 1, 1024])test_data = np.array(test_data)  #测试数据
test_label = np.array((test_label), dtype=int)  #测试标签array([0, 0, 0, ..., 1, 1, 1])test_x = torch.from_numpy(test_data).type(torch.float32)  ##torch.Size([720, 1024])
test_y = torch.from_numpy(test_label).type(torch.LongTensor)  ##torch.Size([720])
test_x = torch.unsqueeze(test_x, dim=1) ##torch.Size([720, 1, 1024])

看看train_data的维度

train_data.shape
# 输出结果
(4320, 1024)
#创建dataset
from sklearn.model_selection import train_test_split
train_ds = TensorDataset(train_x, train_y)
train_ds,valid_ds = train_test_split(train_ds, test_size = 0.2,random_state = 42) #将训练集分隔维训练集和测试集,分割比例5:1
test_ds = TensorDataset(test_x, test_y)
#创建dataloader
batch_size = 64   # 批次大小
train_dl = DataLoader(dataset = train_ds, batch_size = batch_size, shuffle=True)  #shuffle是进行乱序
valid_dl = DataLoader(dataset = valid_ds, shuffle=True, batch_size = batch_size)
test_dl = DataLoader(dataset = test_ds, batch_size = batch_size, shuffle=False)
for i, (imgs, labels) in enumerate(train_dl):print(i, imgs, labels)
## 输出结果
0 
tensor([[[-0.3515,  0.3796,  1.1076,  ..., -1.5466, -1.2220, -0.3242]],[[ 1.3776,  1.4640,  0.7423,  ...,  0.1370,  0.3632,  0.6326]],[[-0.4793, -0.6978, -0.5840,  ..., -0.6917, -1.3257, -1.4273]],...,[[-0.2387, -0.5382, -1.1767,  ...,  0.8166,  0.7749, -2.3905]],[[-0.9560, -1.3880, -1.7271,  ...,  0.4720,  0.2392,  0.1732]],[[ 0.2588,  0.0780, -0.5720,  ..., -0.0647,  0.4554,  1.2545]]]) tensor([3, 3, 3, 1, 2, 3, 3, 0, 2, 3, 2, 1, 3, 2, 2, 2, 3, 2, 2, 0, 2, 1, 2, 2,1, 2, 3, 3, 1, 2, 1, 0, 2, 1, 3, 0, 1, 3, 1, 3, 3, 1, 1, 3, 2, 0, 0, 1,0, 0, 0, 2, 1, 0, 0, 3, 1, 3, 3, 0, 2, 2, 0, 3])...

http://www.mrgr.cn/p/21644804

相关文章

Sysbench安装与使用

Sysbench Sysbench安装curl -s https://packagecloud.io/install/repositories/akopytov/sysbench/script.deb.sh | sudo bash sudo apt -y install sysbench sysbench --versionSysbench使用 内存测试 查看内存帮助信息 sysbench memory help # 查看内存帮助信息 sysbench 1.0…

为什么Python不适合写游戏?

知乎上有热门个问题&#xff1a;Python 能写游戏吗&#xff1f;有没有什么开源项目&#xff1f; Python可以开发游戏&#xff0c;但不是好的选择 Python作为脚本语言&#xff0c;一般很少用来开发游戏&#xff0c;但也有不少大型游戏有Python的身影&#xff0c;比如&#xff1…

数据可视化是怎样帮助网络安全建设的?

在当今数字化时代,网络安全已成为各大企业乃至国家安全的重要组成部分。随着网络攻击的日益复杂和隐蔽,传统的网络安全防护措施已难以满足需求,急需新型的解决方案以增强网络防护能力。数据可视化技术,作为一种将复杂数据转换为图形或图像表示的方法,正在成为提升网络安全…

Java八股文(JVM)

Java八股文のJVM JVM JVM 什么是Java虚拟机&#xff08;JVM&#xff09;&#xff1f; Java虚拟机是一个运行Java字节码的虚拟机。 它负责将Java程序翻译成机器代码并执行。 JVM的主要组成部分是什么&#xff1f; JVM包括以下组件&#xff1a; ● 类加载器&#xff08;ClassLoa…

Sentinel 对分布式服务进行流量控制

可以下载sentinel的jar包,用java -jar命令直接启动默认端口就是8080,这里随便写一下演示,其他修改还是直接看Sentinel网站吧 java -jar -Dserver.port=8080 sentinel的jar包名.jar对需要进行流量控制的服务进行依赖导入(这个依赖直接在父工程引入似乎无效,不知为何) <…

pod 入门

pod 入门 入门级的放通网络 我们使用 nginx 创建的pod,必须要能本地访问它才行,在不引入复杂的网络环境的前提下,我们可以通过映射端口或者映射本机网络的方式进行访问。 apiVersion: v1 kind: Pod metadata:name: web01labels:app: nginxversion: "1.24"namespac…

verilog设计-cdc:多比特信号跨时钟域(DMUX)

一、前言 多比特一般为数据&#xff0c;其在跨时钟域传输的过程中有多种处理方式&#xff0c;比如DMUX&#xff0c;异步FIFO&#xff0c;双口RAM&#xff0c;握手处理。本文介绍通过DMUX的方式传输多比特信号。 二、DMUX同步跨时钟域数据 dmux表示数据分配器&#xff0c;该方…

由浅到深认识Java语言(25):正则表达式

该文章Github地址&#xff1a;https://github.com/AntonyCheng/java-notes 在此介绍一下作者开源的SpringBoot项目初始化模板&#xff08;Github仓库地址&#xff1a;https://github.com/AntonyCheng/spring-boot-init-template & CSDN文章地址&#xff1a;https://blog.c…

殊荣双至,天翼云边缘计算再获两项大奖!

近日,全球边缘计算大会北京站在新世界大酒店成功召开。大会公布了“2023金边奖”评选结果,天翼云斩获“最佳智能边缘云服务商”“最佳边缘安全加速平台”两项大奖。天翼云边缘计算产品专家熊瑶、天翼云边缘安全产品专家杜茜参加会议并分别发表演讲,分享了天翼云边缘计算发展…

多所大学在用的网速测试网页工具librespeed/speedtest精简版

亲测比较准&#xff0c;所以精简自用并分享。 https://github.com/librespeed/speedtest/ 精简到15KB版本&#xff1a;https://download.csdn.net/download/YUJIANYUE/89049070

echarts 水球图

本文章出自: 作者:你不知道的巨蟹 文章:vue中如何使用基于 echarts 的 echarts-liquidfill 插件开发水球图前言 echarts4 官网:https://echarts.apache.org/v4/zh/option.html#series-scatter.coordinateSystemecharts5 官网:https://echarts.apache.org/echarts-liquidfi…

# Apache SeaTunnel 究竟是什么?

作者 | Shawn Gordon 翻译 | Debra Chen 原文链接 | What the Heck is Apache SeaTunnel? 我在2023年初开始注意到Apache SeaTunnel的相关讨论,一直低调地关注着。该项目始于2017年,最初名为Waterdrop,在Apache DolphinScheduler的创建者的贡献下发展起来,后者支持SeaTunn…

MybatisPlus多参数分页查询,黑马程序员SpringBoot3+Vue3教程第22集使用MP代替pagehelper

前言: 视频来源1:黑马程序员SpringBoot3+Vue3全套视频教程,springboot+vue企业级全栈开发从基础、实战到面试一套通关 视频来源2:黑马程序员最新MybatisPlus全套视频教程,4小时快速精通mybatis-plus框架 创作理由:网上MP实现分页查询功能的帖子易读性太差,具体实现看下面…

华为云使用指南02

5.​​使用GitLab进行团队及项目管理​​ GitLab旨在帮助团队进行项目开发协作&#xff0c;为软件开发和运营生命周期提供了一个完整的DevOps方案。GitLab功能包括&#xff1a;项目源码的管理、计划、创建、验证、集成、发布、配置、监视和保护应用程序等。该镜像基于CentOS操…

探索人工智能与强化学习:从基础原理到应用前景

人工智能(Artificial Intelligence,AI)是当今科技领域的热点话题,而强化学习(Reinforcement Learning,RL)作为其重要分支,在推动着智能系统向前迈进。本文将深入探讨AI与强化学习的基本原理、关键技术以及未来的应用前景,以期为读者提供全面的认识和理解。强化学习的基…

AI大模型学习

文章目录 每日一句正能量前言AI大模型学习的理论基础AI大模型的训练与优化AI大模型在特定领域的应用AI大模型学习的伦理与社会影响未来发展趋势与挑战后记 每日一句正能量 其实许多波折不过是成功道上的荆棘路&#xff0c;纵然今天不如意&#xff0c;但我们还有未来。 前言 随…

Haproxy2.8.1+Lua5.1.4部署,haproxy.cfg配置文件详解和演示

目录 一.快速安装lua和haproxy 二.配置haproxy的配置文件 三.配置haproxy的全局日志 四.测试负载均衡、监控和日志效果 五.server常用可选项 1.check 2.weight 3.backup 4.disabled 5.redirect prefix和redir 6.maxconn 六.调度算法 1.静态 2.动态 一.快速安装lu…

C# 高级文件操作与异步编程探索(初步)

文章目录 文本文件的读写探秘StreamReader 类深度剖析StreamWriter 类细节解读编码和中文乱码的解决方案 二进制文件的读写BinaryReader 类全面解析BinaryWriter 类深度探讨 异步编程与C#的未来方向同步与异步&#xff1a;本质解读Task 的神奇所在async/await 的魔法 在现代编程…

Linux下TCP/IP编程--TCP实战

之前尝试过windows下的简单TCP客户端服务器编写,这次尝试下一下Linux环境下的TCP 客户端代码 #include <stdio.h> #include <stdlib.h> #include <string.h> #include <unistd.h> #include <sys/socket.h> #include <netinet/in.h> #incl…

短视频文案提取的简单实现

​过春风十里,尽荠麦青青。春天总是让人舒坦,而今年的三月,也因为与媳妇结婚十年,显得格外不同。两人奢侈的请了一天假,瞒着孩子,重游西湖,去寻找13年前的冰棍店(给当时还是同事的她买了最贵的一个雪糕-8元),去寻找13年前卖红豆钥匙扣的大爷(她送我了一个绿豆的钥匙…