当前位置: 首页 > news >正文

损失函数与反向传播

1.损失函数的作用

  • 计算实际输出和目标之间的差距
  • 为我们更新输出提供一定的依据(反向传播)

2.介绍几种官方文档中的损失函数

损失函数只能处理float类型的张量。

  • L1Loss (MAE):
    在这里插入图片描述
    在这里插入图片描述
import torch
from torch.nn import L1Lossinputs=torch.tensor([1,2,3],dtype=torch.float32)
targets=torch.tensor([1,2,5],dtype=torch.float32)inputs=torch.reshape(inputs,(1,1,1,3))
targets=torch.reshape(targets,(1,1,1,3))loss=L1Loss()
result=loss(inputs,targets)print(result)
  • MSELoss:
    在这里插入图片描述
loss_mse=nn.MSELoss()
result_mse=loss_mse(inputs,targets)
  • CrossEntropyLoss:
    该Loss算法计算输入对数与目标对数之间的交叉熵损失,在训练 C 类分类问题时非常有用。
    在这里插入图片描述
x=torch.tensor([0.1,0.2,0.3])
y=torch.tensor([1])
x=torch.reshape(x,(1,3))
loss_cross=nn.CrossEntropyLoss()
result_cross=loss_cross(x,y)

3.在神经网络中使用Loss Function

import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoaderdataset=torchvision.datasets.CIFAR10("data",train=False,transform=torchvision.transforms.ToTensor(),download=True)#每个批次中加载的数据项数量
dataloader=DataLoader(dataset,batch_size=1)class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()self.model1=Sequential(Conv2d(3,32,5,padding=2),MaxPool2d(2),Conv2d(32,32,5,padding=2),MaxPool2d(2),Conv2d(32,64,5,padding=2),MaxPool2d(2),Flatten(),Linear(1024,64),Linear(64,10))def forward(self, x):x=self.model1(x)return xloss=nn.CrossEntropyLoss()
tudui=Tudui()
for data in dataloader:imgs,targets = dataoutputs =tudui(imgs)result_loss=loss(outputs,targets)print(result_loss)

在这里插入图片描述

4.grad梯度

result_loss.backward()

loss=nn.CrossEntropyLoss()
tudui=Tudui()
for data in dataloader:imgs,targets = dataoutputs =tudui(imgs)result_loss=loss(outputs,targets)result_loss.backward()print("ok")

Debug
在这里插入图片描述
在这里插入图片描述

优化器就是根据grad中的值进行优化loss

http://www.mrgr.cn/news/13796.html

相关文章:

  • 009 下一代网络技术:SDN与虚拟化
  • 未戴安全帽算法检测源码样本安防监控视频分析未戴安全帽检测算法应用场景
  • 树结构与递归学习笔记二
  • Spark-RDD迭代器管道计算
  • “Ruby宝石匣:解锁流行插件系统的奥秘“
  • 适合跑步运动的蓝牙耳机推荐?盘点开放式耳机排行榜10强
  • HTML静态网页成品作业(HTML+CSS)——宠物狗店网页(1个页面)
  • 【GPT教我学】字节对象和字符对象
  • 【电控笔记z26】串级PID单环位置PID
  • HSE软件组件有哪些?如何实现HSE与主机的通信(同步/异步)?如何使用HSE提供的安全服务?
  • 米哈游(原神)一面算法原题
  • shell循环结构之while循环
  • 深入探索Python的`multiprocessing`模块:实现并行处理的实用指南
  • 【初阶数据结构】顺序表和链表算法题(下)
  • ADB 获取屏幕坐标,并模拟滑动和点击屏幕
  • C++ 两线交点程序(Program for Point of Intersection of Two Lines)
  • 数据仓库系列 2:数据仓库的核心特点是什么?
  • 解决Selenium已安装,在pycharm导入时报错
  • 如何将十六进制的乱码转换成汉字
  • Java 输入与输出之 NIO【非阻塞式IO】【NIO核心原理】探索之【一】