当前位置: 首页 > news >正文

mask和class_conf_mask的作用

代码 

import torch
import torch.nn.functional as F
import numpy as np# 模拟logits数据(随机生成)
batch_size = 4  # 假设有4个样本
class_num = 5  # 假设有5个类别# 模拟教师模型输出的logits,维度为 [batch_size, class_num]
logits_teacher_weak = torch.randn(batch_size, class_num)# 打印logits_teacher_weak的值
print("logits_teacher_weak:\n", logits_teacher_weak)# 将logits通过softmax转换为概率分布
pred_teacher_weak = F.softmax(logits_teacher_weak.detach(), dim=1)# 打印转换后的概率分布
print("\npred_teacher_weak (softmax output):\n", pred_teacher_weak)# 计算每个样本的最大概率值(即置信度)和伪标签
confidence, pseudo_labels = pred_teacher_weak.max(dim=1)# 打印每个样本的最大概率值(置信度)和伪标签
print("\nconfidence (max probabilities):\n", confidence)
print("\npseudo_labels (pseudo labels based on max probabilities):\n", pseudo_labels)# 分离置信度,确保它不参与梯度计算
confidence = confidence.detach()# 计算50%分位数(即中位数),作为阈值
conf_thresh = np.percentile(confidence.cpu().numpy().flatten(), 50
)# 打印置信度阈值
print("\nconf_thresh (confidence threshold, 50th percentile):\n", conf_thresh)# 生成mask,标记置信度小于等于阈值的样本
mask = confidence.le(conf_thresh).bool()# 打印mask
print("\nmask (samples with confidence <= threshold):\n", mask)# ============== 类别置信度部分 ==============
# 计算每个类别的置信度(所有样本在该类别上的预测概率之和)
class_confidence = torch.sum(pred_teacher_weak, dim=0)# 打印类别置信度
print("\nclass_confidence (sum of probabilities across samples for each class):\n", class_confidence)# 分离类别置信度,确保它不参与梯度计算
class_confidence = class_confidence.detach()# 计算类别置信度的50%分位数作为阈值
class_confidence_thresh = np.percentile(class_confidence.cpu().numpy().flatten(), 50
)# 打印类别置信度阈值
print("\nclass_confidence_thresh (class confidence threshold, 50th percentile):\n", class_confidence_thresh)# 生成class_conf_mask,标记类别置信度小于等于阈值的类别
class_conf_mask = class_confidence.le(class_confidence_thresh).bool()# 打印类别掩码
print("\nclass_conf_mask (classes with confidence <= threshold):\n", class_conf_mask)

逐行解释1

  • logits_teacher_weak = torch.randn(batch_size, class_num)

    • 随机生成一组教师模型的logits输出,模拟批次中每个样本对各个类别的预测。
    • 输出:一个形状为 [batch_size, class_num] 的tensor,表示教师模型在弱增强数据上的logits。
  • pred_teacher_weak = F.softmax(logits_teacher_weak.detach(), dim=1)

    • 使用 softmax 将logits转换为概率分布,dim=1 表示在类别维度上进行softmax操作。
    • 输出:每个样本在每个类别上的预测概率。
  • confidence, pseudo_labels = pred_teacher_weak.max(dim=1)

    • 找到每个样本预测概率的最大值(即置信度)以及对应的类别(伪标签)。
    • 输出
      • confidence:每个样本的最大概率(置信度)。
      • pseudo_labels:每个样本的伪标签(概率最大的类别)。
  • confidence = confidence.detach()

    • confidence 与计算图分离,确保它不参与梯度计算(为了防止反向传播时对置信度进行梯度更新)。
    • 输出:与之前相同的置信度,但不参与梯度计算。
  • conf_thresh = np.percentile(confidence.cpu().numpy().flatten(), 50)

    • 计算置信度的50%分位数(即中位数),作为后续生成掩码的阈值。
    • 输出:表示所有样本置信度的中位数。
  • mask = confidence.le(conf_thresh).bool()

    • 生成一个布尔掩码,标记置信度小于等于阈值的样本。le 表示小于等于操作。
    • 输出:一个布尔向量,True 表示该样本的置信度低于或等于阈值。

输出示例1

假设我们随机生成的 logits_teacher_weak 如下:

logits_teacher_weak:tensor([[ 1.5074, -0.3623, -0.3050,  0.7985,  1.3345],[-0.2311, -0.7265,  0.3124, -0.3955, -0.5713],[-0.0377,  0.0672, -0.6561, -0.3366, -1.2700],[ 0.3643, -0.2974,  0.4942, -1.3482,  0.1723]])

