当前位置: 首页 > news >正文

基于Python的机器学习系列(17):梯度提升回归(Gradient Boosting Regression)

简介

        梯度提升(Gradient Boosting)是一种强大的集成学习方法,类似于AdaBoost,但与其不同的是,梯度提升通过在每一步添加新的预测器来减少前一步预测器的残差。这种方法通过逐步改进模型,能够有效提高预测准确性。

梯度提升回归的工作原理

        在梯度提升回归中,我们逐步添加预测器来修正模型的残差。以下是梯度提升的基本步骤:

  1. 初始化模型:选择一个初始预测器 h0(x),计算该预测器的预测值。
  2. 计算残差:计算每个样本的残差,残差是实际值与当前预测值之间的差异。
  3. 训练新预测器:用计算得到的残差作为目标,训练一个新的预测器 h1(x)。
  4. 更新模型:将新预测器的预测结果加到现有模型中。
  5. 重复步骤:重复上述步骤,逐步添加更多的预测器,以减少残差。

目标函数与残差

        在回归问题中,我们希望通过添加新的预测器来最小化残差。具体来说,对于每个样本 (x(i),y(i)),我们计算预测器的残差:

        我们希望新的预测器 h1(x)能够进一步减少这个残差:

        通过这样的方式,我们可以不断改进模型的预测能力。

梯度提升回归的损失函数

        在回归中,我们通常使用均方误差(MSE)作为损失函数:

        我们的目标是通过每一步最小化残差,从而最小化整体损失函数。

代码示例

        下面的代码示例展示了如何使用sklearn中的GradientBoostingRegressor实现梯度提升回归:        

from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error# 生成数据集
X, y = make_regression(n_samples=500, noise=0.2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 创建和训练模型
gbr = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)
gbr.fit(X_train, y_train)# 进行预测和评估
y_pred = gbr.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
print(f"均方误差: {mse:.2f}")

结语

        与之前讨论的决策树、Bagging、随机森林相比,梯度提升回归通过逐步优化模型的残差来提升预测性能。决策树和Bagging方法通过集成多个模型来减少方差,而随机森林进一步通过随机特征选择来去相关性。梯度提升则通过序列化的方式不断改进模型,强调对残差的逐步修正。每种方法都有其独特的优势和适用场景,选择合适的模型可以显著提高预测的准确性。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!


http://www.mrgr.cn/news/16077.html

相关文章:

  • 如何构建基于Java SpringBoot的医疗器械管理系统?四步详解从需求分析到系统部署,集成Vue.js提升用户体验,内含MySQL数据库管理技巧。
  • 基于STM32开发的智能电力监控与管理系统
  • 为什么生成设备号过后,还要去板子mknod /dev/led c 11 0来生成设备文件呢?
  • 【设计模式】代理模式
  • 为什么在JDBC中使用PreparedStatement?
  • 深度学习100问28:什么是RNNLM(RNN语言模型)
  • 浪潮信息携区域ISP伙伴,共创AI应用新生态
  • 企业级web服务实战 (模拟)(一
  • 【C++ 面试 - STL】每日 3 题(二)
  • 通过 TS-Mixer 实现股票价格预测
  • 【大数据分析工具】使用Hadoop、Spark进行大数据分析
  • 07.整合Pinia
  • CSRF漏洞的预防
  • ssh---配置密钥对验证
  • 蓝牙对象交换协议(OBEX) - 常见的opcode介绍
  • Python知识点:如何使用SQLAlchemy进行ORM(对象关系映射)
  • 如何扩展 WSL 2 虚拟硬盘的大小
  • 鸿蒙OS试题(2)
  • 技术风暴中的应急策略:开发团队如何应对突发故障与危机
  • 深度学习的基础_多层感知机的手动实现