当前位置: 首页 > news >正文

Onnx使用预训练的 ResNet18 模型对输入图像进行分类,并将分类结果显示在图像上

目录

一、整体功能概述

二、函数分析

2.1 resnet() 函数:

2.2 pre_process(img_path) 函数:

2.3 loadOnnx(img_path) 函数:

三、代码执行流程


一、整体功能概述


这段代码实现了一个图像分类系统,使用预训练的 ResNet18 模型对输入图像进行分类,并将分类结果显示在图像上。它包括以下主要步骤:
读取一个包含类别名称和对应编号的文本文件,并将其存储在字典中。
定义了几个函数,包括模型导出函数 resnet()、图像预处理函数 pre_process() 和加载 ONNX 模型进行分类的函数 loadOnnx()。
在主程序中,指定输入图像路径,调用 loadOnnx() 函数对图像进行分类并显示结果。


二、函数分析


2.1 resnet() 函数:


使用 torchvision 中的预训练 ResNet18 模型,并设置为评估模式。
生成一个随机输入张量 x,并将模型导出为 ONNX 格式,保存为 models/resnet18.onnx 文件。

def resnet():model=models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)model.eval()x=torch.randn(1,3,224,224)torch.onnx.export(model,x,'models/resnet18.onnx',input_names=['input'],output_names=['output'])


2.2 pre_process(img_path) 函数:


读取输入图像 img_path。
调整图像大小为 224x224。
将图像颜色通道从 BGR 转换为 RGB。
对图像像素值进行归一化处理。
交换图像维度顺序,并增加一个维度。
返回预处理后的图像张量。

def pre_process(img_path):#h w c--->224,224,3#归一化#换轴#增加维度img=cv2.imread(img_path)scale_image=cv2.resize(img,dsize=(224,224))rgb_img=cv2.cvtColor(scale_image,cv2.COLOR_BGR2RGB)rgb_img=rgb_img/255rgb_img=np.transpose(rgb_img,(2,0,1))rgb_img=np.expand_dims(rgb_img,0).astype(np.float32)return rgb_img


2.3 loadOnnx(img_path) 函数:


创建一个 ONNX 推理会话,加载预导出的 ResNet18 ONNX 模型。

调用 pre_process() 函数对输入图像进行预处理。
准备输入数据并进行推理。
获取推理结果中概率最大的类别编号。
根据类别编号从字典中获取对应的类别名称,并进行翻译。
在输入图像上显示分类结果,并展示图像。

def loadOnnx(img_path):session=ort.InferenceSession(r'models\resnet18.onnx',providers=['CPUExecutionProvider'])img=pre_process(img_path)img_back=cv2.imread(img_path)intput_feed={'input':img}session_out=session.run(None,intput_feed)[0]out=np.argmax(session_out,axis=1)[0]res=str(out)# print(dict[res])ans=dict[res].split(',')[1].split(']')[0].strip()ans = translator.translate(ans)cv2.putText(img_back,ans,(100,100),fontFace=1,fontScale=2.0,color=(0,0,255),thickness=3,lineType=cv2.LINE_AA)cv2.imshow('win',img_back)cv2.waitKey(0)cv2.destroyAllWindows()print(ans)

完整代码如下

import cv2
import numpy as np
import torch
from torchvision import models
from torchvision.models import ResNet18_Weights
import onnxruntime as ort
from translate import Translator
translator=Translator(to_lang='Chinese')#翻译成中文
dict={}
with open('类别.txt','r',encoding='utf-8') as f:lines=f.readlines()for line in lines:name=line.split('\t')[0]value=line.split('\t')[1]dict[name]=value
# print(dict)
def resnet():model=models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)model.eval()x=torch.randn(1,3,224,224)torch.onnx.export(model,x,'models/resnet18.onnx',input_names=['input'],output_names=['output'])
def pre_process(img_path):#h w c--->224,224,3#归一化#换轴#增加维度img=cv2.imread(img_path)scale_image=cv2.resize(img,dsize=(224,224))rgb_img=cv2.cvtColor(scale_image,cv2.COLOR_BGR2RGB)rgb_img=rgb_img/255rgb_img=np.transpose(rgb_img,(2,0,1))rgb_img=np.expand_dims(rgb_img,0).astype(np.float32)return rgb_img#RGB
def loadOnnx(img_path):session=ort.InferenceSession(r'models\resnet18.onnx',providers=['CPUExecutionProvider'])img=pre_process(img_path)img_back=cv2.imread(img_path)intput_feed={'input':img}session_out=session.run(None,intput_feed)[0]out=np.argmax(session_out,axis=1)[0]res=str(out)# print(dict[res])ans=dict[res].split(',')[1].split(']')[0].strip()ans = translator.translate(ans)cv2.putText(img_back,ans,(100,100),fontFace=1,fontScale=2.0,color=(0,0,255),thickness=3,lineType=cv2.LINE_AA)cv2.imshow('win',img_back)cv2.waitKey(0)cv2.destroyAllWindows()print(ans)pass
if __name__ == '__main__':img_path='dog.png'# resnet()#导出模型loadOnnx(img_path)


三、代码执行流程


在 if __name__ == '__main__': 部分:
定义输入图像路径 img_path。
可以选择调用 resnet() 函数导出模型(注释状态,通常只在第一次运行或模型更新时使用)。
调用 loadOnnx(img_path) 函数对输入图像进行分类和显示结果。

 

 


http://www.mrgr.cn/news/6535.html

相关文章:

  • 代码随想录训练营 Day37打卡 动态规划 part05 完全背包理论基础 518. 零钱兑换II 377. 组合总和 Ⅳ 卡码70. 爬楼梯(进阶版)
  • Notification 分不同实例关闭
  • 什么是关键词难度?
  • RISC-V全志D1多媒体套件文章汇总
  • OCR识别行驶证(阿里云和百度云)
  • Axios 中的相关参数
  • 图论 最短路
  • webrtc ns 降噪之粉红噪声参数推导
  • 我们再次陷入软件危机
  • 提高实时多媒体传输效率的三大方法
  • io进程----标准io
  • 开源AI智能名片商城小程序在私域流量运营中的转化效率与ROI提升研究
  • DM8守护集群部署、数据同步验证、主备切换
  • PyQtGraph库的基本使用
  • 进程函数练习
  • Apache Doris安装部署
  • vue-cli搭建过程,elementUI搭建使用过程
  • Ubuntu下部署Hadoop集群+Hive(一)
  • 总结:Python语法
  • 喜报 | 麒麟信安“信创云桌面解决方案”在浙江省委党校应用实施,荣膺国家级示范案例