# 前言

本篇博客提供简单的 Python 脚本代码,实现 onnx 模型转换编译,保存为 TVM 的 .so .params .json 文件 。

望长城内外,惟余莽莽;大河上下,顿失滔滔。
--------------- 教员
------ 大家好啊 我是 暮冬 Z 羡慕

# Python 脚本实现模型编译和保存

脚本中需要修改的就一些路径,很容易看明白,就不再过多介绍了。


import onnx
from tvm.contrib.download import download_testdata
from PIL import Image
import numpy as np
import tvm.relay as relay
import tvm
from tvm.contrib import graph_executor


# 图片
img_path = "../image/imagenet_cat.png"
# img_url = "https://s3.amazonaws.com/model-server/inputs/kitten.jpg"
# img_path = download_testdata(img_url, "../image/imagenet_cat.png", module="data")

# 重设大小为 224x224
resized_image = Image.open(img_path).resize((224, 224))
img_data = np.array(resized_image).astype("float32")

# 输入图像是 HWC 布局,而 ONNX 需要 CHW 输入,所以转换数组
img_data = np.transpose(img_data, (2, 0, 1))

# 根据 ImageNet 输入规范进行归一化
imagenet_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
imagenet_stddev = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
norm_img_data = (img_data / 255 - imagenet_mean) / imagenet_stddev


# 添加 batch 维度,期望 4 维输入:NCHW。
img_data = np.expand_dims(norm_img_data, axis=0)
# 保存为 bin 文件  
norm_img_data.astype("float32").tofile("../image/imagenet_cat.bin")


# 目标设备配置
target = 'llvm'  # 以CPU为例

input_name = "data"
shape_dict = {input_name: img_data.shape}

onnx_model = onnx.load("../model/simple.onnx")

mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)

with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)


# 运行相关
dev = tvm.device(str(target), 0)
module = graph_executor.GraphModule(lib["default"](dev))

# 保存库文件
lib_fname = "../lib/mod.so"
lib.export_library(lib_fname)

# 保存模型参数
params_fname = "../lib/mod.params"
with open(params_fname, "wb") as param_file:
    param_file.write(relay.save_param_dict(lib.get_params()))

# 保存JSON格式的计算图
json_fname = "../lib/mod.json"
with open(json_fname, "w") as json_file:
    json_file.write(lib.get_executor_config())

dtype = "float32"
module.set_input(input_name, img_data)
module.run()
output_shape = (1, 10)
tvm_output = module.get_output(0, tvm.nd.empty(output_shape)).numpy()

from scipy.special import softmax

# 下载标签列表
labels_url = "https://s3.amazonaws.com/onnx-model-zoo/synset.txt"
labels_path = download_testdata(labels_url, "synset.txt", module="data")

with open(labels_path, "r") as f:
    labels = [l.rstrip() for l in f]

# 打开输出文件并读取输出张量
scores = softmax(tvm_output)    #   直接输出模型结果
scores = np.squeeze(tvm_output)
ranks = np.argsort(scores)[::-1]
for rank in ranks[0:5]:
    print("class='%s' with probability=%f" % (labels[rank], scores[rank]))

# 后记

本博客目前以及可预期的将来都不会支持评论功能。各位大侠如若有指教和问题,可以在我的 github 项目 或随便一个项目下提出 issue,或者知乎 私信,并指明哪一篇博客,我看到一定及时回复,感激不尽!

Edited on

Give me a cup of [coffee]~( ̄▽ ̄)~*

XianMu WeChat Pay

WeChat Pay

XianMu Alipay

Alipay