当前位置: 首页 > news >正文

CO-DETR追踪损失函数情况

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录

  • 前言
  • 寻找损失函数相关位置


前言

提示:这里可以添加本文要记录的大概内容:

参考之前的博客,记录co-detr训练推理过程,将自己的数据集转化为coco格式便于CO-DETR训练,现在因为训练后AP不高,需要找到loss函数,保存最好的pth文件,还可以做出loss函数和epoch的图像,观察loss函数是不是在稳定下降还是说波动比较剧烈。


提示:以下是本篇文章正文内容,下面案例可供参考

寻找损失函数相关位置

在train.py里点进去train_detector,然后再点进去 runner.run(data_loaders, cfg.workflow)中的run,到的class EpochBasedRunner(BaseRunner):函数,在train函数里添加记录loss的代码,在run函数最后面添加画图的代码,具体代码如下:

  def train(self, data_loader, **kwargs):best_loss = float('inf')self.model.train()self.mode = 'train'self.data_loader = data_loaderself._max_iters = self._max_epochs * len(self.data_loader)self.call_hook('before_train_epoch')time.sleep(2)  # Prevent possible deadlock during epoch transitionepoch_loss = 0.0  # 每个 epoch 的损失初始化for i, data_batch in enumerate(self.data_loader):self.data_batch = data_batchself._inner_iter = iself.call_hook('before_train_iter')self.run_iter(data_batch, train_mode=True, **kwargs)self.call_hook('after_train_iter')del self.data_batchself._iter += 1# 假设损失在 log_buffer 中被记录if 'loss' in self.outputs['log_vars']:epoch_loss += self.outputs['log_vars']['loss']# 记录当前 epoch 的平均损失avg_loss = epoch_loss / len(data_loader)self.epoch_losses.append(avg_loss)  # 添加到损失列表if avg_loss is not None and avg_loss < best_loss:best_loss = avg_lossself.save_checkpoint(self.work_dir, 'best_model.pth')self.call_hook('after_train_epoch')self._epoch += 1@torch.no_grad()def val(self, data_loader, **kwargs):self.model.eval()self.mode = 'val'self.data_loader = data_loaderself.call_hook('before_val_epoch')time.sleep(2)  # Prevent possible deadlock during epoch transitionfor i, data_batch in enumerate(self.data_loader):self.data_batch = data_batchself._inner_iter = iself.call_hook('before_val_iter')self.run_iter(data_batch, train_mode=False)self.call_hook('after_val_iter')del self.data_batchself.call_hook('after_val_epoch')def run(self,data_loaders: List[DataLoader],workflow: List[Tuple[str, int]],max_epochs: Optional[int] = None,**kwargs) -> None:"""Start running.Args:data_loaders (list[:obj:`DataLoader`]): Dataloaders for trainingand validation.workflow (list[tuple]): A list of (phase, epochs) to specify therunning order and epochs. E.g, [('train', 2), ('val', 1)] meansrunning 2 epochs for training and 1 epoch for validation,iteratively."""assert isinstance(data_loaders, list)assert mmcv.is_list_of(workflow, tuple)assert len(data_loaders) == len(workflow)if max_epochs is not None:warnings.warn('setting max_epochs in run is deprecated, ''please set max_epochs in runner_config', DeprecationWarning)self._max_epochs = max_epochsassert self._max_epochs is not None, ('max_epochs must be specified during instantiation')for i, flow in enumerate(workflow):mode, epochs = flowif mode == 'train':self._max_iters = self._max_epochs * len(data_loaders[i])breakwork_dir = self.work_dir if self.work_dir is not None else 'NONE'self.logger.info('Start running, host: %s, work_dir: %s',get_host_info(), work_dir)self.logger.info('Hooks will be executed in the following order:\n%s',self.get_hook_info())self.logger.info('workflow: %s, max: %d epochs', workflow,self._max_epochs)self.call_hook('before_run')while self.epoch < self._max_epochs:for i, flow in enumerate(workflow):mode, epochs = flowif isinstance(mode, str):  # self.train()if not hasattr(self, mode):raise ValueError(f'runner has no method named "{mode}" to run an ''epoch')epoch_runner = getattr(self, mode)else:raise TypeError('mode in workflow must be a str, but got {}'.format(type(mode)))for _ in range(epochs):if mode == 'train' and self.epoch >= self._max_epochs:breakepoch_runner(data_loaders[i], **kwargs)time.sleep(1)  # wait for some hooks like loggers to finishself.call_hook('after_run')# 绘制损失图表plt.figure(figsize=(10, 5))plt.plot(range(1, len(self.epoch_losses) + 1), self.epoch_losses, marker='o', label='Training Loss')plt.title('Training Loss per Epoch')plt.xlabel('Epochs')plt.ylabel('Loss')plt.legend()plt.grid()plt.savefig(osp.join(self.work_dir, 'training_loss.png'))  # 保存图像plt.show()  # 显示图像

http://www.mrgr.cn/news/39186.html

相关文章:

  • 谷歌收录批量查询,如何批量查询谷歌收录以及提交网站进行收录的方法
  • 服务器开通个人账户
  • 初识Linux以及Linux的基本命令
  • UFS 3.1架构简介
  • 关于git分支冲突问题
  • Dynamics 365 dependency EntityType
  • 古代帝王与啤酒的不解之缘
  • 如果MySQL已经安装但mysql --version命令不好用,怎么办?
  • MySQL的驱动安装
  • Python——内置字符串操作与转换函数
  • 配置STM32F103的高级定时器TIM1用于PWM功能
  • 三数之和为0
  • 跨境电商中的IP关联及其防范策略
  • vector
  • 光伏设计难点在哪儿?如何解决?
  • Excel表列名称
  • 【QT Quick】基础语法:导入外部QML文件
  • 深度学习新手必备:Easy-PyTorch 助你轻松入门 PyTorch
  • 6.824 Lab 2C 学习记录
  • 语言的重定向