当前位置: 首页 > news >正文

ONNX加载和保存模型

ONNX

ONNX(Open Neural Network Exchange)是一个开放的格式,用于表示机器学习模型。它使得不同框架之间的模型可以互操作,方便模型的迁移和部署。以下是一些关于 ONNX 的基本介绍和使用方法。

在这里插入图片描述

  1. 模型转换:ONNX 允许你将模型从一个深度学习框架(如 PyTorch、TensorFlow)转换为 ONNX 格式。
  2. 互操作性:ONNX 模型可以在支持 ONNX 的不同平台和工具之间共享。
  3. 优化:ONNX 提供了工具来优化模型,以提高推理性能。

将模型转换为 ONNX 格式

以下是将 PyTorch 模型转换为 ONNX 模型的步骤:

  1. 安装 ONNX

安装了 ONNX 和相关的转换工具:

pip install onnx
pip install onnxruntime  # 用于运行 ONNX 模型
pip install torch  # PyTorch
  1. 转换 PyTorch 模型

一个已训练的 PyTorch 模型,可以使用以下代码将其转换为 ONNX 格式:

import torch
import torch.onnx
import torchvision.models as models# 加载预训练的 PyTorch 模型
model = models.resnet18(pretrained=True)
model.eval()  # 设置模型为推理模式# 创建示例输入张量
dummy_input = torch.randn(1, 3, 224, 224)# 将模型导出为 ONNX 格式
torch.onnx.export(model, dummy_input, "resnet18.onnx", verbose=True)

在这个示例中,将一个预训练的 ResNet-18 模型转换为 ONNX 格式并保存为 resnet18.onnx 文件。

加载和运行 ONNX 模型

使用 ONNX Runtime 来加载和运行转换后的 ONNX 模型:

import onnx
import onnxruntime as ort
import numpy as np# 加载 ONNX 模型
onnx_model = onnx.load("resnet18.onnx")
onnx.checker.check_model(onnx_model)  # 检查模型是否有效# 创建 ONNX Runtime 会话
ort_session = ort.InferenceSession("resnet18.onnx")# 创建输入数据
dummy_input = np.random.randn(1, 3, 224, 224).astype(np.float32)# 运行模型
outputs = ort_session.run(None, {"input": dummy_input})
print(outputs[0])

检查和优化 ONNX 模型

ONNX 提供了一些工具来检查和优化模型:

1. 检查模型

使用 onnx.checker 来验证模型的有效性:

import onnxonnx_model = onnx.load("resnet18.onnx")
onnx.checker.check_model(onnx_model)

2. 优化模型

使用 onnx.optimizer 来优化模型:

import onnx
import onnx.optimizeronnx_model = onnx.load("resnet18.onnx")# 定义优化通道
passes = ["fuse_consecutive_transposes", "eliminate_deadend"]# 优化模型
optimized_model = onnx.optimizer.optimize(onnx_model, passes)# 保存优化后的模型
onnx.save(optimized_model, "resnet18_optimized.onnx")

其他常用工具和库

  • Netron:用于可视化 ONNX 模型的工具。可以下载并使用 Netron 打开 .onnx 文件进行模型可视化。
  • ONNX Model Zoo:ONNX 模型库,包含许多预训练的 ONNX 模型,可以直接下载和使用。

小结

ONNX 作为一个开放的模型格式,可以极大地提高模型在不同框架和平台之间的可移植性。通过学习如何将模型转换为 ONNX 格式,并使用 ONNX Runtime 进行推理和优化,你可以更高效地部署和管理你的机器学习模型。


只有一个元素的时候才能够使用item()转为scalar,无论是一个0维度张量,还是1维张量,还是2维度

x_t = torch.tensor([1.0])
x2_t =torch.tensor(1.0)
x4_t = torch.tensor([[[1.0]]])x_n = x_t.item()       # 1.0
x2_n = x2_t.item()     # 1.0
x3_n = x3_t.item()     # 1.0

http://www.mrgr.cn/news/14057.html

相关文章:

  • 【零知识证明】MiMC哈希函数电路
  • [米联客-XILINX-H3_CZ08_7100] FPGA程序设计基础实验连载-11 UART串口接收驱动设计
  • 【FPGA】HDMI参数信息汇总
  • 宠物空气净化器哪款更值得推荐?希喂和352哪款更好?
  • 35岁零基础能转型AI大模型吗?
  • CSS 终于在 2024 年增加了垂直居中功能
  • Qt调用外部exe并嵌入到Qt界面中(验证成功的成功)
  • 如何解决:Failed to start jenkins.service: Unit not found.
  • P1009 【深基4,例7】阶乘之和
  • Java对象属性比较工具类(可用)
  • 【中秋特惠】南卡Runner Pro5:送给家人的科技健康礼!
  • 不用async与await将异步函数改为同步函数
  • 【递归回溯之floodfill算法专题练习】
  • 了解CSS中的BFC
  • 华为设备默认密码
  • Lombok组件的使用
  • E29.【C语言】练习:sizeof和strlen的习题集(A)
  • matlab 将数组从左向右翻转
  • 电子电气架构 --- 车载网简史(上)
  • 迷雾大陆辅助:VMOS云手机助力升级装备系统秘籍!