【机器学习】分类算法 - KNN算法(K-近邻算法)KNeighborsClassifier

news/2024/5/17 12:28:13

「作者主页」:士别三日wyx
「作者简介」:CSDN top100、阿里云博客专家、华为云享专家、网络安全领域优质创作者
「推荐专栏」:零基础快速入门人工智能《机器学习入门到精通》

K-近邻算法

  • 1、什么是K-近邻算法?
  • 2、K-近邻算法API
  • 3、K-近邻算法实际应用
    • 3.1、获取数据集
    • 3.2、划分数据集
    • 3.3、特征标准化
    • 3.4、KNN处理并评估

1、什么是K-近邻算法?

K-近邻算法的核心思想是根据「邻居」「推断」你的类别。

K-近邻算法的思路其实很简单,比如我在北京市,想知道自己在北京的哪个区。K-近邻算法就会找到和我距离最近的‘邻居’,邻居在朝阳区,就认为我大概率也在朝阳区。

在这里插入图片描述

其中 K 是邻居个数的意思

  • 邻居个数「太少」,容易受到异常值的影响
  • 邻居个数「太多」,容易受到样本不均衡的影响。

2、K-近邻算法API

sklearn.neighbors.KNeighborsClassifier( n_neighbors=5, algorithm=‘auto’ ) 是实现K-近邻算法的API

  • n_neighbors:(可选,int)指定邻居(K)数量,默认值 5
  • algorithm:(可选,{ ‘auto’,‘ball_tree’,‘kd_tree’,‘brute’})计算最近邻居的算法,默认值 ‘auto’。

算法解析

  • brute:蛮力搜索,也就是线性扫描,训练集越大,消耗的时间越多。
  • kd_tree:构造kd树(也就是二叉树)存储数据以便对其进行快速检索,以中值切分构造的树,每个结点是一个超矩形,在维数小于20时效率高
  • ball_tree:用来解决kd树高维失效的问题,以质心C和半径r分割样本空间,每个节点是一个超球体。
  • auto:自动决定最合适的算法

函数

  • KNeighborsClassifier.fit( x_train, y_train):接收训练集特征 和 训练集目标
  • KNeighborsClassifier.predict(x_test):接收测试集特征,返回数据的类标签。
  • KNeighborsClassifier.score(x_test, y_test):接收测试集特征 和 测试集目标,返回准确率。
  • KNeighborsClassifier.get_params():获取接收的参数(就是 n_neighbors 和 algorithm 这种参数)
  • KNeighborsClassifier.set_params():设置参数
  • KNeighborsClassifier.kneighbors():返回每个相邻点的索引和距离
  • KNeighborsClassifier.kneighbors_graph():返回每个相邻点的权重

3、K-近邻算法实际应用

3.1、获取数据集

这里使用sklearn自带的鸢尾花「数据集」,它是分类最常用的分类试验数据集。

from sklearn import datasets# 1、获取数据集(实例化)
iris = datasets.load_iris()print(iris.data)

输出:

