神经网络的优化器

news/2024/5/21 20:53:55

神经网络的优化器是用于训练神经网络的一类算法,它们的核心目的是通过改变神经网络的权值参数来最小化或最大化一个损失函数。优化器对损失函数的搜索过程对于神经网络性能至关重要。

作用:

  1. 参数更新:优化器通过计算损失函数相对于权重参数的梯度来确定更新参数的方向和步长。

  2. 收敛加速:高效的优化算法可以加快训练过程中损失函数的收敛速度。

  3. 避免陷入局部最优:一些优化器特别设计了策略(如动量),以帮助模型跳出局部最小值,寻找到更全局的最优解。

  4. 适应性调整:许多优化器可以自适应地调整学习率,使得训练过程中对不同的数据或参数具有不同的调整策略。

常用优化器有以下几种:

  1. 梯度下降(SGD):最基本的优化策略,它使用固定的学习率更新所有的权重。存在批量梯度下降(使用整个数据集计算梯度)、随机梯度下降(每个样本更新一次权重)和小批量梯度下降(mini-batch,每个小批量数据更新一次权重)。

    import torch
    import torch.nn as nn
    import torch.optim as optim# 假设我们有一个简单的模型
    model = nn.Sequential(nn.Linear(10, 5),nn.ReLU(),nn.Linear(5, 1)
    )# 定义损失函数,这里使用均方误差
    loss_fn = nn.MSELoss()# 定义优化器,使用 SGD 并设置学习率
    optimizer = optim.SGD(model.parameters(), lr=0.01)# 假定一个输入和目标输出
    input = torch.randn(64, 10)
    target = torch.randn(64, 1)# 运行模型训练流程
    for epoch in range(100): # 假设总共训练 100 轮# 正向传播,计算预测值output = model(input)# 计算损失loss = loss_fn(output, target)# 梯度清零,这一步很重要,否则梯度会累加optimizer.zero_grad()# 反向传播,计算梯度loss.backward()# 根据梯度更新模型参数optimizer.step()# 记录、打印损失或者使用损失进行其他操作

  2. 带动量的SGD(Momentum):在传统的梯度下降算法基础上,SGD Momentum考虑了梯度的历史信息,帮助优化器在正确的方向上加速,并且抑制震荡。

  3. Adagrad:自适应地为每个参数分配不同的学习率,从而提高了在稀疏数据上的性能。对于出现次数少的特征,会给予更大的学习率。

  4. RMSprop:对Adagrad进行改进,通过使用滑动平均的方式来更新学习率,解决了其学习率不断减小可能会提前停止学习的问题。

  5. Adam(Adaptive Moment Estimation):结合Momentum和RMSprop的概念,在Momentum的基础上计算梯度的一阶矩估计和二阶矩估计,进而进行参数更新。

    作用:自适应学习率调整:Adam算法通过自适应地调整每个参数的学习率,使得对于不同的参数,学习率能够根据其梯度的大小进行动态调整。这样能够更快地收敛到最优解,同时减少了手动调整学习率的需求。动量优化:Adam算法利用动量的概念来加速优化过程。动量能够帮助算法在参数空间中跨越局部极小值,从而加速收敛过程,并且可以在参数更新时减少梯度方向上的震荡。参数更新:Adam算法使用指数加权移动平均来估计每个参数的一阶矩(梯度的均值)和二阶矩(梯度的方差),然后根据这些估计值来更新参数。
    import torch
    import torch.nn as nn
    import torch.optim as optim# 定义一个简单的神经网络
    class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(784, 256)self.fc2 = nn.Linear(256, 128)self.fc3 = nn.Linear(128, 10)def forward(self, x):x = torch.flatten(x, 1)x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = self.fc3(x)return x# 初始化模型和Adam优化器
    model = Net()
    optimizer = optim.Adam(model.parameters(), lr=0.001)# 定义损失函数
    criterion = nn.CrossEntropyLoss()# 训练过程示例
    for epoch in range(num_epochs):for inputs, targets in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()在这个示例中,我们首先定义了一个简单的神经网络模型(包含三个全连接层),然后初始化了Adam优化器,将模型的参数传递给优化器。在训练过程中,我们在每个迭代周期中执行了模型的前向传播、损失计算、反向传播以及参数更新的操作。通过调用optimizer.step()来实现参数更新,Adam优化器会根据当前梯度自适应地调整学习率,并更新模型参数。

  6. Nadam:结合了Adam和Nesterov动量的优化器,它在计算当前梯度前先往前走一小步,用来修正未来的梯度方向。

  7. AdaDelta:是对Adagrad的扩展,减少了学习率递减的激进程度。

