当前位置: 首页 > news >正文

scene graph generation 计算mean recall数据的过程:

这里写目录标题

  • 前言:
  • 计算mean recall的详细过程
    • 1. **准备数据**:
    • 2. **计算每个类别的recall**:
    • 具体代码片段
      • 准备groundtruth数据
      • 准备预测数据
      • 计算recall
      • 计算mean recall

前言:

计算流程这里参考maskrcnn_benchmark/data/datasets/evaluation/vg/sgg_eval.py这个scene graph generation benchmark的github官网来完成相关的任务。

计算mean recall的详细过程

以下是如何利用预测三元组和groundtruth三元组计算mean recall的详细过程:

1. 准备数据

使用如下两个变量来分别保存groundtruth三元组和predicate的三元组

  • prepare_gt方法会处理groundtruth三元组数据,并将其存储在一个字典中。
  • prepare_pred方法会处理预测三元组数据,并将其存储在一个字典中。

2. 计算每个类别的recall

  • calculate_recall方法会遍历所有的groundtruth和predicate数据计算每个关系类别的recall
  • 对于每个类别,计算公式为:
    在这里插入图片描述
    其中,TP是True Positives,FN是False Negatives。(这句的意思看这句代码就理解了,即:float(len(match)) / float(gt_rels.shape[0])也就是说,(正确匹配的三元组)/所有groundtruth三元组
  1. 计算mean recall
    • calculate_mean_recall方法会计算所有类别的平均recall。

    • 首先,计算每个类别的recall

    • 然后,计算所有类别recall的平均值
      在这里插入图片描述

      其中,Recall_i是第i个关系类别的recall,N是类别的总数。

具体代码片段

以下是一些关键代码片段的解释:(这些片段是从github文件中专门拿出来的)

准备groundtruth数据

def prepare_gt(self):for gt in self.gts:gt_entry = {}gt_entry['relations'] = gt['relations']gt_entry['boxes'] = gt['boxes']gt_entry['labels'] = gt['labels']self.gt_entries.append(gt_entry)

准备预测数据

def prepare_pred(self):for pred in self.preds:pred_entry = {}pred_entry['relations'] = pred['relations']pred_entry['boxes'] = pred['boxes']pred_entry['labels'] = pred['labels']self.pred_entries.append(pred_entry)

计算recall

注意这里的TP和FN,就是上面公式了的TP和FN

def calculate_recall(self):for i, gt_entry in enumerate(self.gt_entries):pred_entry = self.pred_entries[i]for rel in gt_entry['relations']:gt_rel = (rel[0], rel[1], rel[2])if gt_rel in pred_entry['relations']:self.tp[rel[2]] += 1else:self.fn[rel[2]] += 1

计算mean recall

def calculate_mean_recall(self):recalls = []for i in range(self.num_rel_classes):if self.tp[i] + self.fn[i] > 0:recalls.append(self.tp[i] / (self.tp[i] + self.fn[i]))mean_recall = sum(recalls) / len(recalls)return mean_recall

通过这些方法,你可以计算出每个类别的recall,并进一步计算出mean recall。希望这些解释对你理解这段代码有所帮助!如果有更多问题,请随时问我。


http://www.mrgr.cn/news/24855.html

相关文章:

  • Java重修笔记 第五十四天 坦克大战(三)事件处理机制
  • 手写Promise
  • 贪心算法day29|134. 加油站(理解有难度)、135. 分发糖果、860. 柠檬水找零、406. 根据身高重建队列
  • 基于 PyTorch 和 TensorFlow 的口罩检测与人脸识别系统
  • 在 PyTorch 中,除了 pad_sequence 还有哪些其他处理序列数据的函数?时间序列数据 预处理
  • 什么是 PD 电压诱骗?
  • R语言统计分析——功效分析2(t检验,ANOVA)
  • 【 html+css 绚丽Loading 】000047 玄武流转盘
  • [综述笔记]Federated learning for medical image analysis: A survey
  • 二分思想与相关例题(上)
  • 可解释性人工智能(eXplainable Artificial Intelligence,XAI)
  • 无敌C++大王养成篇一
  • FreeRTOS学习(2)延时函数的封装
  • 初识Linux · 进程(2)
  • 利士策分享,如何制定合理的工作时长:寻找生活与工作的平衡点
  • 【C#生态园】提升C#开发效率:掌握这六款单元测试利器
  • 【OJ】关于顺序表的经典题目(移除数组中指定元素的值、数组去重、合并两个有序的数组)
  • 基于SpringBoot+Vue+MySQL的考研互助交流平台
  • 力扣sql五十题——连接
  • Codeforces Round 971 (Div. 4)——C题题解