分类模型评估:混淆矩阵与ROC曲线

news/2024/5/14 17:09:59

  • 1.混淆矩阵
  • 2.ROC曲线 & AUC指标

理解混淆矩阵和ROC曲线之前,先区分几个概念。对于分类问题,不论是多分类还是二分类,对于某个关注类来说,都可以看成是二分类问题,当前的这个关注类为正类,所有其他非关注类为负类。因为样本的真实值有正负两类,而模型的预测值也有正负两类,因此样本的真实值和模型的预测值之间产生了下面4种组合:

  • 真正例(True Positives/TP):在所有真实值为正类的样本中,模型预测值也为正类的样本数。
  • 假正例(False Positives/FP):在所有真实值为负类的样本中,模型预测值为正类的样本数。
  • 真负例(True Negatives/TN):所有真实值为负类的样本中,模型预测值也为负类的样本数。
  • 假负例(False Negatives/FN):所有真实值为正类的样本中,模型预测值为负类的样本数。

从上面几个定义可以知道:
1)样本总数 = TP+FP+TN+FN
2)所有真实值为正类的样本总数 = TP+FN
3)所有真实值为负类的样本总数 = TN+FP

1.混淆矩阵

使用sklearn自带的鸢尾花数据集,数据集里鸢尾花包含3个分类。

import numpy as np
from sklearn.datasets import load_iris
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix# 获取特征值与目标值
data = load_iris()
X, y = data['data'], data['target']# 自带的数据集分类准确率为1,为了后面更好的基于混淆矩阵验证相关指标的计算,为训练集添加均值0,标准差2的高斯噪声
np.random.seed(42)
noise = np.random.normal(0, 2, (len(X), len(X[0])))
X += noise# 特征值归一化到区间[-1,1]
scaler = MinMaxScaler(feature_range=(-1, 1))
X_scaled = scaler.fit_transform(X)# 划分训练集与测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)# 创建逻辑回归模型、训练并预测
model = LogisticRegression(multi_class='multinomial', max_iter=1000)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)# 获取模型混淆矩阵、分类报告、准确率
print(f"混淆矩阵:\n{confusion_matrix(y_test, y_pred)}")
print(f"分类报告:\n{classification_report(y_test, y_pred)}")
print(f"准确率:\n{accuracy_score(y_test, y_pred)}")

output:
在这里插入图片描述

混淆矩阵中,横向表示真实值,纵向表示预测值。比如第一个位置7,表示实际类别为0且预测类别为0的样本有7个。基于此混淆矩阵,可以衍生下面相关指标:
1)准确率(accuracy):准确率表示模型对一个样本类别预测正确的可能性,是相对整体来说的。计算方式为所有预测正确的样本(斜对角线之和)/ 样本总数,本例中accuracy=(7+4+8)/30=0.63…
2)精确率(precision):精确率是针对某个具体关注类来说的,精确率关注的是,对于所有预测值为该类的样本中,真实值也属于该类的样本所占的比例,计算公式为 T P T P + F P \frac{TP}{TP+FP} TP+FPTP。比如对于类别0,模型预测的该类样本数=(7+0+2)=9,而真实值为该类的样本数为7,那么类别0的precision=7/9=0.77…。精确率反映了“模型找的对不对”
3)召回率(recall):同样召回率也是针对某个具体关注类来说的,关注的是所有真实值为该类的样本中,模型能正确预测为该类的样本所占的比例,计算公式 T P T P + F N \frac{TP}{TP+FN} TP+FNTP。还是拿类别0来说,真实值为0的样本总数=(7+3+0)=10,模型能正确预测为该类的样本总数为7,所以类别0的召回率=7/10=0.7。召回率代表了“模型找的全不全”
4)F1-score:F1分数是精确率与召回率的调和平均数,计算公式为 2 ∗ p r e c i s i o n ∗ r e c a l l p r e c i s i o n + r e c a l l \frac{2*precision*recall}{precision+recall} precision+recall2precisionrecall。对于类别0的F1-score= 2 ∗ 0.78 ∗ 0.7 0.78 + 0.7 \frac{2*0.78*0.7}{0.78+0.7} 0.78+0.720.780.7=0.737…,F1分数用来表示模型在关注的类上识别正类的综合表现,最大值1表示分类效果最好完全正确,最小值0表示分类效果最差完全错误。

