神经网络反向传播算法

news/2024/5/21 20:17:05

今天我们来看一下神经网络中的反向传播算法,之前介绍了梯度下降与正向传播~       神经网络的反向传播

专栏:💎实战PyTorch💎

反向传播算法(Back Propagation,简称BP)是一种用于训练神经网络的算法。 

反向传播算法是神经网络中非常重要的一个概念,它由Rumelhart、Hinton和Williams于1986年提出。这种算法基于梯度下降法来优化误差函数,利用了神经网络的层次结构来有效地计算梯度,从而更新网络中的权重和偏置。

基本工作流程:

  1. 通过正向传播得到误差,所谓正向传播指的是数据从输入到输出层,经过层层计算得到预测值,并利用损失函数得到预测值和真实值之前的误差。
  2. 通过反向传播把误差传递给模型的参数,从而对网络参数进行适当的调整,缩小预测值和真实值之间的误差。
  3. 反向传播算法是利用链式法则进行梯度求解,然后进行参数更新。对于复杂的复合函数,我们将其拆分为一系列的加减乘除或指数,对数,三角函数等初等函数,通过链式法则完成复合函数的求导。

我们通过一个例子来简单理解下 BP 算法进行网络参数更新的过程🧧:

如图我们在最下边输入两个维度的值进入神经网络:0.05、0.1 ,经过两个隐藏层(每层两个神经元),每个神经元有两个值,左边为输入值,右边是经过激活函数后的输出值;经过这个神经网络后的输出值为:m1、m2,实际值为0.01、0.99 🌠

设置的初始权重w1,w2,...w8分别为0.15、0.20、0.25、0.30、0.30、0.35、0.55、0.60

我们通过计算得到损失函数Error = 1/2 ((m1- target1)2 + (m2 - target2)2) = 0.2988

w5和w7均可以通过求三次导来求梯度,而w1,w3则不能直接通过L降序求导,我们需要求从L到m1,m1到o1,o1到k1,k1到h1,h1到w1:

由于w1是输出两个方向分别到o1和o2,所以是两个方向的梯度求和~

我们也发现所以激活函数都是要可微的~

其他的网络参数更新过程和上面的求导过程是一样的,这里就不过多赘述,我们直接看一下代码。

反向传播代码 

我们先来回顾一些Python中类的一些小细节:

🌈在Python中,使用super()函数可以调用父类的方法。这在子类中重写父类方法时非常有用,因为它允许你调用父类的实现,而不是完全覆盖它

class Parent:def __init__(self):print("Parent init")class Child(Parent):def __init__(self):super().__init__()print("Child init")c = Child()# 输出
Parent init
Child init

🌈当我们创建一个Child类的实例时,它会首先调用Parent类的__init__方法(通过super().__init__()),然后执行Child类的__init__方法,与类的__init__方法(构造方法)对应的类关闭时自动调用的方法是__del__方法。对象不再被使用时,Python解释器会自动调用这个方法。通常在这个方法中进行一些清理工作,比如释放资源、关闭文件等。

反向传播实现

import torch
import torch.nn as nn
import torch.optim as optimclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.linear1 = nn.Linear(2, 2)self.linear2 = nn.Linear(2, 2)# 网络参数初始化w1/w2/w3/w4self.linear1.weight.data = torch.tensor([[0.15, 0.20], [0.25, 0.30]])# w5/w6/w7/w8self.linear2.weight.data = torch.tensor([[0.40, 0.45], [0.50, 0.55]])# 截距bself.linear1.bias.data = torch.tensor([0.35, 0.35])self.linear2.bias.data = torch.tensor([0.60, 0.60])# 定义前向传播的行径def forward(self, x):x = self.linear1(x)x = torch.sigmoid(x)x = self.linear2(x)x = torch.sigmoid(x)return xif __name__ == '__main__':inputs = torch.tensor([[0.05, 0.10]])target = torch.tensor([[0.01, 0.99]])# 获得网络输出值net = Net()output = net(inputs)# print(output)  # tensor([[0.7514, 0.7729]], grad_fn=<SigmoidBackward>)# 计算误差loss = torch.sum((output - target) ** 2) / 2# print(loss)  # tensor(0.2984, grad_fn=<DivBackward0>)# 优化方法optimizer = optim.SGD(net.parameters(), lr=0.5)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 打印 w5、w7、w1 的梯度值print(net.linear1.weight.grad.data)# tensor([[0.0004, 0.0009],#         [0.0005, 0.0010]])print(net.linear2.weight.grad.data)# tensor([[ 0.0822,  0.0827],#         [-0.0226, -0.0227]])# 打印网络参数optimizer.step()print(net.state_dict())# OrderedDict([('linear1.weight', tensor([[0.1498, 0.1996], [0.2498, 0.2995]])),#              ('linear1.bias', tensor([0.3456, 0.3450])),#              ('linear2.weight', tensor([[0.3589, 0.4087], [0.5113, 0.5614]])),#              ('linear2.bias', tensor([0.5308, 0.6190]))])
  • optimizer.step() 相当于是将w和b所有参数更新一步的过程

