当前位置: 首页 > news >正文

回归树练习,泰坦尼克号幸存者的预测

回归树练习,泰坦尼克号幸存者的预测

数据集下载地址
https://download.csdn.net/download/AnalogElectronic/89846327

我们来看看train.csv文件,它包含了891个样本,每个样本代表一个乘客。这些样本的数据包括乘客的年龄(Age)、船票等级(Pclass)、性别(Sex)、登船港口(Embarked)、票价(Fare)等基本信息,以及最重要的生存状态(Survived)。这些特征提供了对乘客生存可能性的洞察,比如男性与女性的生存率差异、船票等级与生存机会的关系等。

##回归树练习,泰坦尼克号幸存者的预测
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as plt
data = pd.read_csv(r"I:\hadoop note\titanic_train.csv",index_col= 0)
data.head()

在这里插入图片描述
在这里插入图片描述


#删除缺失值过多的列,和观察判断来说和预测的y没有关系的列
data.drop(["Cabin","Name","Ticket"],inplace=True,axis=1)
#处理缺失值,对缺失值较多的列进行填补,有一些特征只确实一两个值,可以采取直接删除记录的方法
data["Age"] = data["Age"].fillna(data["Age"].mean())
data = data.dropna()
#将分类变量转换为数值型变量
#将二分类变量转换为数值型变量
#astype能够将一个pandas对象转换为某种类型,和apply(int(x))不同,astype可以将文本类转换为数字,用这个方式可以很便捷地将二分类特征转换为0~1
data["Sex"] = (data["Sex"]== "male").astype("int")
#将三分类变量转换为数值型变量
labels = data["Embarked"].unique().tolist()
data["Embarked"] = data["Embarked"].apply(lambda x: labels.index(x))
#查看处理后的数据集
data.head()

在这里插入图片描述

##提取X和Y,拆分训练集和测试集
X = data.iloc[:,data.columns != "Survived"]
y = data.iloc[:,data.columns == "Survived"]
from sklearn.model_selection import train_test_split
Xtrain, Xtest, Ytrain, Ytest = train_test_split(X,y,test_size=0.3)
#修正测试集和训练集的索引
for i in [Xtrain, Xtest, Ytrain, Ytest]:i.index = range(i.shape[0])
#查看分好的训练集和测试集
Xtrain.head()

在这里插入图片描述

clf = DecisionTreeClassifier(random_state=25)
clf = clf.fit(Xtrain, Ytrain)
score_ = clf.score(Xtest, Ytest)
score_

在这里插入图片描述

##循环获取适合的max_depth
tr = []
te = []
for i in range(10):clf = DecisionTreeClassifier(random_state=25,max_depth=i+1 ,criterion="entropy" )clf = clf.fit(Xtrain, Ytrain)score_tr = clf.score(Xtrain,Ytrain)score_te = cross_val_score(clf,X,y,cv=10).mean()tr.append(score_tr)te.append(score_te)
print(max(te))
plt.plot(range(1,11),tr,color="red",label="train")
plt.plot(range(1,11),te,color="blue",label="test")
plt.xticks(range(1,11))
plt.legend()
plt.show()

0.8177860061287026
在这里插入图片描述

##交叉验证和网格搜索
import numpy as np
gini_thresholds = np.linspace(0,0.5,20)
parameters = {'splitter':('best','random'),'criterion':("gini","entropy"),"max_depth":[*range(1,10)],'min_samples_leaf':[*range(1,50,5)],'min_impurity_decrease':[*np.linspace(0,0.5,20)]}
clf = DecisionTreeClassifier(random_state=25)
GS = GridSearchCV(clf, parameters, cv=10)
GS.fit(Xtrain,Ytrain)
GS.best_params_

在这里插入图片描述

GS.best_score_

0.819969278033794


http://www.mrgr.cn/news/43059.html

相关文章:

  • LabVIEW程序怎么解决 Bug?
  • 如何快速学习K8s
  • 【Postman】接口测试工具使用
  • 物理学基础精解【53】
  • 第2篇:Windows权限维持----应急响应之权限维持篇
  • 递归实现单链表的尾插法
  • PMP--三模--解题--161-170
  • 【数据结构与算法】LeetCode:图论
  • 链表——双向链表
  • IntelliJ IDEA 2024.2 新特性概览
  • VTK+其他布尔运算库
  • 【游戏模组】重返德军总部2009高清重置MOD,建模和材质全部重置,并且支持光追效果,游戏画质大提升
  • 华为OD机试 - 核酸最快检测效率 - 动态规划、背包问题(Python/JS/C/C++ 2024 E卷 200分)
  • 如何写出更牛更系统的验证激励
  • 如何使用ssm实现果蔬商品管理系统的设计与实现+vue
  • 【微服务】负载均衡 - LoadBalancer(day4)
  • 【可答疑】基于51单片机的数字时钟(含仿真、代码、报告等)
  • 通过 Caddy2 部署 WebDAV 服务器
  • 利用 Python 爬虫采集 1688商品详情
  • JVM 内存区域划分