当前位置: 首页 > news >正文

使用预训练的 ONNX 格式的 YOLOv8n 模型进行目标检测,并在图像上绘制检测结果

目录

__init__方法:

pre_process方法:

run方法:

filter_boxes方法:

view_img方法:


​​​​​​​__init__方法:

    • 初始化类的实例时,创建一个onnxruntime的推理会话,加载名为yolov8n.onnx的模型,并指定使用 CPU 进行推理。
  1. pre_process方法:

    • 接受一个图像路径作为参数。
    • 读取图像并将其从 BGR 颜色空间转换为 RGB 颜色空间。
    • 计算图像的最大边长,创建一个全零的新图像,大小为最大边长的正方形,将原始图像复制到新图像中。
    • 将新图像调整为640x640的大小并归一化,然后增加一个维度并交换维度,以满足模型输入的要求。
    • 计算图像的缩放比例并返回预处理后的图像和缩放比例。
 def pre_process(self,img_path):img=cv2.imread(img_path)img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)max_edge=max(img.shape)h,w,c=img.shapeimg_back=np.zeros((max_edge,max_edge,3),dtype=np.float32)img_back[:h,:w]=imgimg_scale=cv2.resize(img_back,(640,640))/255img_scale=np.expand_dims(img_scale,axis=0)#升维度(1,640,640,3)img_scale=img_scale.transpose(0,3,1,2)#交换维度scale=max_edge/640return img_scale,scale
  1. run方法:

    • 接受一个图像路径作为参数。
    • 调用pre_process方法对图像进行预处理,得到预处理后的图像和缩放比例。
    • 使用预处理后的图像进行模型推理,得到输出结果。
    • 将输出结果传递给filter_boxes方法进行进一步处理。
 def run(self,img_path):img_process,scale=self.pre_process(img_path)input_name=self.session._inputs_meta[0].namesession_out=self.session.run(None,{input_name:img_process})[0][0]#(84,8400)session_out=session_out.transpose(1,0)#8400,84self.filter_boxes(session_out,scale)
  1. filter_boxes方法:

    • 接受模型输出结果和缩放比例作为参数。
    • 遍历模型输出的每一行,提取边界框信息(中心坐标、宽、高)和类别信息。
    • 根据边界框信息计算边界框的四个顶点坐标,并找到最大置信度的类别索引和置信度值。
    • 如果置信度大于 0.6,则将边界框信息、类别索引和置信度值分别添加到对应的列表中。
    • 调用view_img方法显示图像和检测结果。
    def filter_boxes(self,session_out,scale):#cx,cy,w,h,cls(80)boxes=[]confs=[]classes=[]rows=session_out.shape[0]for row in range(rows):infos = session_out[row]cx,cy,w,h=infos[:4]x1=(cx-w//2)*scaley1=(cy-h//2)*scalex2=(cx+w//2)*scaley2=(cy+h//2)*scalecls=infos[4:]idx=np.argmax(cls)conf=cls[idx]if conf>0.6:confs.append(conf)boxes.append((x1,y1,x2,y2))classes.append(idx)self.view_img(img_path,boxes,classes,confs)
  1. view_img方法:

    • 接受图像路径、边界框列表、类别列表和置信度列表作为参数。
    • 读取图像。
    • 遍历边界框列表,对于每个边界框,绘制在图像上,并打印类别和置信度信息。
    • 显示处理后的图像,并等待用户按下任意键退出程序,关闭所有窗口。
    def view_img(self,img_path,boxes,classes,confs):img=cv2.imread(img_path)size=len(boxes)for i in range(size):cls=classes[i]conf=confs[i]x1,y1,x2,y2=boxes[i]x1,y1,x2,y2=int(x1),int(y1),int(x2),int(y2)cv2.rectangle(img,(x1,y1),(x2,y2),color=(0,0,255),thickness=3,lineType=cv2.LINE_AA)print(f'cls={cls},conf={conf}')cv2.imshow('win', img)cv2.waitKey(0)cv2.destroyAllWindows()

 

还可以添加一个nms

 


http://www.mrgr.cn/news/5288.html

相关文章:

  • Linux离线安装fontconfig
  • 数据可视化大屏模板-美化图表
  • 数据库系统 第22节 事务隔离级别
  • 信刻光盘摆渡系统安全合规实现跨网数据单向导入/导出
  • 2024音频剪辑指南:探索四大高效工具!
  • 虚幻反射-
  • JavaSocket 网络编程之 UDP
  • 图像处理之:Video Processing Subsystem(三)
  • 身份证识别、护照OCR、python身份证四要素实名认证API
  • gpt-2语言模型训练
  • 物联网设备心跳源码-SAAS本地化及未来之窗行业应用跨平台架构
  • 标准库标头 <string_view> (C++17)学习
  • 5步掌握Python Django结合K-means算法进行豆瓣书籍可视化分析
  • LabVIEW深度监测系统
  • 数据结构--单链表
  • 多功能秒达工具箱全开源源码,可自部署且完全开源的中文工具箱
  • 投资伦敦银一般看什么点位做单?
  • sqlite3基本操作/数据库编程
  • uniapp中 使用 VUE3 组合式API 怎么接收上一个页面传递的参数
  • XSS-games