当前位置: 首页 > news >正文

用SAM2和Cutie模型目标追踪

一、数据集

视频:每个视频文件夹以图片帧的形式存储

box:给出每个视频第一帧要追踪的物体的box

二、将数据格式转换成SAM2所需要的格式

主要是将box转换成mask的格式,下面这个代码就是将box转换成mask的代码,具体转换原理如下:

box作为prompt输入进入SAM2,SAM2生成第一帧的mask,在mask上选五个positive point作为新的prompt,结合bbox作为新的prompt再过一遍SAM2 tracking

import os
import json
import torch
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import numpy as np
from PIL import Image# 初始化SAM2预测器
checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))# 图像和bbox目录路径
image_base_path = "/data/WebDAV/VidOR/video_frames_extract"  #图像帧路径
bbox_base_path = "/data/WebDAV/VidOR/all_training_bbox"  #bbox的 路径
output_mask_path = "/data/WebDAV/VidOR/mask_final" #
output_mask_path_initial = "/data/WebDAV/VidOR/mask_initial" #box作为prompt输入SAM2的结果# 创建保存mask的文件夹
os.makedirs(output_mask_path, exist_ok=True)
os.makedirs(output_mask_path_initial, exist_ok=True)# # 从bbox区域采样点(这里随机采样一些点作为示例)
# def sample_points_from_bbox(bbox, num_points=5):
#     xmin, ymin, xmax, ymax = bbox
#     x_coords = np.random.randint(xmin, xmax, size=num_points)
#     y_coords = np.random.randint(ymin, ymax, size=num_points)
#     return np.column_stack((x_coords, y_coords))# 从生成的掩码中采样正点
def sample_points_from_mask(mask, num_points=5):# 获取mask中所有正点的位置 (非零点)y_coords, x_coords = np.where(mask > 0)# 如果正点数量不足,取所有正点if len(x_coords) < num_points:sampled_indices = np.arange(len(x_coords))else:sampled_indices = np.random.choice(len(x_coords), num_points, replace=False)sampled_points = np.column_stack((x_coords[sampled_indices], y_coords[sampled_indices]))return sampled_points# 遍历video_frames_extract目录中的所有子文件夹
for video_segment in os.listdir(image_base_path):# 提取video_id和帧信息video_id, start_frame, end_frame = video_segment.split('_')if int(start_frame) == 0:start_frame = str(int(start_frame) + 1)# 获取视频段的第一帧路径first_frame_path = os.path.join(image_base_path, video_segment, "{:04d}.png".format(int(start_frame)))if not os.path.exists(first_frame_path):print(f"帧 {first_frame_path} 不存在,跳过该视频段.")continue# 打开第一帧图像image = Image.open(first_frame_path).convert("RGB")image_array = np.array(image)# 找到对应的bbox json文件json_file_path = os.path.join(bbox_base_path, f"{video_id}.json")if not os.path.exists(json_file_path):print(f"对应的bbox文件 {json_file_path} 不存在,跳过该视频段.")continue# 读取json文件获取bboxwith open(json_file_path, "r") as f:data = json.load(f)# 查找对应开始帧的bbox(假设trajectories中的索引与帧一一对应)frame_index = int(start_frame) - 1  # 帧索引从0开始,所以减去1if frame_index >= len(data['trajectories']):print(f"帧 {start_frame} 超出 {video_id} 的轨迹范围.")continueframe_bboxes = data['trajectories'][frame_index]# 对每个object生成maskwith torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):predictor.set_image(image_array)for obj in frame_bboxes:tid = obj['tid']bbox = obj['bbox']# Step 1: 使用bbox生成初步的SAM2 maskbbox_array = np.array([bbox['xmin'], bbox['ymin'], bbox['xmax'], bbox['ymax']])masks, _, _ = predictor.predict(box=bbox_array, multimask_output=False)# # 转换bbox为np数组# bbox_array = np.array([bbox['xmin'], bbox['ymin'], bbox['xmax'], bbox['ymax']])# 将mask转换为uint8类型initial_mask = (masks[0] * 255).astype(np.uint8)# 保存初步的maskinitial_mask_pil = Image.fromarray(initial_mask)initial_mask_filename = os.path.join(output_mask_path_initial, f"{video_id}_{start_frame}_{end_frame}_{tid}.png")initial_mask_pil.save(initial_mask_filename)# print(f"保存初步的mask: {initial_mask_filename}")# Step 2: 从生成的初步mask中采样正点sampled_points = sample_points_from_mask(initial_mask)point_labels = np.ones(len(sampled_points))  # 所有采样点都是正点# Step 3: 使用采样点、mask和bbox作为prompt,再次输入SAM2进行掩码生成final_masks, _, _ = predictor.predict(# mask_input=masks,box=bbox_array,point_coords=sampled_points,point_labels=point_labels,multimask_output=False)# # 从bbox区域采样正点# point_coords = sample_points_from_bbox(bbox_array, num_points=5)# point_labels = np.ones(len(point_coords))  # 所有点都标记为正点(标签为1)# # 生成mask# masks, _, _ = predictor.predict(box=bbox_array, multimask_output=False)# 将最终生成的mask保存final_mask = (final_masks[0] * 255).astype(np.uint8)final_mask_pil = Image.fromarray(final_mask)final_mask_filename = os.path.join(output_mask_path, f"{video_id}_{start_frame}_{end_frame}_{tid}.png")final_mask_pil.save(final_mask_filename)# print(f"保存最终的mask: {final_mask_filename}")

三、接下来将视频帧和mask输入进入SAM2进行推理就可以了

四、Cutie

在推理的时候主要有三个参数影响推理结果

1、size图像的大小,我的所有图像都是1080*1920大小的,但是我将其设置成800,再往后效果都是一样的,而且推理占的显存还是挺大的

2、max_mem_frames:30

3、min_mem_frames:28

2和3个参数是我在保证size一定的时候最大的值了,我用的显卡是80G的

还有cutie应该是有三个权重,但是mega那个权重是效果最好的

五、总结

总的来说,sam2的large权重是要比cutie效果好很多的。当然以上只是对我的数据集来说,具体效果还要根据实际情况来定


http://www.mrgr.cn/news/49424.html

相关文章:

  • 程序地址空间 -- 详解
  • Download Vmware Fusion (free for person)
  • Dev-Cpp 5.11 安装教程【保姆级】
  • vue3之生命周期钩子
  • 2024-10-13 NO.1 Quest3 激活教程
  • 千万级的大表,是如何产生的?
  • 【数据脱敏方案】不使用 AOP + 注解,使用 SpringBoot+YAML 实现
  • Mysql—高可用集群故障迁移
  • 【python入门到精通专题】8.装饰器
  • DVWA CSRF 漏洞实践报告
  • HL7协议简介及其在STM32上的解析实现
  • QT开发--QT SQL模块
  • Maven项目打包为jar的几种方式
  • 代码随想录 | Day32 | 回溯算法:排列问题
  • 电子电气架构---软件定义汽车,产业变革
  • iOS_图片加载优化
  • 考研C语言程序设计_语法相关(持续更新)
  • 2024系统架构师---试题二论软件维护方法及其应用
  • HTML(七)表格
  • AVL树实现