不同的优化器可能会对神经网络的训练效果产生较大影响,因此在实际应用中,我们通常会根据具体问题来选择最合适的优化器。实际选择时,往往需要进行试验,并通过验证集的性能来调整选择。

有人研究过几大优化器在一些经典任务上的表现。如下是在图像分类任务上,不同优化器的迭代次数和ACC间关系。


http://www.mrgr.cn/p/80361838

相关文章

华为OD机试 - 跳格子3 - 动态规划(Java 2024 C卷 200分)

华为OD机试 2024C卷题库疯狂收录中,刷题点这里 专栏导读 本专栏收录于《华为OD机试(JAVA)真题(A卷B卷C卷)》。 刷的越多,抽中的概率越大,每一题都有详细的答题思路、详细的代码注释、样例测试…

SpringBoot 打包所有依赖

SpringBoot 项目打包的时候可以通过插件 spring-boot-maven-plugin 来 repackage 项目,使得打的包中包含所有依赖,可以直接运行。例如: <plugins><plugin><groupId>org.springframework.boot</groupId><artifactId>spring-boot-maven-plugin&…

简单解决version GLIBC_2.34 not found,version GLIBC_2.25 not found

简单解决version GLIBC_2.34 not found,version GLIBC_2.25 not found 无需手动下载安装包编译 前言 很多博客都是要手动下载安装包进行编译升级,但这样很容易导致系统崩溃,本博文提供一个简单的方法,参考自博客1,博客2. 检查版本 strings /usr/lib64/libc.so.6 |grep GLI…

敏捷之Scrum开发

目录 一、什么是 Scrum 1.1 Scrum 的定义 二、Scrum 迭代开发过程 2.1 迭代开发过程说明 2.1.1 开发方法 2.1.1.1 增量模型 2.1.1.1.1 定义 2.1.1.1.2 模型方法说明 2.1.1.2 迭代模型 2.1.1.2.1 定义 2.1.1.2.2 模型方法说明 2.1.2 迭代过程 2.1.2.1 产品需求Produ…

