当前位置: 首页 > news >正文

总结一下 KNN、K-means 和 SVM【附代码实现】

小小总结一下 KNN、K-means 和 SVM 及其 Python 实现

好久没更新了,最近准备秋招,在机器学习中感觉经常被问的几个算法:K近邻算法(K-Nearest Neighbors, KNN)、K均值聚类算法(K-means)以及支持向量机(Support Vector Machine, SVM)。给自己做个总结笔记,并贴出来,如果有误欢迎指出。


1. K近邻算法 (KNN)

什么是 KNN?

K近邻算法是一种监督学习算法,常用于分类和回归问题。它的核心思想是:给定一个新样本,找出在特征空间中距离它最近的K个已知样本,然后基于这些邻居的信息进行分类或预测。
在这里插入图片描述

工作原理:
  1. 选择参数K(即最近邻居的数量)。
  2. 计算新样本与训练数据集中所有样本的距离(通常使用欧几里得距离)。
  3. 选择距离最近的K个样本。
  4. 根据这K个邻居的类别,进行投票以决定新样本的类别(在分类问题中)。
KNN的优缺点:
  • 优点:简单易理解,适合小数据集,直观。
  • 缺点:随着数据量的增加,计算量也增加;对特征的尺度和噪声敏感。
Python实现KNN
import numpy as np
from collections import Counter# KNN算法实现
class KNN:def __init__(self, k=3):self.k = kdef fit(self, X_train, y_train):self.X_train = X_trainself.y_train = y_traindef predict(self, X_test):predictions = [self._predict(x) for x in X_test]return np.array(predictions)def _predict(self, x):# 计算欧几里得距离distances = [np.sqrt(np.sum((x - x_train)**2)) for x_train in self.X_train]# 选择K个最近邻k_indices = np.argsort(distances)[:self.k]k_nearest_labels = [self.y_train[i] for i in k_indices]# 进行投票most_common = Counter(k_nearest_labels).most_common(1)return most_common[0][0]# 示例数据
X_train = np.array([[1, 2], [2, 3], [3, 4], [6, 7], [7, 8]])
y_train = np.array([0, 0, 0, 1, 1])
X_test = np.array([[2, 3], [5, 6]])# KNN实例化并预测
knn = KNN(k=3)
knn.fit(X_train, y_train)
predictions = knn.predict(X_test)
print(predictions)

2. K均值算法 (K-means)

什么是 K-means?

K-means是一种无监督学习算法,常用于聚类问题。它通过将数据集划分为K个簇,使得同一簇中的数据点彼此之间更加相似,而与其他簇的数据点差异更大。
在这里插入图片描述

工作原理:
  1. 选择K个初始质心(可以随机选取数据点作为质心)。
  2. 将每个数据点分配给最近的质心,形成K个簇。
  3. 计算每个簇的质心。
  4. 重复步骤2和3,直到质心不再改变(或变化微小)。
K-means的优缺点:
  • 优点:简单快速,适合大数据集。
  • 缺点:对初始质心选择敏感,可能陷入局部最优,K值需要人为设定。
Python实现K-means
from sklearn.cluster import KMeans
import numpy as np# 示例数据
X = np.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8], [1, 0.6], [9, 11]])# K-means聚类
kmeans = KMeans(n_clusters=2)
kmeans.fit(X)# 输出结果
print("簇的质心:", kmeans.cluster_centers_)
print("簇的分配:", kmeans.labels_)##############################################################
# 以上是调包的代码,具体实现如下
##############################################################
import numpy as npclass KMeans:def __init__(self, k=2, max_iters=100):self.k = k  # 簇的数量self.max_iters = max_iters  # 最大迭代次数self.centroids = None  # 质心# 计算欧几里得距离def _euclidean_distance(self, a, b):return np.sqrt(np.sum((a - b) ** 2))# 随机初始化质心def _initialize_centroids(self, X):np.random.seed(42)indices = np.random.choice(X.shape[0], self.k, replace=False)return X[indices]# 更新每个簇的质心为该簇所有点的均值def _compute_centroids(self, clusters, X):centroids = np.zeros((self.k, X.shape[1]))for idx, cluster in enumerate(clusters):centroids[idx] = np.mean(X[cluster], axis=0)return centroids# 将每个数据点分配给最近的质心def _assign_clusters(self, X, centroids):clusters = [[] for _ in range(self.k)]for idx, point in enumerate(X):distances = [self._euclidean_distance(point, centroid) for centroid in centroids]nearest_centroid_idx = np.argmin(distances)clusters[nearest_centroid_idx].append(idx)return clusters# 拟合K-means模型def fit(self, X):self.centroids = self._initialize_centroids(X)for _ in range(self.max_iters):clusters = self._assign_clusters(X, self.centroids)previous_centroids = self.centroidsself.centroids = self._compute_centroids(clusters, X)# 如果质心没有变化,提前终止if np.all(previous_centroids == self.centroids):break# 预测每个点所属的簇def predict(self, X):return [np.argmin([self._euclidean_distance(x, centroid) for centroid in self.centroids]) for x in X]# 示例数据
X = np.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8], [1, 0.6], [9, 11]])# K-means实例化并训练
kmeans = KMeans(k=2)
kmeans.fit(X)# 输出结果
print("质心:", kmeans.centroids)
predictions = kmeans.predict(X)
print("簇的分配:", predictions)

