当前位置: 首页 > news >正文

【mmsegmentation】Loss模块(进阶)自定义自己的LOSS

1、定义自己的loss

driving\models\losses\shuai_loss.py

import torch
from torch import nn
from mmseg.models import LOSSES@LOSSES.register_module()
class ShuaiLoss(nn.Module):def __init__(self,loss_weight=1.0):super().__init__()self.ce_loss = nn.CrossEntropyLoss()self.loss_weight = loss_weightdef forward(self,input,target,device_id='cpu',sample_ratio=1.0):loss = {}if len(target)==0:loss["cls_cost"] = torch.tensor(0.0,dtype=torch.float32,device=device_id)else:loss["cls_cost"] = self.ce_loss(input,target)loss["total_road_cls_loss"] = loss["cls_cost"] * self.loss_weight * sample_ratio # + other losses, if havereturn loss

看下LOSSES注册表(@LOSSES.register_module())
在这里插入图片描述

  • 可以看到ShuaiLoss可以被注册到LOSSES
  • 其实,这里的LOSSES是BACKBONES NECKS HEADS LOSSES SEGMENTORS的总和

2、调用Shuai_loss

if __name__ == "__main__":print("call shuai_loss:")from mmseg.models import build_loss# 1.配置 dictloss = dict(type='ShuaiLoss',loss_weight=1.0,loss_name='loss_shuai')# 从注册器中构建shuai_loss = build_loss(loss)# 使用shuai losspred = torch.Tensor([[0, 2, 3, 0], [0,2,3,0]])   # [2,4]target = torch.Tensor([[1, 1, 1, 0], [1,1,1,1]]) # [2,4]loss = shuai_loss(pred, target)print("loss:",loss)

在这里插入图片描述


http://www.mrgr.cn/news/41271.html

相关文章:

  • 【前沿 热点 顶会】NIPS/NeurIPS 2024中与强化学习有关的论文
  • 小程序echarts不滑动问题
  • 【STM32 Blue Pill编程实例】-SSD1306 OLED显示屏(I2C)
  • CSP-J模拟赛(2)补题报告
  • java中创建不可变集合
  • Pandas数据类型
  • 第2篇:Linux日志分析----应急响应之日志分析篇
  • 模版and初识vector
  • Java hashcode设计与实现
  • 听说这是MATLAB基础?
  • 开源黑科技!Fish Speech TTS模型完美支持8种语言
  • 数组与链表
  • 计算机网络(十) —— IP协议详解,理解运营商和全球网络
  • csp-j模拟二补题报告
  • 如何解决 Photoshop 中的“暂存盘已满”错误
  • 磁编码器磁铁要求和安装要求
  • 分散加载文件 scatter files
  • 大数据算法的思维
  • 开源AI智能名片链动2+1模式S2B2C商城小程序源码与工业4.0的融合发展:机遇与挑战
  • C++基础补充(02)C++其他控制语句break continue goto等