利用pytorch两层线性网络对titanic数据集进行分类(kaggle)

news/2024/5/21 1:39:51

利用pytorch两层线性网络对titanic数据集进行分类

最近在看pytorch的入门课程,做了一下在kaggle网站上的作业,用的是titanic数据集,因为想搭一下神经网络,所以数据加载部分简单的把训练集和测试集中有缺失值的列还有含有字符串的列去除了,加入了DataLoader模块,其实这个数据集很小,用不到,本人还没入门,小白一枚。

import torch 
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
from torchvision import datasets
from torchvision import transforms
import pandas as pdclass titanicDataset(Dataset):def __init__(self,filepath):xy=np.loadtxt(filepath,delimiter=',',skiprows=1,usecols=[1,2,7,8],dtype=np.float32)self.len=xy.shape[0]# print(self.len)self.y_data=torch.from_numpy(xy[:,[0]])self.x_data=torch.from_numpy(xy[:,1:])def __getitem__(self,index):#获取索引元素 return self.x_data[index],self.y_data[index]def __len__(self):return self.len
dataset=titanicDataset('./pytorch/dataset/titanic/train.csv')
train_loader=DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=0)# print(dataset.x_data,dataset.y_data)
test_loader=DataLoader(dataset=np.loadtxt('./pytorch/dataset/titanic/test.csv',delimiter=',',skiprows=1,usecols=[1,6,7],dtype=np.float32),batch_size=32,shuffle=False,num_workers=0)
print(next(iter(test_loader)))class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()# self.linear1=torch.nn.Linear(4,3)self.linear2=torch.nn.Linear(3,2)self.linear3=torch.nn.Linear(2,1)self.sigmoid=torch.nn.Sigmoid()def forward(self,x):# x=self.sigmoid(self.linear1(x))x=self.sigmoid(self.linear2(x))x=self.sigmoid(self.linear3(x))return x
model=Model()
criterion=torch.nn.BCELoss(size_average=True)
optimizer=torch.optim.SGD(model.parameters(),lr=0.1,momentum=0.9)
for epoch in range(10000):acc_num=0for i,data in enumerate(train_loader,0):#1.Prepare datainputs,labels=data# print(inputs.shape[0])#2.Forwardy_pred=model(inputs)loss=criterion(y_pred,labels)# print(epoch,i,loss.item())#3.Backwardoptimizer.zero_grad()loss.backward()#4.Updateoptimizer.step()y_pred_label=torch.where(y_pred>0.5,torch.tensor([1.0]),torch.tensor([0.0]))acc_num+=torch.eq(y_pred_label,labels).sum().item()# print(acc_num,len(dataset),len(train_loader.dataset))acc=acc_num/len(dataset)
print(acc)
# print(test_loader)
# print(test_loader.dataset.shape)
out = model(torch.tensor(test_loader.dataset))
y_pred = torch.where(out>0.5,torch.tensor([1.0]),torch.tensor([0.0]))[:,0]
print(y_pred)
print(pd.Series(y_pred))
id=pd.read_csv('./pytorch/dataset/titanic/test.csv',usecols=['PassengerId']).iloc[:,0]
# print(type(id))pd.DataFrame({'PassengerId':id,'Survived':pd.Series(y_pred,dtype=int)}).to_csv('pred.csv',index=None)
a=pd.DataFrame([id,pd.Series(y_pred)])
print(a)
# print(y_pred[-10:])# for x in test_loader:
#     print(x.shape)
#     out = model(x)
#     y_pred = torch.where(out>0.5,torch.tensor([1.0]),torch.tensor([0.0]))
# print(y_pred)

http://www.mrgr.cn/p/68400160

相关文章

莫队(板子)

莫队 参考博客 玄学暴力区间操作算法PPT解释的很清楚啦~, 导致我没什么可写的 \(qwq\) 把所有询问离线下来后排序(左端点按块,右端点升序),然后从一个小区间通过左右端点的移动扩大区间,更新答案。 复杂度主要在区间扩展,也就是左右指针的移动,对于莫队所有的优化几乎都是…

HTML4(三):表单

文章目录 表单1. 基本结构2. 常用表单控件2.1 文本输入框2.2 密码输入框2.3 单选框2.4 复选框2.5 隐藏域2.6 提交按钮2.7 重置按钮2.8 普通按钮2.9 文本域2.10 下拉框2.11 示例 3. 禁用表单控件4. lable标签5. fieldset与legend标签6. 总结 表单 概念:一种包含交互…

更优雅的使用Gson解析Json

Gson背靠Google这棵大树,拥有广泛的社区支持和相对丰富的文档资源,同时因其简单直观的API,一直以来基本稳坐Android开发序列化的头把交椅(直到Google宣布kotlin成为Android开发的首选语言)。本文对Gson的使用及主要流程做下分析。 Gson的基本使用 Gson依赖 kotlin 复制代…

树和二叉树的定义和基本术语

文章目录 前言一、树的定义二、树的基本术语三、二叉树的定义总结 前言 T_T此专栏用于记录数据结构及算法的(痛苦)学习历程,便于日后复习(这种事情不要啊)。所用教材为《数据结构 C语言版 第2版》严蔚敏。 一、树的定义…

【Vue】vue中将 html 或者 md 导出为 word 文档

原博主 xh-htmlword文档 感谢这位大佬的封装优化和分享,亲测有用!可以去看大佬👇的说明! 前端HTML转word文档,绝对有效!!! 安装 npm install xh-htmlword导入 import handleEx…

基于PSO优化的PV光伏发电系统simulink建模与仿真

