CO-DETR追踪损失函数情况
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
文章目录
- 前言
- 寻找损失函数相关位置
前言
提示:这里可以添加本文要记录的大概内容:
参考之前的博客,记录co-detr训练推理过程,将自己的数据集转化为coco格式便于CO-DETR训练,现在因为训练后AP不高,需要找到loss函数,保存最好的pth文件,还可以做出loss函数和epoch的图像,观察loss函数是不是在稳定下降还是说波动比较剧烈。
提示:以下是本篇文章正文内容,下面案例可供参考
寻找损失函数相关位置
在train.py里点进去train_detector
,然后再点进去 runner.run(data_loaders, cfg.workflow)
中的run,到的class EpochBasedRunner(BaseRunner):
函数,在train
函数里添加记录loss的代码,在run
函数最后面添加画图的代码,具体代码如下:
def train(self, data_loader, **kwargs):best_loss = float('inf')self.model.train()self.mode = 'train'self.data_loader = data_loaderself._max_iters = self._max_epochs * len(self.data_loader)self.call_hook('before_train_epoch')time.sleep(2) # Prevent possible deadlock during epoch transitionepoch_loss = 0.0 # 每个 epoch 的损失初始化for i, data_batch in enumerate(self.data_loader):self.data_batch = data_batchself._inner_iter = iself.call_hook('before_train_iter')self.run_iter(data_batch, train_mode=True, **kwargs)self.call_hook('after_train_iter')del self.data_batchself._iter += 1# 假设损失在 log_buffer 中被记录if 'loss' in self.outputs['log_vars']:epoch_loss += self.outputs['log_vars']['loss']# 记录当前 epoch 的平均损失avg_loss = epoch_loss / len(data_loader)self.epoch_losses.append(avg_loss) # 添加到损失列表if avg_loss is not None and avg_loss < best_loss:best_loss = avg_lossself.save_checkpoint(self.work_dir, 'best_model.pth')self.call_hook('after_train_epoch')self._epoch += 1@torch.no_grad()def val(self, data_loader, **kwargs):self.model.eval()self.mode = 'val'self.data_loader = data_loaderself.call_hook('before_val_epoch')time.sleep(2) # Prevent possible deadlock during epoch transitionfor i, data_batch in enumerate(self.data_loader):self.data_batch = data_batchself._inner_iter = iself.call_hook('before_val_iter')self.run_iter(data_batch, train_mode=False)self.call_hook('after_val_iter')del self.data_batchself.call_hook('after_val_epoch')def run(self,data_loaders: List[DataLoader],workflow: List[Tuple[str, int]],max_epochs: Optional[int] = None,**kwargs) -> None:"""Start running.Args:data_loaders (list[:obj:`DataLoader`]): Dataloaders for trainingand validation.workflow (list[tuple]): A list of (phase, epochs) to specify therunning order and epochs. E.g, [('train', 2), ('val', 1)] meansrunning 2 epochs for training and 1 epoch for validation,iteratively."""assert isinstance(data_loaders, list)assert mmcv.is_list_of(workflow, tuple)assert len(data_loaders) == len(workflow)if max_epochs is not None:warnings.warn('setting max_epochs in run is deprecated, ''please set max_epochs in runner_config', DeprecationWarning)self._max_epochs = max_epochsassert self._max_epochs is not None, ('max_epochs must be specified during instantiation')for i, flow in enumerate(workflow):mode, epochs = flowif mode == 'train':self._max_iters = self._max_epochs * len(data_loaders[i])breakwork_dir = self.work_dir if self.work_dir is not None else 'NONE'self.logger.info('Start running, host: %s, work_dir: %s',get_host_info(), work_dir)self.logger.info('Hooks will be executed in the following order:\n%s',self.get_hook_info())self.logger.info('workflow: %s, max: %d epochs', workflow,self._max_epochs)self.call_hook('before_run')while self.epoch < self._max_epochs:for i, flow in enumerate(workflow):mode, epochs = flowif isinstance(mode, str): # self.train()if not hasattr(self, mode):raise ValueError(f'runner has no method named "{mode}" to run an ''epoch')epoch_runner = getattr(self, mode)else:raise TypeError('mode in workflow must be a str, but got {}'.format(type(mode)))for _ in range(epochs):if mode == 'train' and self.epoch >= self._max_epochs:breakepoch_runner(data_loaders[i], **kwargs)time.sleep(1) # wait for some hooks like loggers to finishself.call_hook('after_run')# 绘制损失图表plt.figure(figsize=(10, 5))plt.plot(range(1, len(self.epoch_losses) + 1), self.epoch_losses, marker='o', label='Training Loss')plt.title('Training Loss per Epoch')plt.xlabel('Epochs')plt.ylabel('Loss')plt.legend()plt.grid()plt.savefig(osp.join(self.work_dir, 'training_loss.png')) # 保存图像plt.show() # 显示图像