经过softmax转换后,得到概率分布:

pred_teacher_weak (softmax output):tensor([[0.4005, 0.0641, 0.0678, 0.1913, 0.2763],[0.2145, 0.1300, 0.3662, 0.1841, 0.1051],[0.2463, 0.2726, 0.1460, 0.1970, 0.1381],[0.2883, 0.1476, 0.3273, 0.0457, 0.1910]])

置信度和伪标签:

confidence (max probabilities):tensor([0.4005, 0.3662, 0.2726, 0.3273])pseudo_labels (pseudo labels based on max probabilities):tensor([0, 2, 1, 2])

置信度的阈值(50%分位数):

conf_thresh (confidence threshold, 50th percentile):0.3368

生成的mask:

mask (samples with confidence <= threshold):tensor([False, False,  True,  True])

逐行解释2

  1. class_confidence = torch.sum(pred_teacher_weak, dim=0)

    • 计算每个类别的总置信度,即将所有样本在该类别上的预测概率进行求和。结果是一个长度为 class_num 的向量,表示每个类别的总置信度。
    • 输出:每个类别的置信度总和。
  2. class_confidence = class_confidence.detach()

    • 将类别置信度与计算图分离,确保它不参与梯度计算。
  3. class_confidence_thresh = np.percentile(class_confidence.cpu().numpy().flatten(), 50)

    • 计算类别置信度的50%分位数,作为后续生成 class_conf_mask 的阈值。这个阈值用于区分置信度高和低的类别。
  4. class_conf_mask = class_confidence.le(class_confidence_thresh).bool()

    • 生成类别掩码 class_conf_mask,标记置信度低于或等于阈值的类别。le 表示小于等于操作。

输出示例2

假设我们随机生成的 logits_teacher_weak 如下:

logits_teacher_weak:tensor([[ 0.6654, -0.3170, -0.3315,  0.5557, -0.0610],[ 0.1992, -0.4481, -0.5696,  0.3045, -0.5566],[-0.2331,  1.5880,  1.1310,  1.1659,  0.0431],[-0.6817, -0.4727, -0.1713, -0.3666,  0.1745]])

经过 softmax 转换后:

pred_teacher_weak (softmax output):tensor([[0.3366, 0.1341, 0.1322, 0.3018, 0.0953],[0.2830, 0.1609, 0.1422, 0.3155, 0.0985],[0.0408, 0.3931, 0.2492, 0.2565, 0.0604],[0.1835, 0.2246, 0.3039, 0.2501, 0.0379]])

类别置信度和阈值如下:

class_confidence (sum of probabilities across samples for each class):tensor([0.8439, 0.9127, 0.8275, 1.1239, 0.2919])class_confidence_thresh (class confidence threshold, 50th percentile):0.8439

生成的 class_conf_mask

class_conf_mask (classes with confidence <= threshold):tensor([ True, False,  True, False,  True])

这个 class_conf_mask 表示第1、第3和第5类的置信度小于等于阈值,标记为 True


http://www.mrgr.cn/news/19644.html

相关文章:

  • 加密技术.
  • PHP7 的内核结构
  • 鸿蒙(API 12 Beta6版)图形【使用Drawing实现图形绘制与显示 (C/C++)】方舟2D图形服务
  • 开发经销商有哪些渠道和方法?不得不看的思路!
  • 智能提醒助理系列-AIGC模型如何对接公众号2-扣子
  • 前端项目开发之prettier安装和使用
  • 移动端视频编辑SDK解决方案,AI语音识别添加字幕
  • 微信小程序跳转到另一个微信小程序
  • 冒泡排序及qsort函数
  • React学习day05-Redux-概念、作用、安装、使用、action传参
  • 二叉搜索树【C++】
  • Leetcode Day21组合总和
  • 鸿蒙正则校验无效 - Harmony
  • 如何使用 ef core 的 code first(fluent api)模式实现自定义类型转换器?
  • 开源网安引领AIGC+开发安全,智能防护铸就软件安全新高度
  • CCSI: 用于无数据类别增量学习的持续类别特定印象|文献速递--基于深度学习的医学影像病灶分割
  • VS按F11不进函数调试
  • 在线Ascii码对照表,Ascii转换对照表
  • gradle和maven相比有什么相同点和区别?
  • PIM