当前位置: 首页 > news >正文

神经网络分类任务

代码:
 


from pathlib import Path
import requests
#导入 Path 类用于处理文件路径,requests 库用于从网络下载文件。
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"PATH.mkdir(parents=True, exist_ok=True)# 这行代码调用PATH对象的mkdir方法来创建目录。mkdir是创建文件夹的命令
# parents = True表示如果父目录(这里是data目录)不存在,则先创建父目录。
# exist_ok = True表示如果目录已经存在,则不会引发异常。这个设置在脚本可能多次运行并且不需要重复创建已有目录的情况下非常有用。
URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"if not (PATH / FILENAME).exists():#如果文件或目录存在,exists() 方法返回 True;如果不存在,则返回 Falsecontent = requests.get(URL + FILENAME).content(PATH / FILENAME).open("wb").write(content)
import pickle#这行代码导入pickle模块,pickle模块用于序列化和反序列化 Python 对象。在这里是为了后续从下载的文件中加载数据。
import gzip#这行代码导入gzip模块,因为下载的文件是.gz压缩格式的,需要使用gzip模块来解压缩
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
#这行代码使用gzip.open以二进制读取模式(rb)打开PATH路径下的FILENAME文件(先解压缩)。as_posix()方法是将Path对象转换为适合gzip.open使用的字符串形式。然后使用with语句来确保文件在使用后正确关闭,将打开的文件对象赋值给f((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
#在with语句块内,这行代码使用pickle.load从文件对象f中加载数据。数据被加载并解包为((x_train, y_train), (x_valid, y_valid), _)的形式,其中_表示可能存在的第三个元组中的未使用部分(可能是测试集数据,但这里没有使用它的名称)。
#这里使用encoding="latin-1"是因为文件可能是以这种编码方式存储数据的
print(x_train.shape)#784是mnist数据集每个样本的像素点个数
import torch
x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid)
)#这里使用了 Python 的 map 函数,将 torch.tensor 函数应用到 x_train、y_train、x_valid、y_valid 这四个对象上,
# 将它们转换为 PyTorch 张量。这样做是为了在 PyTorch 框架下进行后续的深度学习操作,如模型训练等。
n, c = x_train.shape
x_train, x_train.shape, y_train.min(), y_train.max()
#这行代码计算了一些关于训练集数据的信息,包括 x_train 本身、x_train 的形状、y_train 中的最小值和最大值。不过,这里只是计算了这些值,没有对结果进行任何处理或保存。
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())
import torch.nn.functional as F
# weights = torch.randn([784, 10], dtype = torch.float, requires_grad = True)
# bias = torch.zeros(10, dtype = torch.float, requires_grad = True)
loss_func = F.cross_entropy
def model(xb):return xb.mm(weights) + bias
torch.randn([784, 10], dtype = torch.float,  requires_grad = True)
bs = 64
xb = x_train[0:bs]  # a mini-batch from x
yb = y_train[0:bs]
weights = torch.randn([784, 10], dtype = torch.float,  requires_grad = True)
bs = 64
bias = torch.zeros(10, requires_grad=True)
print(loss_func(model(xb), yb))
from torch import nnclass Mnist_NN(nn.Module):def __init__(self):super().__init__()self.hidden1 = nn.Linear(784, 128)self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256, 10)def forward(self, x):x = F.relu(self.hidden1(x))x = F.relu(self.hidden2(x))x = self.out(x)return x
net = Mnist_NN()
print(net)
from matplotlib import pyplot
import numpy as nppyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)
for name, parameter in net.named_parameters():print(name, parameter,parameter.size())
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoadertrain_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)
def get_data(train_ds, valid_ds, bs):return (DataLoader(train_ds, batch_size=bs, shuffle=True),DataLoader(valid_ds, batch_size=bs * 2),)
import numpy as npdef fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval()with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))
from torch import optim
def get_model():model = Mnist_NN()return model, optim.SGD(model.parameters(), lr=0.001)
def loss_batch(model, loss_func, xb, yb, opt=None):loss = loss_func(model(xb), yb)if opt is not None:loss.backward()opt.step()opt.zero_grad()return loss.item(), len(xb)
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(25, model, loss_func, opt, train_dl, valid_dl)

损失到0.37/0.36左右就差不多了


http://www.mrgr.cn/news/42527.html

相关文章:

  • webGL入门(六)图形旋转
  • 快速点特征直方图 (FPFH) 描述符 和 点特征直方图 (PFH) 描述符 的差异
  • Redis篇(面试题 - 连环16炮)(持续更新迭代)
  • 数据链路层(以太网简介)
  • 24年最新大众点评数据
  • 【深度学习基础模型】回声状态网络(Echo State Networks, ESN)详细理解并附实现代码。
  • Vue2如何在网页实现文字的逐个显现
  • 69.【C语言】动态内存管理(重点)(2)
  • 【60天备战2024年11月软考高级系统架构设计师——第36天:系统安全设计——数据加密】
  • 【微服务】负载均衡 - LoadBalance(day4)
  • 我与世界的联系---读书
  • 【MySQL】Ubuntu环境下MySQL的安装与卸载
  • IEC104规约的秘密之六----配置参数k,w
  • 数据库管理-第247期 23ai:全球分布式数据库-Schema对象(20241004)
  • 基于Springboot+Vue的在线项目管理与任务分配中的应用 (含源码数据库)
  • 2024软件测试面试大全(含答案+文档)
  • 2024最新软件测试面试八股文
  • 基础算法--枚举
  • 第18场小白入门赛(蓝桥杯)
  • TryHackMe 第6天 | Web Fundamentals (一)