[[5.1 3.5 1.4 0.2][4.9 3.  1.4 0.2][4.7 3.2 1.3 0.2]

从打印的数据集可以看到,鸢尾花数据集有4个「属性」,这里解释一下属性的含义

  • sepal length:萼片长度(厘米)
  • sepal width:萼片宽度(厘米)
  • petal length:花瓣长度(厘米)
  • petal width:花瓣宽度(厘米)

3.2、划分数据集

接下来对鸢尾花的特征值(iris.data)和目标值(iris.target)进行「划分」,测试集为60%,训练集为40%。

from sklearn import datasets
from sklearn import model_selection# 1、获取数据集
iris = datasets.load_iris()
# 2、划分数据集
x_train, x_test, y_train, y_test = model_selection.train_test_split(iris.data, iris.target, random_state=6)
print('训练集特征值:', len(x_train))
print('测试集特征值:',len(x_test))
print('训练集目标值:',len(y_train))
print('测试集目标值:',len(y_test))

输出:

训练集特征值: 112
测试集特征值: 38
训练集目标值: 112
测试集目标值: 38

从打印结果可以看到,测试集的样本数是38,训练集的样本数是112,划分比例符合预期。


3.3、特征标准化

接下来,对训练集和测试集的特征值进行「标准化」处理(训练集和测试集所做的处理必须完全「相同」)。

from sklearn import datasets
from sklearn import model_selection
from sklearn import preprocessing# 1、获取数据集
iris = datasets.load_iris()
# 2、划分数据集
# x_train:训练集特征,x_test:测试集特征,y_train:训练集目标,y_test:测试集目标
x_train, x_test, y_train, y_test = model_selection.train_test_split(iris.data, iris.target, random_state=6)
# 3、特征标准化
ss = preprocessing.StandardScaler()
x_train = ss.fit_transform(x_train)
x_test = ss.fit_transform(x_test)
print(x_train)

输出:

[[-0.18295405 -0.192639    0.25280554 -0.00578113][-1.02176094  0.51091214 -1.32647368 -1.30075363][-0.90193138  0.97994624 -1.32647368 -1.17125638]

从打印结果可以看到,特征值发生了相应的变化。


3.4、KNN处理并评估

接下来,将训练集特征 和 训练集目标 传给 KNN,然后评估处理结果的「准确率」

from sklearn import datasets
from sklearn import model_selection
from sklearn import preprocessing
from sklearn import neighbors# 1、获取数据集
iris = datasets.load_iris()
# 2、划分数据集
# x_train:训练集特征,x_test:测试集特征,y_train:训练集目标,y_test:测试集目标
x_train, x_test, y_train, y_test = model_selection.train_test_split(iris.data, iris.target, random_state=6)
# 3、特征标准化
ss = preprocessing.StandardScaler()
x_train = ss.fit_transform(x_train)
x_test = ss.fit_transform(x_test)
# 4、KNN算法处理
knn = neighbors.KNeighborsClassifier(n_neighbors=2)
knn.fit(x_train, y_train)
# 5、评估结果
y_predict = knn.predict(x_test)
print('真实值和预测值对比:', y_predict == y_test)
score = knn.score(x_test, y_test)
print('准确率:', score)

输出:

真实值和预测值对比: [ True  True  True  True  True  True False  True  True  True False  TrueTrue  True  True False  True  True  True  True  True  True  True  TrueTrue  True  True  True  True  True  True  True  True  True False  TrueTrue  True]
准确率: 0.8947368421052632

从输出结果可以很容易看出,准确率是89%;真实值和预测值对比的结果中,True越多,表示准确率越高。


http://www.mrgr.cn/p/05388600

相关文章

centos7安装mysql数据库详细教程及常见问题解决

mysql数据库详细安装步骤 1.在root身份下输入执行命令: yum -y update 2.检查是否已经安装MySQL,输入以下命令并执行: mysql -v 如出现-bash: mysql: command not found 则说明没有安装mysql 也可以输入rpm -qa | grep -i mysql 查看是否已…

Unity XML3——XML序列化

一、XML 序列化 ​ 序列化:把对象转化为可传输的字节序列过程称为序列化,就是把想要存储的内容转换为字节序列用于存储或传递 ​ 反序列化:把字节序列还原为对象的过程称为反序列化,就是把存储或收到的字节序列信息解析读取出来…

java+springboot+mysql疫情物资管理系统

项目介绍: 使用javaspringbootmysql开发的疫情物资管理系统,系统包含超级管理员,系统管理员、员工角色,功能如下: 超级管理员:管理员管理;部门管理;职位管理;员工管理&…

zore-shot,迁移学习和多模态学习

1.zore-shot 定义:在ZSL中,某一类别在训练样本中未出现,但是我们知道这个类别的特征,然后通过语料知识库,便可以将这个类别识别出来。概括来说,就是已知描述,对未知类别(未在训练集中…

Python 教程之标准库概览

概要 Python 标准库非常庞大,所提供的组件涉及范围十分广泛,使用标准库我们可以让您轻松地完成各种任务。 以下是一些 Python3 标准库中的模块: 「os 模块」 os 模块提供了许多与操作系统交互的函数,例如创建、移动和删除文件和…

Debeizum 增量快照

在Debeizum1.6版本发布之后,成功推出了Incremental Snapshot(增量快照)的功能,同时取代了原有的实验性的Parallel Snapshot(并行快照)。在本篇博客中,我将介绍全新快照方式的原理,以…

系统架构设计师-软件架构设计(5)

目录 一、构件与中间件技术 1、软件复用 2、构件与中间件技术的概念 3、构件的复用 3.1 检索与提取构件 3.2 理解与评价构件 3.3 修改构件 3.4 组装构件 4、中间件 4.1 采用中间件技术的优点: 4.2 中间件的分类: 5、构件标准 5.1 CORBA(公共…

day43-Feedback Ui Design(反馈ui设计)

50 天学习 50 个项目 - HTMLCSS and JavaScript day43-Feedback Ui Design&#xff08;反馈ui设计&#xff09; 效果 index.html <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport&q…

CPU密集型和IO密集型任务的权衡:如何找到最佳平衡点

关于作者&#xff1a;CSDN内容合伙人、技术专家&#xff0c; 从零开始做日活千万级APP。 专注于分享各领域原创系列文章 &#xff0c;擅长java后端、移动开发、人工智能等&#xff0c;希望大家多多支持。 目录 一、导读二、概览三、CPU密集型与IO密集型3.1、CPU密集型3.2、I/O密…

【已解决】windows7添加打印机报错:加载Tcp Mib库时的错误,无法加载标准TCP/IP端口的向导页

windows7 添加打印机的时候&#xff0c;输入完打印机的IP地址后&#xff0c;点击下一步&#xff0c;报错&#xff1a; 加载Tcp Mib库时的错误&#xff0c;无法加载标准TCP/IP端口的向导页 解决办法&#xff1a; 复制以下的代码到新建文本文档.txt中&#xff0c;然后修改文本文…

搭建测试平台开发(一):Django基本配置与项目创建

一、安装Django最新版本 1 pip install django 二、创建Django项目 首先进入要存放项目的目录&#xff0c;再执行创建项目的命令 1 django-admin startproject testplatform 三、Django项目目录详解 1 testplatform 2 ├── testplatform  # 项目的容器 3 │ ├──…

清洁机器人规划控制方案

清洁机器人规划控制方案 作者联系方式Forrest709335543qq.com 文章目录 清洁机器人规划控制方案方案简介方案设计模块链路坐标变换算法框架 功能设计定点自主导航固定路线清洁区域覆盖清洁贴边沿墙清洁自主返航回充 仿真测试仿真测试准备定点自主导航测试固定路线清洁测试区域…

ER系列路由器多网段划分设置指南

ER系列路由器多网段划分设置指南 - TP-LINK 服务支持 TP-LINK ER系列路由器支持划分多网段&#xff0c;可以针对不同的LAN接口划分网段&#xff0c;即每一个或多个LAN接口对应一个网段&#xff1b;也可以通过一个LAN接口与支持划分802.1Q VLAN的交换机进行对接&#xff0c;实现…

微信小程序导入微信地址

获取用户收货地址。调起用户编辑收货地址原生界面&#xff0c;并在编辑完成后返回用户选择的地址。 1&#xff1a;原生微信小程序接口使用API&#xff1a;wx.chooseAddress(OBJECT) wx.chooseAddress({success (res) {console.log(res.userName)console.log(res.postalCode)c…

Day02-作业(JavaScriptVue)

作业1&#xff1a;实现5秒之后&#xff0c;当前页面直接跳转到官网首页&#xff08;首页地址&#xff1a;https://www.itcast.cn&#xff09; 提示&#xff1a; 5秒之后&#xff0c;才触发某一个动作 素材&#xff1a; <!DOCTYPE html> <html lang"en"&…

【AGI】Copilot AI编程辅助工具安装教程

1. 基础激活教程 GitHub和OpenAI联合为程序员们送上了编程神器——GitHub Copilot。 但是&#xff0c;Copilot目前不提供公开使用&#xff0c;需要注册账号通过审核&#xff0c;我也提交了申请&#xff1a;这里第一期记录下&#xff0c;开启教程&#xff0c;欢迎大佬们来讨论…

【Linux】更换jdk版本

目录 一、前言二、查看jdk版本号1、项目中的版本号&#xff08;pom.xml&#xff09;2、服务器中的版本号 三、更换jdk版本1、创建java文件夹2、下载并解压JDK安装包①、下载jdk安装包②、移动到创建好的/usr/local/java路径下③、解压jdk安装包 四、删除原来的jdk版本1、删除原…

删除 iptables 中的规则

查看规则编号 要删除 iptables 中的规则&#xff0c;可以使用以下命令&#xff1a; 查看 iptables 中的规则&#xff0c;找到要删除的规则的编号&#xff1a; iptables -L --line-numbers删除指定编号的规则&#xff1a; iptables -D [chain] [rule-number]其中&#xff0c;…

设计模式再探——代理模式

目录 一、背景介绍二、思路&方案三、过程1.代理模式简介2.代理模式的类图3.代理模式代码4.代理模式还可以优化的地方5.代理模式的项目实战&#xff0c;优化后(只加了泛型方式&#xff0c;使用CGLIB的代理) 四、总结五、升华 一、背景介绍 最近在做产品过程中对于日志的统一…

c# 此程序集中已使用了资源标识符

严重性 代码 说明 项目 文件 行 禁止显示状态 错误 CS1508 此程序集中已使用了资源标识符“BMap.NET.WindowsForm.BMapControl.resources” BMap.NET.WindowsForm D:\MySource\Decompile\BMap.NET.WindowsForm\CSC 1 活动 运行程序时&a…