3. 支持向量机 (SVM)

什么是 SVM?

SVM是一种强大的监督学习算法,常用于分类问题。它的目标是找到一个最优的超平面,能够最大化区分不同类别的数据点。SVM不仅仅适用于线性可分问题,还可以通过引入核函数来处理非线性问题。
在这里插入图片描述

工作原理:
  1. 对于线性分类问题,SVM找到一个最大化分类间距的超平面。
  2. 对于非线性分类问题,SVM通过将数据映射到高维空间,使得数据在高维空间中线性可分。
  3. 核函数是SVM的核心之一,常见的核函数有线性核、RBF核等。
SVM的优缺点:
  • 优点:对高维数据效果好,能够处理非线性数据。
  • 缺点:对数据量较大的情况下计算开销较高,难以解释。
Python实现SVM
from sklearn import svm
import numpy as np# 示例数据
X = np.array([[1, 2], [2, 3], [3, 4], [6, 7], [7, 8]])
y = np.array([0, 0, 0, 1, 1])# SVM分类器
clf = svm.SVC(kernel='linear')
clf.fit(X, y)# 测试数据预测
X_test = np.array([[2, 3], [5, 6]])
predictions = clf.predict(X_test)
print("SVM预测结果:", predictions)##############################################################
# 以上是调包的代码,具体实现如下
##############################################################
import numpy as npclass SVM:def __init__(self, learning_rate=0.001, lambda_param=0.01, n_iters=1000):self.lr = learning_rateself.lambda_param = lambda_param  # 正则化参数self.n_iters = n_itersself.w = None  # 权重self.b = None  # 偏差# 拟合SVM模型def fit(self, X, y):n_samples, n_features = X.shapey_ = np.where(y <= 0, -1, 1)  # 将标签转化为-1和1# 初始化权重和偏差self.w = np.zeros(n_features)self.b = 0# 梯度下降for _ in range(self.n_iters):for idx, x_i in enumerate(X):condition = y_[idx] * (np.dot(x_i, self.w) - self.b) >= 1if condition:# 如果分类正确,只需要最小化权重的正则项self.w -= self.lr * (2 * self.lambda_param * self.w)else:# 如果分类错误,更新权重和偏差self.w -= self.lr * (2 * self.lambda_param * self.w - np.dot(x_i, y_[idx]))self.b -= self.lr * y_[idx]# 预测数据类别def predict(self, X):linear_output = np.dot(X, self.w) - self.breturn np.sign(linear_output)# 示例数据
X = np.array([[1, 2], [2, 3], [3, 4], [6, 7], [7, 8]])
y = np.array([0, 0, 0, 1, 1])  # 标签:0变为-1,1保持不变# SVM实例化并训练
svm = SVM()
svm.fit(X, y)# 预测测试数据
X_test = np.array([[2, 3], [5, 6]])
predictions = svm.predict(X_test)
print("SVM预测结果:", predictions)

总结

  • KNN 是一种基于距离的简单分类方法,适合小规模数据集。
  • K-means 是一种聚类算法,能够将数据集分割成多个簇。
  • SVM 是一种强大的分类算法,适用于线性和非线性问题。

http://www.mrgr.cn/news/42154.html

相关文章:

  • 杀疯啦!yolov11+strongsort的目标跟踪实现
  • 华为OD机试 - 密室逃生游戏(Python/JS/C/C++ 2024 E卷 100分)
  • 更美观的HTTP性能监测工具:httpstat
  • 【自然语言处理】(1) --语言转换方法
  • 小错误(输入数据)牛客 14683 储物点的距离
  • Oracle中MONTHS_BETWEEN()函数详解
  • 【笔记】选择题笔记408
  • PADS自动导出Gerber文件 —— 6层板
  • C/C++/EasyX ——入门图形编程(2)
  • leetcode134:加油站
  • 关于Mybatis框架操作时注意的细节,常见的错误!(博主亲生体会的细节!)
  • 秋天,相遇最美校园
  • 挖矿病毒记录 WinRing0x64.sys
  • 基于深度学习的视频内容理解
  • 【一文理解】conda install pip install 区别
  • Flutter与原生代码通信
  • 网络基础知识总结(二)
  • 游览器输入URL并Enter时都发生了什么 面试完美回答
  • index索引
  • Mybatis框架梳理