当前位置: 首页 > news >正文

AOT源码解析4.4 -decoder生成预测mask并计算loss

3、生成ref_imgs的预测mask和loss

这一步在训练阶段调用

3.1 数据处理

在这里插入图片描述

图1,如图1所示,将enc_embs的最后一个比例的特征图和有ref_imgs相关的特征图得到的LSTT特征图相拼接作为输入

        curr_enc_embs = self.curr_enc_embscurr_lstt_embs = self.curr_lstt_output[0]pred_id_logits = self.AOT.decode_id_logits(curr_lstt_embs,curr_enc_embs)

3.2 Decoder结构

在这里插入图片描述

图2, decoder的操作步骤如图,该解码器将enc_embs各个比例的特征图结合到一起

  • Decoder结构
class FPNSegmentationHead(nn.Module):def __init__(self,in_dim,out_dim,decode_intermediate_input=True,hidden_dim=256,shortcut_dims=[24, 32, 96, 1280],align_corners=True):super().__init__()self.align_corners = align_cornersself.decode_intermediate_input = decode_intermediate_inputself.conv_in = ConvGN(in_dim, hidden_dim, 1)self.conv_16x = ConvGN(hidden_dim, hidden_dim, 3)self.conv_8x = ConvGN(hidden_dim, hidden_dim // 2, 3)self.conv_4x = ConvGN(hidden_dim // 2, hidden_dim // 2, 3)self.adapter_16x = nn.Conv2d(shortcut_dims[-2], hidden_dim, 1)self.adapter_8x = nn.Conv2d(shortcut_dims[-3], hidden_dim, 1)self.adapter_4x = nn.Conv2d(shortcut_dims[-4], hidden_dim // 2, 1)self.conv_out = nn.Conv2d(hidden_dim // 2, out_dim, 1)self._init_weight()def forward(self, inputs, shortcuts):if self.decode_intermediate_input:x = torch.cat(inputs, dim=1)else:x = inputs[-1]x = F.relu_(self.conv_in(x))s1 = self.adapter_16x(shortcuts[-2])x = F.relu_(self.conv_16x(self.adapter_16x(shortcuts[-2]) + x))x = F.interpolate(x,size=shortcuts[-3].size()[-2:],mode="bilinear",align_corners=self.align_corners)x = F.relu_(self.conv_8x(self.adapter_8x(shortcuts[-3]) + x))x = F.interpolate(x,size=shortcuts[-4].size()[-2:],mode="bilinear",align_corners=self.align_corners)x = F.relu_(self.conv_4x(self.adapter_4x(shortcuts[-4]) + x))x = self.conv_out(x)return x

3.3 计算loss

在这里插入图片描述

  • 对Decoder输出的结果按照对象数量进行分隔
        pred_id_logits = self.pred_id_logitspred_id_logits = F.interpolate(pred_id_logits,size=gt_mask.size()[-2:],mode="bilinear",align_corners=self.align_corners)label_list = []logit_list = []for batch_idx, obj_num in enumerate(self.obj_nums):now_label = gt_mask[batch_idx].long()now_logit = pred_id_logits[batch_idx, :(obj_num + 1)].unsqueeze(0)label_list.append(now_label.long())logit_list.append(now_logit)
  • 计算loss

在深度学习中,尤其是在图像相关的任务(如图像分割)中,我们通常有大量的像素需要预测。在这种情况下,可能并不是所有的像素对最终的任务都同样重要。
例如,模型可能已经能够很好地预测图像的大部分区域,但是对于一些难以区分的区域(如物体边缘或小物体)预测得不够好。这些难以预测的区域可能正是模型需要关注的重点。

为了使模型更加关注这些难以预测的区域,可以采用一种称为“硬例挖掘”(hard example mining)的技术。这种方法的基本思想是,不是对所有的像素平均地计算损失,而是只关注那些损失最大的像素。

通过这种方式,模型的训练可以更加集中在那些难以正确预测的像素上,从而提高模型的整体性能。具体来说,“top k percent pixels” 指的是按照损失值从高到低排序后,选取前 k 百分比的像素。例如,如果 k 设置为 50%,那么在损失计算中,只会考虑损失最大的前 50% 的像素。

在代码中,这通常是通过以下步骤实现的:

  • 计算所有像素的损失。
  • 根据损失值对像素进行排序。
  • 选择损失值最高的前 k 百分比的像素。
  • 只计算这些选定像素的损失,并将它们加起来作为最终的损失。
class CrossEntropyLoss(nn.Module):def __init__(self,top_k_percent_pixels=None,hard_example_mining_step=100000):super(CrossEntropyLoss, self).__init__()self.top_k_percent_pixels = top_k_percent_pixelsif top_k_percent_pixels is not None:assert (top_k_percent_pixels > 0 and top_k_percent_pixels < 1)self.hard_example_mining_step = hard_example_mining_step + 1e-5if self.top_k_percent_pixels is None:self.celoss = nn.CrossEntropyLoss(ignore_index=255,reduction='mean')else:self.celoss = nn.CrossEntropyLoss(ignore_index=255,reduction='none')def forward(self, dic_tmp, y, step):total_loss = []for i in range(len(dic_tmp)):pred_logits = dic_tmp[i]gts = y[i]if self.top_k_percent_pixels is None:final_loss = self.celoss(pred_logits, gts)else:# Only compute the loss for top k percent pixels.# First, compute the loss for all pixels. Note we do not put the loss# to loss_collection and set reduction = None to keep the shape.num_pixels = float(pred_logits.size(2) * pred_logits.size(3))pred_logits = pred_logits.view(-1, pred_logits.size(1),pred_logits.size(2) * pred_logits.size(3))gts = gts.view(-1, gts.size(1) * gts.size(2))pixel_losses = self.celoss(pred_logits, gts)if self.hard_example_mining_step == 0:top_k_pixels = int(self.top_k_percent_pixels * num_pixels)else:ratio = min(1.0,step / float(self.hard_example_mining_step))top_k_pixels = int((ratio * self.top_k_percent_pixels +(1.0 - ratio)) * num_pixels)top_k_loss, top_k_indices = torch.topk(pixel_losses,k=top_k_pixels,dim=1)final_loss = torch.mean(top_k_loss)final_loss = final_loss.unsqueeze(0)total_loss.append(final_loss)total_loss = torch.cat(total_loss, dim=0)return total_loss

http://www.mrgr.cn/news/36728.html

相关文章:

  • 《Linux从小白到高手》开篇:脱胎换骨之为什么要深度学习Linux?
  • 一七零、GORM值为0或者空字符串的时候不能被更新创建的五种解决办法
  • 【JavaEE初阶】深入解析死锁的产生和避免以及内存不可见问题
  • electron使用npm install出现下载失败的问题
  • bat脚本的命名方式导致一个脚本不能使用的不明原因的怪事
  • 18 vue3之自动引入ref插件深入使用v-model
  • OceanBase云数据库战略实施两年,受零售、支付、制造行业青睐
  • Vue3 获取验证码按钮,倒计时60s
  • 「点击即复制!」——超实用 JavaScript 实现技巧
  • OpenHarmony(鸿蒙南向)——平台驱动指南【HDMI】
  • TopOn对话游戏魔客:2024移动游戏广告应如何突破?
  • 【GIS】Leaflet:Web地图快速上手
  • 外包干了1个多月,技术明显退步了。。。。。
  • 我在 Thoughtworks 被裁前后的经历
  • 用户体验分享 | YashanDB V23.2.3安装部署
  • 知识图谱检索 Graph-Based Retriever:文本块到结构化数据的转换,解决语义检索捕获不了的长尾关系
  • 283. 移动0
  • PHP之 实现https ssl证书到期提醒,通过企微发送消息
  • 江协科技STM32学习- P19 TIM编码器接口
  • 22.4k star,好用、强大的链路监控软件,skywalking