025.模型导出:从PyTorch到ONNX/TorchScript的实战与踩坑手记

张开发
2026/4/13 18:21:43 15 分钟阅读

分享文章

025.模型导出:从PyTorch到ONNX/TorchScript的实战与踩坑手记
一、深夜报警模型在训练端跑得好好的部署端直接崩了上周三凌晨两点手机突然狂震——生产环境的目标检测服务挂了。日志里赫然一行RuntimeError: Expected tensor for argument #1 input to have the same dimension训练时用PyTorch 1.8部署环境却是C推理服务。问题就出在模型导出这一步训练脚本里动态尺寸输入跑得欢导出成TorchScript后死活不认变长输入。这个坑让我熬了个通宵也让我彻底明白模型导出不是点一下torch.onnx.export()就完事的玄学操作而是连接训练与部署的关键桥梁。今天咱们就聊聊PyTorch模型导出的那些实战细节特别是ONNX和TorchScript这两个主流格式。我会把踩过的坑、绕过的路都摊开来写你跟着做能省下不少调试时间。二、TorchScriptPyTorch亲儿子的序列化方案先看TorchScript这是PyTorch自家的部署格式兼容性最好。导出有两种方式追踪Tracing和脚本化Scripting。2.1 追踪模式简单但有限制importtorchfrommodels.yoloimportModel# 你的YOLO模型类modelModel(yolov5s.yaml)# 加载配置model.load_state_dict(torch.load(best.pt)[model])# 加载权重model.eval()# 关键在这里准备一个示例输入example_inputtorch.rand(1,3,640,640)# 固定尺寸# 追踪导出traced_scripttorch.jit.trace(model,example_input)traced_script.save(yolo_traced.pt)# 测试一下withtorch.no_grad():outputtraced_script(torch.rand(1,3,640,640))print(output.shape)# 正常# output2 traced_script(torch.rand(1, 3, 320, 320)) # 这个会报错尺寸必须和example_input一致踩坑点1追踪模式会记录下example_input这个具体张量在模型里的流动路径。如果你的模型有动态控制流比如if-else分支依赖输入值追踪只会记录当时走的那条路其他分支就丢了。踩坑点2输入尺寸被写死。上面代码里用(1,3,640,640)导出的模型推理时就必须是这个尺寸。想支持动态尺寸得用脚本化。2.2 脚本化模式支持动态逻辑# 在模型类里加装饰器classModel(torch.nn.Module):def__init__(self):super().__init__()# 你的层定义torch.jit.export# 显式标记要导出的方法defforward(self,x):# 你的前向逻辑ifx.mean()0:# 动态控制流追踪模式处理不了returnself.path_a(x)else:returnself.path_b(x)# 脚本化导出scripted_modeltorch.jit.script(model)scripted_model.save(yolo_scripted.pt)脚本化会真正解析Python代码所以能处理条件分支、循环。但代价是你的模型代码必须符合TorchScript的语法子集。这意味着不能有复杂的Python类型注解用List[Tensor]别用list不能调用外部Python函数除非也用torch.jit.script装饰列表推导、字典操作受限个人习惯我通常先用追踪模式快速验证如果模型简单且输入尺寸固定这就够了。遇到动态逻辑或需要多尺寸支持时再忍痛改代码适配脚本化。三、ONNX生态更广的开放格式ONNX的优势在于跨框架PyTorch导出可以用TensorRT、OpenVINO、ONNX Runtime等各种后端推理。但导出过程更像走钢丝平衡不好就掉坑里。3.1 基础导出一堆参数要看准importtorch.onnx# 还是那个模型和示例输入model.eval()example_inputtorch.rand(1,3,640,640)# 导出核心调用torch.onnx.export(model,example_input,yolo.onnx,export_paramsTrue,# 把模型参数也导进去opset_version13,# 这个很重要版本不对算子可能不支持do_constant_foldingTrue,# 常量折叠优化一般开着input_names[images],# 输入节点名后面推理用output_names[output],# 输出节点名dynamic_axes{images:{0:batch,2:height,3:width},# 动态轴支持变尺寸output:{0:batch}})关键参数解读opset_versionONNX算子集版本。YOLOv5/v7用的Focus层需要opset11某些新算子需要更高版本。先查清楚你的模型用了什么特殊算子。dynamic_axes这是实现动态尺寸的关键。上面配置表示第0维batch、第2维高、第3维宽可变。如果你训练时用了多尺度这里必须设对。3.2 验证别等到部署才发现问题importonnximportonnxruntimeasort# 1. 检查模型格式是否正确onnx_modelonnx.load(yolo.onnx)onnx.checker.check_model(onnx_model)# 语法检查print(onnx.helper.printable_graph(onnx_model.graph))# 看一眼计算图# 2. 用ONNX Runtime推理测试ort_sessionort.InferenceSession(yolo.onnx)ort_inputs{ort_session.get_inputs()[0].name:example_input.numpy()}ort_outputsort_session.run(None,ort_inputs)# 3. 和PyTorch结果对比withtorch.no_grad():torch_outputmodel(example_input)importnumpyasnpprint(输出差值:,np.max(np.abs(ort_outputs[0]-torch_output.numpy())))# 应该很小1e-7量级常见坑验证通过但推理结果不对大概率是预处理/后处理没对齐。PyTorch里可能做了归一化/255.0导出时如果没包含进计算图部署端就得自己补上。四、YOLO模型导出的特殊处理YOLO系列模型导出时有几个高频坑点4.1 后处理别丢# 错误示范只导出主干网络classDetector(torch.nn.Module):defforward(self,x):featuresself.backbone(x)# 这里少了检测头的解码、NMSreturnfeatures# 这样导出的模型输出是原始特征图不是最终检测框# 正确做法把后处理打包如果推理引擎支持classDetectorWithPostprocess(torch.nn.Module):defforward(self,x):predself.model(x)# 原始输出boxes,scores,classesself.non_max_suppression(pred)# 包含NMSreturnboxes,scores,classes但注意很多部署框架如TensorRT有自己优化过的NMS算子。我通常做法是导出时不带NMS在部署端用框架的NMS实现性能更好。4.2 动态尺寸与批处理# 如果你想支持批量推理和变尺寸dynamic_axes{images:{0:batch_size,2:height,3:width},output:{0:batch_size}}实际部署时如果用了TensorRT动态尺寸会显著增加引擎构建时间。生产环境如果尺寸固定最好导出固定尺寸模型。4.3 自定义算子处理YOLOv5的Focus层切片拼接在opset 11以下不支持要么升级opset_version到11或者把Focus层替换为等价的Conv层YOLOv5官方提供了替换脚本# 替换Focus层的技巧frommodels.commonimportFocusclassMyModel(nn.Module):def__init__(self):super().__init__()# 训练时用Focusself.focusFocus(...)defforward(self,x):ifself.training:returnself.focus(x)else:# 导出时用等效卷积returnself.equivalent_conv(x)训练和导出用不同路径这是常见技巧。五、调试当导出失败时怎么办看错误栈ONNX导出失败会打印不支持的算子或操作照着改。简化模型从单层开始导逐步增加复杂度定位问题层。用torch.onnx.export(verboseTrue)打印计算图看哪里断了。查算子支持表ONNX的算子文档和PyTorch的torch.onnx文档都列了支持情况。我电脑里有个“导出失败记录.md”里面记着torch.tensor.tolist()在脚本化里不行得用torch.unbind()torch.arange()的步长参数在ONNX opset 11前后语法变了LSTM的hidden_size和input_size参数顺序容易搞反六、经验性建议导出前先冻住模型model.eval()是必须的但别忘了还有torch.no_grad()上下文。BatchNorm和Dropout层在训练和评估模式行为不同。版本对齐要死磕PyTorch版本、ONNX版本、推理框架版本TensorRT等的兼容性矩阵先查清楚再动手。我吃过亏本地1.8导出的ONNX生产环境1.7解析不了。留个PyTorch备份ONNX/TorchScript模型导出后一定保留原始的PyTorch模型权重.pt文件。哪天导出格式不兼容了还能用新工具重新导。动态尺寸是双刃剑开发阶段图方便开了动态尺寸生产环境如果尺寸固定建议重新导出静态模型推理速度能快20%以上。验证要全面别只测一张图。准备个小型测试集10-20张覆盖各种尺寸、宽高比对比PyTorch和导出模型的mAP差异。我见过导出后精度掉5个点的原因是某个激活函数量化异常。文档随代码走在导出脚本里用注释写明“此模型用opset 14导出支持动态高宽但不支持批量可变”。三个月后你自己都记不清。模型导出这事就像给训练好的模型做一次“脱水处理”——去掉训练时的冗余保留推理必需的骨架。刚开始会觉得束手束脚多踩几次坑就摸出门道了。记住一次成功的导出始于训练代码时就考虑部署约束。下次写模型时不妨先想想“这代码能顺利导出吗” 这会省去你无数个调试的深夜。

更多文章