【tensorflow onnx】TensorFlow2导出ONNX及模型可视化教程
创始人
2024-05-30 20:03:37
0

文章目录

  • 1 背景介绍
  • 2 实验环境
  • 3 tf2onnx工具介绍
  • 4 代码实操
    • 4.1 TensorFlow2与ONNX模型导出
    • 4.2 ONNX正确性验证
    • 4.3 TensorFlow2与ONNX的一致性检查
    • 4.4 多输入的情况
    • 4.5 设定输入/输出节点
  • 5 ONNX模型可视化
  • 6 ir_version和opset_version修改
  • 7 ONNX输入输出维度修改
  • 8 致谢

原文来自于地平线开发者社区,未来会持续发布深度学习、板端部署的相关优质文章与视频,如果文章对您有帮助,麻烦给点个赞,如果您有兴趣一起学习,欢迎点个关注:寻找永不遗憾(CSDN用户名)

1 背景介绍

使用深度学习开源框架Pytorch训练完网络模型后,在部署之前通常需要进行格式转换,地平线工具链模型转换目前支持Caffe1.0和ONNX(opset_version=10/11 且 ir_version≤7)两种。ONNX(Open Neural Network Exchange)格式是一种常用的开源神经网络格式,被较多推理引擎支持,例如Pytorch、PaddlePaddle、TensorFlow等。本文将详细介绍如何将TensorFlow2得到的模型导出为ONNX格式。

2 实验环境

本教程的实验环境如下:

Python库Version
tensorflow-cpu2.11.0
tensorflow-intel2.11.0
tf2onnx1.13.0
protobuf3.20.2
onnx1.13.0
onnxruntime1.14.0

3 tf2onnx工具介绍

tf2onnx可以通过命令行的方式将TensorFlow/Keras的模型转换为ONNX,该工具的主要配置参数如下:

python -m tf2onnx.convert--saved-model          #以save-model方式保存的tf模型文件夹--output               #转换为ONNX格式的完整模型名称--opset                #默认为13,请手动配置10或11--inputs               #可选,用于指定导出的首节点--outputs              #可选,用于指定导出的尾节点

tf2onnx的更多详细介绍可以参考: https://github.com/onnx/tensorflow-onnx

4 代码实操

4.1 TensorFlow2与ONNX模型导出

以下代码展示了如何搭建一个简单分类模型以TensorFlow2的save-model方式保存并转换为ONNX格式。

import tensorflow as tf
import os
import onnxdef MyNet():input1 = tf.keras.layers.Input(shape=(7, 7, 3))x = tf.keras.layers.Conv2D(16, (3, 3),activation='relu',padding='same',name='conv1')(input1)x = tf.keras.layers.Conv2D(16, (3, 3),activation='relu',padding='same',name='conv2')(x)x = tf.keras.layers.Flatten(name='flatten')(x)x = tf.keras.layers.Dense(100, activation='relu', name='fc1')(x)output = tf.keras.layers.Dense(2, activation='softmax', name='predictions')(x)input_1 = input1model = tf.keras.models.Model(inputs=[input_1], outputs=output)return modelmodel = MyNet()#需要先使用model.save方法保存模型
model.save('model')
#调用tf2onnx将上一步保存的模型导出为ONNX
os.system("python -m tf2onnx.convert --saved-model model --output model.onnx --opset 11")

4.2 ONNX正确性验证

可以用以下代码验证ONNX模型的正确性,会检查模型的版本,图的结构,节点及输入输出。若输出为 Check: None 则表示无报错信息,模型导出正确。

import onnxonnx_model = onnx.load("./model.onnx")
check = onnx.checker.check_model(onnx_model)
print('Check: ', check)

4.3 TensorFlow2与ONNX的一致性检查

可以使用以下代码检查导出的ONNX模型和原始的PaddlePaddle模型是否有相同的计算结果。

import tensorflow as tf
import onnxruntime
import numpy as npinput1 = np.random.random((1, 7, 7, 3)).astype('float32')ort_sess = onnxruntime.InferenceSession("./model.onnx")
ort_inputs = {ort_sess.get_inputs()[0].name: input1}
ort_outs = ort_sess.run(None, ort_inputs)tf_model = tf.saved_model.load(export_dir="model")
tf_outs = tf_model(inputs=input1)print(ort_outs[0])
print(tf_outs.numpy())
np.testing.assert_allclose(tf_outs.numpy(), ort_outs[0], rtol=1e-03, atol=1e-05)
print("onnx model check finsh.")

4.4 多输入的情况

若您的模型存在多输入,则可参考下方代码以TensorFlow2的save-model方式保存并转换为ONNX格式。

import tensorflow as tf
import osdef MyNet():input1 = tf.keras.layers.Input(shape=(7, 7, 3))input2 = tf.keras.layers.Input(shape=(7, 7, 3))x = tf.keras.layers.Conv2D(16, (3, 3),activation='relu',padding='same',name='conv1')(input1)y = tf.keras.layers.Conv2D(16, (3, 3),activation='relu',padding='same',name='conv2')(input2)z = tf.keras.layers.Concatenate(axis=-1)([x, y])z = tf.keras.layers.Flatten(name='flatten')(z)z = tf.keras.layers.Dense(100, activation='relu', name='fc1')(z)output = tf.keras.layers.Dense(2, activation='softmax', name='predictions')(z)input_1 = input1input_2 = input2model = tf.keras.models.Model(inputs=[input_1,input_2], outputs=output)return modelmodel = MyNet()model.save('model')
os.system("python -m tf2onnx.convert --saved-model model --output model.onnx --opset 11")