🌈关于nn.Linear的使用

import torch
import torch.nn.functional as F
import torch.nn as nn# 均匀分布随机初始化linear = nn.Linear(5, 3)
# 从0-1均匀分布产生参数
nn.init.uniform_(linear.weight)
print(linear.weight.data)

nn.Linear是PyTorch中用于创建线性层的类,也被称为全连接层。它的主要作用是将输入数据与权重矩阵相乘并加上偏置,然后通常会通过一个非线性激活函数进行转换。 

  1. 在函数内部,创建一个线性层,输入维度为5,输出维度为3;
  2. 使用nn.init.uniform_()函数对线性层的权重进行均匀分布随机初始化;
  3. 打印线性层的权重数据。

http://www.mrgr.cn/p/58507166

相关文章

WebAuthn 无密码身份认证

文章目录 WebAuthn简介工作原理组成部分架构实现注册认证应用场景案例演示 WebAuthn简介 WebAuthn&#xff0c;全称 Web Authentication&#xff0c;是由 FIDO 联盟&#xff08;Fast IDentity Online Alliance&#xff09;和 W3C&#xff08;World Wide Web Consortium&#x…

多线程事务怎么回滚

1、背景介绍 1&#xff0c;最近有一个大数据量插入的操作入库的业务场景&#xff0c;需要先做一些其他修改操作&#xff0c;然后在执行插入操作&#xff0c;由于插入数据可能会很多&#xff0c;用到多线程去拆分数据并行处理来提高响应时间&#xff0c;如果有一个线程执行失败…

【网络知识系列】-- 换个角度理解计算机网络

换个角度理解计算机网络,搭建计网知识框架 所谓换个角度,就是从三层物理设备(物理层、数据链路层、网络层)开始,串联起整个网络的工作原理 可能有些小伙伴看见物理设备天生就犯困,反手就准备关闭文章,且慢!本文只是简单的介绍这几个设备的功能,并不会涉及复杂的底层硬…

C#学习笔记-字段、属性、索引器

字段字段表示与对象或者类型(类或结构体)关联的变量(成员变量),为对象或类型存储数据。与对象关联的字段称为“实例字段”,隶属于某个对象。与类型关联的字段称为“静态字段”,表示某一个类型当前的状态。静态字段使用 static 关键字修饰。字段在没有显示初始化的情况下…

力扣-203. 移除链表元素

1.题目 题目地址(203. 移除链表元素 - 力扣(LeetCode)) https://leetcode.cn/problems/remove-linked-list-elements/ 题目描述 给你一个链表的头节点 head 和一个整数 val ,请你删除链表中所有满足 Node.val == val 的节点,并返回 新的头节点 。示例 1:输入:head = [1,2…

SpringCloud学习笔记(二)Ribbon负载均衡、Nacos注册中心、Nacos与Eureka的区别

文章目录 4 Ribbon负载均衡4.1 负载均衡原理4.2 源码解读4.3 负载均衡策略4.3.1 内置的负载均衡策略4.3.2 自定义负载均衡策略4.3.2.1 方式一&#xff1a;定义IRule4.3.2.2 方式二&#xff1a;配置文件 4.4 饥饿加载 5 Nacos注册中心5.1 认识和安装Nacos5.2 服务注册到Nacos5.3…

【云原生】Docker 实践(三):使用 Dockerfile 文件构建镜像

Docker 实践&#xff08;三&#xff09;&#xff1a;使用 Dockerfile 文件构建镜像 1.使用 Dockerfile 文件构建镜像2.Dockerfile 文件详解 1.使用 Dockerfile 文件构建镜像 Dockerfile 是一个文本文件&#xff0c;其中包含了一条条的指令&#xff0c;每一条指令都用于构建镜像…

网络接收全流程

