Unverified 提交 de9c25b3 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Use `export_formats()` in export.py (#6705)

* Use `export_formats()` in export.py * list fix
上级 a297efc3
...@@ -433,9 +433,12 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' ...@@ -433,9 +433,12 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
conf_thres=0.25 # TF.js NMS: confidence threshold conf_thres=0.25 # TF.js NMS: confidence threshold
): ):
t = time.time() t = time.time()
include = [x.lower() for x in include] include = [x.lower() for x in include] # to lowercase
tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs')) # TensorFlow exports formats = tuple(export_formats()['Argument'][1:]) # --include arguments
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) flags = [x in include for x in formats]
assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {formats}'
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights
# Load PyTorch model # Load PyTorch model
device = select_device(device) device = select_device(device)
...@@ -475,20 +478,19 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' ...@@ -475,20 +478,19 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
# Exports # Exports
f = [''] * 10 # exported filenames f = [''] * 10 # exported filenames
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
if 'torchscript' in include: if jit:
f[0] = export_torchscript(model, im, file, optimize) f[0] = export_torchscript(model, im, file, optimize)
if 'engine' in include: # TensorRT required before ONNX if engine: # TensorRT required before ONNX
f[1] = export_engine(model, im, file, train, half, simplify, workspace, verbose) f[1] = export_engine(model, im, file, train, half, simplify, workspace, verbose)
if ('onnx' in include) or ('openvino' in include): # OpenVINO requires ONNX if onnx or xml: # OpenVINO requires ONNX
f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify) f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify)
if 'openvino' in include: if xml: # OpenVINO
f[3] = export_openvino(model, im, file) f[3] = export_openvino(model, im, file)
if 'coreml' in include: if coreml:
_, f[4] = export_coreml(model, im, file) _, f[4] = export_coreml(model, im, file)
# TensorFlow Exports # TensorFlow Exports
if any(tf_exports): if any((saved_model, pb, tflite, edgetpu, tfjs)):
pb, tflite, edgetpu, tfjs = tf_exports[1:]
if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707 if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow` check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow`
assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.' assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论