简单解决version `GLIBC_2

简单解决version GLIBC_2.34 not found,version GLIBC_2.25 not found 无需手动下载安装包编译 前言 很多博客都是要手动下载安装包进行编译升级,但这样很容易导致系统崩溃,本博文提供一个简单的方法,参考自博客1,博客2. 检查版本 strings /usr/lib64/libc.so.6 |grep GLI…

HTML:认识HTML及基本语法

目录 1. HTML介绍 2. 关于软件选择和安装 3. HTML的基本语法 1. HTML介绍 HyperText Markup Language 简称HTML&#xff0c;意为&#xff1a;超文本标记语言 超文本&#xff1a;是指页面内可以包含的图片&#xff0c;链接&#xff0c;声音&#xff0c;视频等内容 标记&am…

初三奥赛模拟测试5

初三奥赛模拟测试5点击查看快读快写代码 #include <cstdio>using namespace std; // orz laofudasuan // modifiednamespace io {const int SIZE = (1 << 21) + 1;char ibuf[SIZE], *iS, *iT, obuf[SIZE], *oS = obuf, *oT = oS + SIZE - 1, c, qu[55]; int f, qr;…

栈_单向链表

利用单向链表设计一个栈,实现“后进先出”的功能 ​ 栈内存自顶向下进行递增,其实栈和顺序表以及链式表都一样,都属于线性结构,存储的数据的逻辑关系也是一对一的。 ​ 栈的一端是封闭的,数据的插入与删除只能在栈的另一端进行,也就是栈遵循“*后进先出*”的原则。也被成…

【国产NI替代】NI-9219 100 S/s/ch,4通道C系列通用模拟输入模块

100 S/s/ch&#xff0c;4通道C系列通用模拟输入模块 NI-9219专为多功能测试而设计。NI-9219可用于测量来自多种传感器&#xff08;如应变计&#xff0c;电阻温度检测器(RTD)&#xff0c;热电偶&#xff0c;测压元件和其他有源传感器等&#xff09;的信号&#xff0c;以及制作1…

VScode 无法连接云服务器

试了很多方法&#xff0c;比如更换VScode版本&#xff0c;卸载重装&#xff0c;删除配置文件 重启电脑&#xff0c;都无法成功。最后重置电脑后才连接上&#xff0c;但是重启服务器后又出现该问题。 方法一&#xff1a;修改环境 方法二&#xff1a;把vscode卸载干净重下

SQL Sever无法连接服务器

SQL Sever无法连接服务器&#xff0c;报错证书链是由不受信任的颁发机构颁发的 解决方法&#xff1a;不用ssl方式连接 1、点击弹框中按钮“选项” 2、连接安全加密选择可选 3、不勾选“信任服务器证书” 4、点击“连接”&#xff0c;可连接成功

vue 脚手架 创建vue3项目

创建项目 命令&#xff1a;vue create vue-element-plus 选择配置模式&#xff1a;手动选择模式 (上下键回车) 选择配置&#xff08;上下键空格回车&#xff09; 选择代码规范、规则检查和格式化方式: 选择语法检查方式 lint on save (保存就检查) 代码文件中有代码不符合 l…

【排课小工具】面向对象分析探索领域模型

用户向系统中输入课表模板、课程信息以及教师责任信息,系统以某种格式输出每个班级的课表。该用例中的主要参与者包括用户以及系统,除了上述两个主要参与者外,我们从该用例中抽取出可能有价值的名词:课表模板、课程、教师职责、班级以及课表。现在我们只知道下面图示的关系…

【Qt 专栏】Qt Creator 的 git 配置 上传到gitee

1.进入到Qt项目文件夹内,打开 “Git Bash Here” 2.初始化,在“Git Bash Here”中输入 git init 3.加入所有文件,在“Git Bash Here”中输入 git add . (需要注意,git add 后面还有一个点) 4.添加备注,git commit -m "备份" 5.推送本地仓库到gitee(需要事…

前端发起网络请求的几种常见方式(XMLHttpRequest、FetchApi、jQueryAjax、Axios)

摘要 前端发起网络请求的几种常见方式包括&#xff1a; XMLHttpRequest (XHR)&#xff1a; 这是最传统和最常见的方式之一。它允许客户端与服务器进行异步通信。XHR API 提供了一个在后台发送 HTTP 请求和接收响应的机制&#xff0c;使得页面能够在不刷新的情况下更新部分内容…

数字旅游:通过科技赋能,创新旅游服务模式,提供智能化、个性化的旅游服务,满足游客多元化、个性化的旅游需求

目录 一、数字旅游的概念与内涵 二、科技赋能数字旅游的创新实践 1、大数据技术的应用 2、人工智能技术的应用 3、物联网技术的应用 4、云计算技术的应用 三、智能化、个性化旅游服务的实现路径 1、提升旅游服务的智能化水平 2、实现旅游服务的个性化定制 四、数字旅…

报错“Please indicate a valid Swagger or OpenAPI version field”

报错“Please indicate a valid Swagger or OpenAPI version field” 报错信息Please indicate a valid Swagger or OpenAPI version field. Supported version fields are swagger: "2.0" and those that match openapi: 3.0.n (for example, openapi: 3.0.0). 原因…

安卓获取SHA

1&#xff1a;安卓通过签名key获取SHA 方式有两种&#xff0c; 1、电脑上来存在eclipse的用户或正在使用此开发工具的用户就简单了&#xff0c;直接利用eclipse 走打包流程&#xff0c;再打包的时候选择相应的签名&#xff0c;那么在当前面板的下面便会出现签名的相关信息。 2、…

【C++】封装哈希表 unordered_map和unordered_set容器

目录​​​​​​​ 一、unordered系列关联式容器 1、unordered_map 2、unordered_map的接口 3、unordered_set 二、哈希表的改造 三、哈希表的迭代器 1、const 迭代器 2、 operator 3、begin()/end() ​ 4、实现map[]运算符重载 四、封装 unordered_map 和 unordered_se…

visual studio2022,开发CMake项目添加rabbitmq库,连接到远程计算机并进行开发于调试

1.打开visual studio installer 。安装“用于 Windows 的 C CMake 工具” 2.新建CMake项目 3.点击VS的“工具”—>"选项“—>“跨平台”—>”连接管理器“,添加远程计算机。用来将VS编辑的代码传到服务器进行编译–连接—运行&#xff08;调试&#xff09;。 …