网卡简介 网卡是一块通信硬件。属于数据链路层。用户可以通过电缆或无线相互连接。每一个网卡都有一个独一无二的MAC地址(48位),它被写在卡上的一块ROM中。IEEE负责为网卡销售商分配唯一的MAC地址。 可以在终端运行sudo lshw -C network来查看网卡型号 可以在/lib/modules/$(u…

中科院突破:TalkingGaussian技术实现3D人脸动态无失真,高效同步嘴唇运动!

DeepVisionary 每日深度学习前沿科技推送&顶会论文分享&#xff0c;与你一起了解前沿深度学习信息&#xff01; 引言&#xff1a;探索高质量3D对话头像的新方法 在数字媒体和虚拟互动领域&#xff0c;高质量的3D对话头像技术正变得日益重要。这种技术能够在虚拟现实、电影…

Linux操作系统·进程管理

一、什么是进程 1.作业和进程的概念 Linux是一个多用户多任务的操作系统。多用户是指多个用户可以在同一时间使用计算机系统&#xff1b;多任务是指Linux可以同时执行几个任务&#xff0c;它可以在还未执行完一个任务时又执行另一项任务。为了完成这些任务&#xff0c;系统上…

《痞子衡嵌入式半月刊》 第 99 期

痞子衡嵌入式半月刊: 第 99 期这里分享嵌入式领域有用有趣的项目/工具以及一些热点新闻,农历年分二十四节气,希望在每个交节之日准时发布一期。 本期刊是开源项目(GitHub: JayHeng/pzh-mcu-bi-weekly),欢迎提交 issue,投稿或推荐你知道的嵌入式那些事儿。 上期回顾 :《…

Codeforces Round 942 Div.2 题解

ds 这么聪明的。蹭个热度,挽救一下 cnblogs 蒸蒸日上的阅读量。Q: 你是手速狗吗? A: 我觉得我是。2A 因为选的 \(w\) 一定可以让它合法,一次操作可以看作 \(a\) 数组向右平移一位。枚举操作次数后暴力判断即可。 #include <bits/stdc++.h>void work() {int n;std::cin…

linux下调试串口设备

USB转串口常用CH34x芯片,该芯片有linux下的驱动。 在默认情况下,大部分linux发行版都包含了CH34x的驱动,唯一缺点就是版本比较久。 可以先插上开发板, 一般是挂载到/dev/ttyCH341USB0文件下,如果该文件不存在,有两种可能,一种是驱动版本太久,可以下载官方的驱动文件,然…

Kafka 生产者应用解析

目录 1、生产者消息发送流程 1.1、发送原理 2、异步发送 API 2.1、普通异步发送 2.2、带回调函数的异步发送 3、同步发送 API 4、生产者分区 4.1、分区的优势 4.2、生产者发送消息的分区策略 示例1&#xff1a;将数据发往指定 partition 示例2&#xff1a;有 key 的…

Windows系统下将MySQL数据库表内的数据全量导入Elasticsearch

目录 下载安装Logstash 配置Logstash配置文件 运行配置文件 查看导入结果 使用Logstash将sql数据导入Elasticsearch 下载安装Logstash 官网地址 选择Windows系统&#xff0c;需下载与安装的Elasticsearch相同版本的&#xff0c;下载完成后解压安装包。 配置Logstash配…

xhs全参xs,xt,xscommon逆向分析

声明 本文章中所有内容仅供学习交流,抓包内容、敏感网址、数据接口均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关,若有侵权,请联系我立即删除! 目标网站 aHR0cHM6Ly93d3cueGlhb2hvbmdzaHUuY29tL2V4cGxvcmUvNjYyNDcxYzkwMDAwMDAwMDA0M…

做大模型产品,如何设计prompt?

做GenAI产品&#xff0c;除了要设计好的AI任务流程&#xff0c;合理的拆分业务以外&#xff0c;最重要的就是写好prompt&#xff0c;管理好prompt&#xff0c;持续迭代prompt。 prompt一般有两种形式&#xff1a;结构化prompt和对话式prompt。 结构化prompt的优点是通过规范的…

【记录】Python3| 将 PDF 转换成 HTML/XML(✅⭐⭐⭐⭐pdf2htmlEX)

本文将会被汇总至 【记录】Python3&#xff5c;2024年 PDF 转 XML 或 HTML 的第三方库的使用方式、测评过程以及对比结果&#xff08;汇总&#xff09;&#xff0c;更多其他工具请访问该文章查看。 文章目录 pdf2htmlEX 使用体验与评估1 安装指南2 测试代码3 测试结果3.1 转 HT…

BSP视频教程第30期:UDS ISO14229统一诊断服务CAN总线专题,常用诊断执行流程精讲,干货分享,图文并茂(2024-04-30)

视频教程汇总帖:https://www.armbbs.cn/forum.php?mod=viewthread&tid=110519 【前言】 1、继前面分享了CANopen和J1939的专题后,这次继续为大家分享UDS专题视频第1期。 2、统一诊断服务(Unified Diagnostic Services,简称UDS)是车用电子的通信协议,是电子控制器EC…

Reverse Card (Hard Version)

事情是这样的,我验了这一场 CF。显然我玩原神玩多了有一个很奇怪的、不能过的算法,哦,当然,在我本机可以过。为了展现自己的智慧糖,我写一下。 出题人是先发给我了一个限制都是 \(n\) 的,因此只有这个。\(n,m\) 改改就是了。 要求 \(1\le a\le n,1\le b\le n\) 满足\(a+b…