4.5 设定输入/输出节点

有时考虑到部署难度,我们不希望TensorFlow网络结构的前后处理部分也导入进ONNX模型。此时可以使用tf2onnx工具的inputs和outputs参数,指定导出的首尾节点,这样首节点之前和尾节点之后的部分都不会导入进ONNX模型。

5 ONNX模型可视化

导出成ONNX模型后,可以使用开源可视化工具Netron来查看网络结构及相关配置信息。Netron的使用方式主要分为两种,一种是使用在线网页版 https://netron.app/ ,另一种是下载安装程序 https://github.com/lutzroeder/netron 。此教程中模型的可视化效果为:

6 ir_version和opset_version修改

地平线工具链支持的ONNX模型需要满足 opset_version=10/11 且 ir_version≤7,当拿到的ONNX模型不满足这两个要求时,可以修改代码重新导出,或者尝试编写脚本直接修改ONNX模型的对应属性,第二种方式的示例代码如下:

import onnxmodel = onnx.load("./model.onnx")
model.ir_version = 6
model.opset_import[0].version = 11
onnx.save_model(model, "./model_version.onnx")

注意: 高版本向低版本切换时可能会出现问题,这里只是一种可尝试的解决方案。

7 ONNX输入输出维度修改

当发现使用tf2onnx工具保存的ONNX模型的输入输出节点出现异常值时,比如以下情况:

在这里插入图片描述

可以使用如下代码进行修改:

import onnxonnx_model = onnx.load("./model.onnx")
onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_value = 1
onnx_model.graph.output[0].type.tensor_type.shape.dim[0].dim_value = 1
onnx.save(onnx_model, './model_dim.onnx')

打开保存的ONNX模型文件,可以看到输入输出节点的维度已经正常:
在这里插入图片描述

至此,该ONNX模型已满足地平线工具链的转换条件。

8 致谢

原文来自于地平线开发者社区,未来会持续发布深度学习、板端部署的相关优质文章与视频,如果文章对您有帮助,麻烦给点个赞,如果您有兴趣一起学习,欢迎点个关注:寻找永不遗憾(CSDN用户名)

相关内容

热门资讯

小学一年级招生标语 小学一年级招生标语  以下是一篇关于小学一年级招生标语,供大家参考,希望可以帮助到大家,小学一年级招...
大学辅导员评语 大学辅导员评语模板  在我们毕业前,辅导员都会给我书写评语。以下是小编整理好的大学辅导员评语模板,欢...
工地安全质量宣传标语 工地安全质量宣传标语(精选105句)  在平时的学习、工作或生活中,说到标语,大家肯定都不陌生吧,标...
卖鞋口号 卖鞋口号大全  篇一:鞋类广告语大全  零约束,轻松自由.------李宁舒适装备. 李宁鞋类  蓝...
脱贫攻坚的宣传标语 脱贫攻坚的宣传标语大全  1、精准识别贫困人口,扶贫攻坚致富幸福。  2、产业扶到户,致富有门路。 ...
《聪明的兔子》说课稿 《聪明的兔子》说课稿  作为一位杰出的教职工,常常需要准备说课稿,通过说课稿可以很好地改正讲课缺点。...
简短的班主任综合评语 2020年简短的班主任综合评语合集39条  吴东浩:该生平时比较内向,学习态度端正,上课听讲专心,能...
二年级数学《简单的推理》优秀... 二年级数学《简单的推理》优秀评课稿(精选9篇)  姜银平老师《简单的推理》这一堂课注重对学生学习兴趣...
《月亮湾》说课稿 《月亮湾》说课稿  作为一位不辞辛劳的人民教师,通常需要准备好一份说课稿,借助说课稿可以有效提高教学...
《百合花开》说课稿 《百合花开》说课稿范文  一、 说教材  1、 教材简析:  我说课的内容是是冀教版语文第十一册第三...
《班级公约》说课稿 《班级公约》说课稿范文  一、主题分析:  《幼儿园教育纲要》中指出:“培养幼儿良好行为习惯,注重潜...
《像山那样思考》说课稿 《像山那样思考》说课稿  一、说教学理念  我将在钱梦龙老师的“三主”教学理念的指导下开展我的教学:...
小学科学说课稿 小学科学说课稿  一、说课稿简析教材  教材是进行教学的评判凭据,是学生获取知识的重要来源。  ①教...
三级跳远的加油稿 三级跳远的加油稿10篇三级跳远的加油稿1  人生的路,有坦途,也有坎坷,做过的岁月,有欢笑,也有苦涩...
六年级《山中访友》说课稿 六年级《山中访友》说课稿  作为一位不辞辛劳的人民教师,很有必要精心设计一份说课稿,写说课稿能有效帮...
小学语文《为人民服务》评课稿 小学语文《为人民服务》评课稿  这几天在福州参加首届英才杯“智慧·互动·成长”全国青年教师风采展示大...
新教师代表发言稿 新教师代表发言稿尊敬的各位领导、老师:下午好!我是蔡各庄小学新入职教师---赵煜。首先,我十分感谢局...
《打弹弓的盲童》 说课稿 《打弹弓的盲童》 说课稿范文  一、说教材:  课文讲述了一个失明男孩在妈妈的鼓励和提示下用弹弓打破...
教师节学生国旗下讲话稿   敬爱的老师们,亲爱的同学们:  大家好!再过几天,我们将迎来教师节。在今天的升旗仪式上,让我们代...
《运动的快慢》的初中物理说课... 关于《运动的快慢》的初中物理说课稿  一.说教材分析  1.教材的地位、作用、分析  本节课所要讲授...