分类报告直接提供了每个分类下的精确率、召回率、F1分数等指标。最下面两行的macro avgweighted avg分别表示对每个指标的算术平均和加权平均,最后一列的support表示对应的样本数量。

2.ROC曲线 & AUC指标

ROC:Receiver Operating Characteristic。
AUC:Area Under the [ROC] Curve,ROC曲线下的面积。

ROC曲线的绘制中需要用到两个指标:

  • 真正率(True Positive Rate/TPR):在所有真实类别为正类的样本中,模型正确识别为正类的样本所占的比例,也就是把正类样本识别成正类样本的概率。反映了模型识别正类的能力,可以看成是模型在识别正类样本时的收获能力,计算公式 T P T P + F N \frac{TP}{TP+FN} TP+FNTP
  • 假正率(False Positive Rate/FPR):在所有真实类别为负类的样本中,模型错误识别为正类的样本所占的比例,即把负类样本识别为正类样本的概率。反映了模型识别为正类样本时的错误程度,可以理解成模型在识别正类样本时付出的代价,计算公式 F P T N + F P \frac{FP}{TN+FP} TN+FPFP

大多数分类模型都是通过计算出每个样本属于正类的概率,和属于正类的概率阈值进行比较来对样本进行分类的。正类的概率>=阈值,判定为正类,反之判定为负类。

ROC曲线是由不同概率阈值下真正率(y轴)和假正率(x轴)对应的一系列点所构成的曲线,x轴从左到右判定为正类的概率阈值从1到0逐渐递减。ROC曲线用来描述二分类模型预测效果,对于多分类问题,是将关注类视为正类,其他类视为负类。

ROC曲线的具体绘制过程可以理解为:

  1. 对于测试集中的每个样本,利用分类器预测其为正类的概率值。
  2. 将这些概率值按照从大到小的顺序排列,作为阈值。
  3. 对于每个阈值,分别计算真正率和假正率,对应坐标轴上的一个点。
  4. 连接这些点。

从真正率和假正率的计算,可以看出曲线越往右,判定为正类的概率阈值越低,那么就有更多的样本被归类到正类当中,因为分母是不变的,分子(TP/FP)随着正类样本增多都会逐渐增大,因此ROC的曲线走势应该是一个从(0, 0)到(1, 1)逐渐上升的曲线。

同时,因为x轴代表了在识别正类时付出的代价,y轴代表了在识别正类时的收获,因此当x值越小,y值越大,即曲线越靠近左上角(0, 1),说明模型的分类效果越好。

而AUC,是ROC曲线下的面积,它衡量的是模型在所有概率阈值下识别正类时“收获”与“代价”的比重,因此AUC值越大越好,值域范围[0, 1]。
AUC=0.5:模型不具有分类效果,相当于盲猜。
AUC<0.5:分类效果最差,不如盲猜。
AUC>0.5:有一定的分类效果,值越接近1分类效果越好。

