当前位置: 首页 > news >正文

onnx和tensorrt使用过程中的一些代码梯子

tensorrt使用的一些脚本

  • 打印trt输入和输出的尺寸
  • 打印trt模型的推理速度
  • 打印onnx输入和输出的尺寸
  • 将包含多个输入和输出的onnx模型转化为trt
  • 合并两个onnx模型,并行合并

打印trt输入和输出的尺寸

# -*- coding: utf-8 -*-
# @Time    : 2024/8/23 19:56
# @Author  : sjh
# @Site    : 
# @File    : 打印trt输入和输出的尺寸.py
# @Comment :
import tensorrt as trt# 加载 TensorRT 引擎
TRT_LOGGER = trt.Logger(trt.Logger.INFO)def load_engine(engine_file_path):with open(engine_file_path, "rb") as f:runtime = trt.Runtime(TRT_LOGGER)engine = runtime.deserialize_cuda_engine(f.read())return engine# 打印引擎输入输出的详细信息
def print_engine_info(engine):print("Engine has {} bindings.".format(engine.num_bindings))for i in range(engine.num_bindings):binding_name = engine.get_binding_name(i)binding_shape = engine.get_binding_shape(i)binding_dtype = engine.get_binding_dtype(i)is_input = engine.binding_is_input(i)if is_input:print(f"Input {i}: Name = {binding_name}, Shape = {binding_shape}, DType = {binding_dtype}")else:print(f"Output {i}: Name = {binding_name}, Shape = {binding_shape}, DType = {binding_dtype}")# 加载并打印 combined_1.engine 的信息
engine = load_engine("combined_1.engine")
print_engine_info(engine)

打印trt模型的推理速度

import tensorrt as trt
import numpy as np
import pycuda.driver as cuda
import pycuda.autoinit
import timeTRT_LOGGER = trt.Logger(trt.Logger.ERROR)# 加载 TensorRT 引擎
def load_engine(engine_file_path):with open(engine_file_path, "rb") as f:runtime = trt.Runtime(TRT_LOGGER)engine = runtime.deserialize_cuda_engine(f.read())return engine# 进行推理并测量推理时间
def infer_and_measure_speed(engine, input_data_list, iterations=1000):context = engine.create_execution_context()# 获取输入和输出的数量num_bindings = engine.num_bindingsnum_inputs = sum([engine.binding_is_input(i) for i in range(num_bindings)])num_outputs = num_bindings - num_inputs# 获取输入和输出的名称input_names = [engine.get_tensor_name(i) for i in range(num_inputs)]output_names = [engine.get_tensor_name(i) for i in range(num_inputs, num_bindings)]# 设置每个输入的形状for i, input_data in enumerate(input_data_list):context.set_binding_shape(i, input_data.shape)# 打印调试信息# for i, input_name in enumerate(input_names):#     print(f"Input Tensor {i + 1} Name: {input_name}, Shape: {input_data_list[i].shape}")# for i, output_name in enumerate(output_names):#     print(f"Output Tensor {i + 1} Nam

http://www.mrgr.cn/news/50859.html

相关文章:

  • 单链表算法题(一)(超详细版)
  • 基于SpringBoot+Vue+MySQL的养老保险管理系统
  • C1. Adjust The Presentation (Easy Version) 双指针
  • 除毛好、噪音小的宠物空气净化器推荐?希喂、有哈、美的性能对比
  • 性能与体验登顶,海马云电脑重新定义行业,领跑未来工作与娱乐方式
  • 使用 Spring 框架构建 MVC 应用程序:初学者教程
  • MySQL基础(一)
  • 道路车辆功能安全 ISO 26262标准(4-3)—系统级产品开发
  • PHP 函数 func_num_args() 的作用
  • 编程练习7 5G网络建设
  • 初识Linux
  • DB-GPT 安装
  • 基于Leaflet的高德AOI数据在天地图底图可视化纠偏实践
  • 视觉的边界填充、数值计算和腐蚀操作
  • jeston nano配置虚拟环境记录
  • 每日OJ题_WY3小易的升级之路_数学模拟_C++_Java
  • 离宝安羊台山登山口最近的停车场探寻
  • 港大和字节提出长视频生成模型Loong,可生成具有一致外观、大运动动态和自然场景过渡的分钟级长视频。
  • 百度地图怎么上传店铺定位?
  • RK3568平台开发系列讲解(调试篇)嵌入式必备技能:万用表使用指南