目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 5.完整工程文件 1.课题概述 基于PSO优化的PV光伏发电系统simulink建模与仿真。其中PSO采用matlab编程实现,通过simulink的函数嵌入模块,将matlab调用进simulink中。 2.系统仿真结…

websevere服务器从零搭建到上线(四)|muduo网络库的基本原理和使用

文章目录 muduo源码编译安装muduo框架讲解muduo库编写服务器代码示例代码解析用户连接的创建和断开回调函数用户读写事件回调 使用vscode编译程序配置c_cpp_properties.json配置tasks.json配置launch.json编译 总结 muduo源码编译安装 muduo依赖Boost库,所以我们应…

麒麟系统

问题描述 Nginx最新版 Nginx 1.25.0解决方案 开放防火墙端口 开启端口:sudo firewall-cmd --zone=public --add-port=8080/tcp --permanent 关闭端口:sudo firewall-cmd --zone=public --remove-port=8080/tcp --permanent 端口生效:firewall-cmd --reload

C#中Linq的去重方式Distinct详解

一、首先创建一个控制台应用程序,添加一个Person对象 using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks;namespace Compare {public class Person{public string Name { get; set; }public int Age { ge…

华为ensp中BFD和OSPF联动(原理及配置命令)

作者主页:点击! ENSP专栏:点击! 创作时间:2024年5月6日20点26分 BFD通常指的是双向转发检测。BFD是一个旨在快速检测通信链路故障的网络协议,提供了低开销、短延迟的链路故障检测机制。它主要用于监测两个…

B/S模式的web通信

这里写目录标题 目标实现的目标 服务器代码(采用epoll实现服务器)整体框架main函数init_listen_fd函数(负责对lfd初始化的那一系列操作)epoll_run函数 一级目录二级目录二级目录二级目录 目标 实现的目标 我们要实现,…

数字集成电路 NMOS工作区

MOSFET是一个四端器件(栅极、源极、漏极、衬底)。 衬底一般连接到一个直流电源端:NMOS的衬底接地GND,PMOS的衬底接高电平VDD。(为了使得MOS管中的PN结零偏或反偏,尽管如此,二极管的结电容也会对电路产生影响)(PN结正偏不仅会形成通路,也会导致结电容急剧增大 C=ES/D) N…

Python-----容器的介绍以及操作

1.列表和元组 1.列表是什么, 元组是什么: 编程中, 经常需要使用变量, 来保存/表示数据. 如果代码中需要表示的数据个数比较少, 我们直接创建多个变量即可. 但是有的时候, 代码中需要表示的数据特别多, 甚至也不知道要表示多少个数据. 这个时候, 就需要用到列表 列表…

力扣138. 随机链表的复制

Problem: 138. 随机链表的复制 文章目录 题目描述思路及解法复杂度Code 题目描述 思路及解法 1.创建Map集合Map<Node, Node> map;创建指针cur指向head&#xff1b; 2.遍历链表将cur作为键&#xff0c;new Node(cur.val)作为值&#xff0c;存入map集合&#xff1b; 3.再次…

FPGA+炬力ARM实现VR视频播放器方案

FPGA炬力ARM方案&#xff0c;单个视频源信号&#xff0c;同时驱动两个LCD屏显示&#xff0c;实现3D 沉浸式播放 客户应用&#xff1a;VR视频播放器 主要功能&#xff1a; 1.支持多种格式视频文件播放 2.支持2D/3D 效果实时切换播放 3.支持TF卡/U盘文件播放 4.支持定制化配置…

一招MAX降低10倍,现在它是我的了

一.背景 性能优化是一场永无止境的旅程。 到家门店系统,作为到家核心基础服务之一,门店C端接口有着调用量高,性能要求高的特点。 C端服务经过演进,核心接口先查询本地缓存,如果本地缓存没有命中,再查询Redis。本地缓存命中率99%,服务性能比较平稳。随着门店数据越来越多…

【京东云新品发布月刊】2024年4月产品动态

京东云4月产品动态:1.【言犀AI虚拟主播】"采销东哥"数字人是怎样练成的?“大家好,好久不见,我是你们的老朋友东哥……”面对众网友喊话开直播,刘强东以新的形式与大家见面。4月16日下午6点18分,由京东云言犀打造的“采销东哥”AI数字人开启直播首秀,同时亮相京…

【C++】CentOS环境搭建-编译安装Boost库(附CMAKE编译文件)

【C】环境搭建-编译安装Boost库 Boost库简介Boost库安装通过YUM安装&#xff08;版本较低 V1.53.0&#xff09;通过编译安装&#xff08;官网最新版本1.85.0&#xff09;1.安装相关依赖2.查询官网下载最新安装包并解压3.编译Boost4.安装Boost库到系统路径 Boost库验证 Boost库简…

谷歌Gmail邮箱开启SMTP/IMAP服务流程

本篇专门定向讲解谷歌Gmail邮箱,如何开通SMTP协议的流程,在讲篇幅前,我需要你确定3件事:1.你已经有谷歌账号了2.你很清楚自己为什么想要开通SMTP服务3.你已经掌握一定的基础知识,能够达到翻出了谷歌Gmail邮箱开启SMTP/IMAP服务流程如果你没法“翻出去”,接下来的内容就可…

python教程9-第三方模块安装

https://pypi.python.org/pypi 是python的开源模块库。 收录了⾃全世界python开发者贡献的模块,⼏乎涵盖了你想⽤python做的任何事情。 事实上每个python开发 者,只要注册⼀个账号就可以往这个平台上传你⾃⼰的模块,这样全世界的开发者都可以容易的下载并使⽤你的模块。 下载…