下面还是以鸢尾花的数据集为例,通过一个demo对ROC和AUC进行计算和绘制。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_curve, auc# 获取特征值与目标值
data = load_iris()
X, y = data['data'], data['target']# 仅使用两个类别:0 & 1
X = X[y != 2]
y = y[y != 2]# 训练集添加噪声
np.random.seed(42)
noise = np.random.normal(0, 2, (len(X), len(X[0])))
X += noise# 归一化
scaler = MinMaxScaler(feature_range=(-1, 1))
X_scaled = scaler.fit_transform(X)# 划分数据集、创建逻辑回归模型、训练
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
model = LogisticRegression(multi_class='multinomial', max_iter=1000)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)# 获取每个样本预测为正类样本的概率,[:, 0]是负类样本的概率
y_pred_prob = model.predict_proba(X_test)[:, 1]# 计算FPR、TPR和AUC值
fpr, tpr, thresholds = roc_curve(y_test, y_pred_prob)
roc_auc = auc(fpr, tpr)# 绘制ROC曲线
plt.figure()
plt.plot(fpr, tpr, color='green', lw=1, label=f'ROC Curve (AUC={roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='red', lw=1, linestyle='--')
plt.xlim([0.0, 1])
plt.ylim([0.0, 1.05])
plt.xlabel('FPR')
plt.ylabel('TPR')
plt.title('ROC Curve')
plt.legend(loc="lower right")
plt.show()

output:
在这里插入图片描述
绿线代表ROC曲线,红线相当于盲猜,绿线在红线上方距离红线越远模型分类效果越好。


http://www.mrgr.cn/p/88836665

相关文章

LineageOS刷机教程

一、刷机准备工作 在 https://download.lineageos.org.cn/ 这个网址里找到自己的手机型号,也可以用Ctrl+F搜索手机型号关键字找到后点击下载下载页面左侧找到自己的手机型号(不知道上面点击下载,进来的二级页面还得再找一次)下载好lineageOS的刷机包,boot.img文件,recove…

Linux基础篇:解析Linux命令执行的基本原理

Linux 命令是一组可在 Linux 操作系统中使用的指令&#xff0c;用于执行特定的任务&#xff0c;例如管理文件和目录、安装和配置软件、网络管理等。这些命令通常在终端或控制台中输入&#xff0c;并以文本形式显示输出结果。 Linux 命令通常以一个或多个单词的简短缩写或单词…

C++王牌结构hash:哈希表闭散列的实现与应用

一、哈希概念 顺序结构以及平衡树中&#xff0c;元素关键码与其存储位置之间没有对应的关系&#xff0c;因此在查找一个元素 时&#xff0c;必须要经过关键码的多次比较。顺序查找时间复杂度为O(N)&#xff0c;平衡树中为树的高度&#xff0c;即O(log n)&#xff0c;搜索的效率…

Unity AI Navigation自动寻路

目录 前言一、Unity中AI Navigation是什么&#xff1f;二、使用步骤1.安装AI Navigation2.创建模型和材质3.编写向目标移动的脚本4.NavMeshLink桥接组件5.NavMeshObstacle组件6.NavMeshModifler组件 三、效果总结 前言 Unity是一款强大的游戏开发引擎&#xff0c;而人工智能&a…

TCP三次握手、四次挥手出现意外情况时,如何保证稳定可靠?

TCP 作为一个靠谱的协议,在传输数据的前后,需要在双端之间建立连接,并在双端各自维护连接的状态。TCP 并没有什么特别之处,在面对多变的网络情况,也只能通过不断的重传和各种算法来保证可靠性。建立连接前,TCP 会通过三次握手来保证双端状态正确,然后就可以正常传输数据…

八大技术趋势案例(人工智能物联网)

科技巨变,未来已来,八大技术趋势引领数字化时代。信息技术的迅猛发展,深刻改变了我们的生活、工作和生产方式。人工智能、物联网、云计算、大数据、虚拟现实、增强现实、区块链、量子计算等新兴技术在各行各业得到广泛应用,为各个领域带来了新的活力和变革。 为了更好地了解…

手机可以搜索到wifi,但电脑搜索不到WiFi无线网络的解决方法

手机可以搜索到wifi,但电脑搜索不到WiFi无线网络的解决方法:(重点先尝试一下后面的 原因二) 手机可以搜索到wifi,但电脑搜索不到WiFi无线网络的解决方法 一般来说,手机如果可以搜索到路由器Wi-Fi无线信号,并且可以连上上网,说明路由器自身和设置肯定是没有任何问题的,这…

数据可视化-ECharts Html项目实战(7)

在之前的文章中&#xff0c;我们学习了如何设置漏斗图、仪表盘。想了解的朋友可以查看这篇文章。同时&#xff0c;希望我的文章能帮助到你&#xff0c;如果觉得我的文章写的不错&#xff0c;请留下你宝贵的点赞&#xff0c;谢谢 数据可视化-ECharts Html项目实战&#xff08;6…

学习JavaEE的日子 Day32 线程池

Day32 线程池 1.引入 一个线程完成一项任务所需时间为&#xff1a; 创建线程时间 - Time1线程中执行任务的时间 - Time2销毁线程时间 - Time3 2.为什么需要线程池(重要) 线程池技术正是关注如何缩短或调整Time1和Time3的时间&#xff0c;从而提高程序的性能。项目中可以把Time…

DataX-Oracle新增writeMode支持update

目录 前言 第一步下载源码 第二步修改源码 1、Oraclewriter 2、WriterUtil 2.1、修改getWriteTemplate方法 2.2、新增onMergeIntoDoString与getStrings方法 3、CommonRdbmsWriter 3.1、修改startWriteWithConnection 3.2、修改doBatchInsert 3.3、修改fillPreparedStatem…

Flutter 拦截系统键盘,显示自定义键盘

一、这里记录下在开发过程中&#xff0c;下单的时候输入金额需要使用自定义的数字键盘 参考链接: https://juejin.cn/post/7166046328609308685 效果图 二、屏蔽系统键盘 怎样才能够在输入框获取焦点的时候&#xff0c;不让系统键盘弹出呢&#xff1f;同时又显示我们自定义的…

【SpringBoot从入门到精通】01_SpringBoot概述

一、Spring与SpringBoot 1.1 Spring Spring 是一款目前主流的 Java EE 轻量级开源框架&#xff0c;是 Java 世界最为成功的框架之一。Spring 由“Spring 之父”Rod Johnson(罗宾约翰逊) 提出并创立&#xff0c;其目的是用于简化 Java 企业级应用的开发难度和开发周期。 广义…

新网站收录时间是多久,新建网站多久被百度收录

对于新建的网站而言&#xff0c;被搜索引擎收录是非常重要的一步&#xff0c;它标志着网站的正式上线和对外开放。然而&#xff0c;新网站被搜索引擎收录需要一定的时间&#xff0c;而且时间长短受多种因素影响。本文将探讨新网站收录需要多长时间&#xff0c;以及新建网站多久…

Prometheus +Grafana +node_exporter可视化监控Linux + windows虚机

1、介绍 待补充 2、架构图 Prometheus &#xff1a;主要是负责存储、抓取、聚合、查询方面。 node_exporter &#xff1a;主要是负责采集物理机、中间件的信息。 3、搭建过程 配置要求&#xff1a;1台主服务器 n台从服务器 &#xff08;被监控的linux或windows虚机&am…

【分布式】——CAPBASE理论

CAP&BASE理论 ⭐⭐⭐⭐⭐⭐ Github主页&#x1f449;https://github.com/A-BigTree 笔记链接&#x1f449;https://github.com/A-BigTree/tree-learning-notes ⭐⭐⭐⭐⭐⭐ Spring专栏&#x1f449;https://blog.csdn.net/weixin_53580595/category_12279588.html Sprin…

数仓 - [04] 数仓分层

数仓分层是一种将数据仓库按照不同的层级进行组织和管理的方法。每个层级具有不同的功能和目的,旨在支持数据分析、报告和决策等不同的业务需求。一、数仓分层的意义和目的数仓分层的主要目的是实现数据的高效访问和分析,提高数据的可用性和性能,同时提供更好的灵活性和可扩…

TheMoon 恶意软件短时间感染 6,000 台华硕路由器以获取代理服务

文章目录 针对华硕路由器Faceless代理服务预防措施 一种名为"TheMoon"的新变种恶意软件僵尸网络已经被发现正在侵入全球88个国家数千台过时的小型办公室与家庭办公室(SOHO)路由器以及物联网设备。 "TheMoon"与“Faceless”代理服务有关联&#xff0c;该服务…

Hybrid-PSC:基于对比学习的混合网络,解决长尾图片分类 | CVPR 2021

论文提出新颖的混合网络用于解决长尾图片分类问题,该网络由用于图像特征学习的对比学习分支和用于分类器学习的交叉熵分支组成,在训练过程逐步将训练权重调整至分类器学习,达到更好的特征得出更好的分类器的思想。另外,为了节省内存消耗,论文提出原型有监督对比学习。从实…

axios发送get请求但参数中有数组导致请求路径多出了“[]“的处理办法

一、情况 使用axios发送get请求携带了数组参数时&#xff0c;请求路径中就会多出[]字符&#xff0c;而在后端也会报错 二、解决办法 1、安装qs 当前项目的命令行中安装 npm install qs2、引入qs库(使用qs库来将参数对象转换为字符串) // 全局 import qs from qs Vue.proto…

STM32使用USART发送数据包指令点亮板载LED灯

电路连接&#xff1a; 连接显示屏模块&#xff0c;显示屏的SCL在B10&#xff0c;SDA在B11。 程序目的&#xff1a; 发送LED_ON指令打开板载LED灯&#xff0c;发送LED_OFF关闭板载LED灯&#xff0c;与上一个博客不同&#xff0c;这个实际上是实现串口收发文本数据包。 …