Unverified 提交 d0fa0042 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

New `@try_export` decorator (#9096)

* New export decorator * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * New export decorator * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rename fcn to func * rename to @try_export Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
上级 eab35f66
...@@ -67,8 +67,8 @@ if platform.system() != 'Windows': ...@@ -67,8 +67,8 @@ if platform.system() != 'Windows':
from models.experimental import attempt_load from models.experimental import attempt_load
from models.yolo import Detect from models.yolo import Detect
from utils.dataloaders import LoadImages from utils.dataloaders import LoadImages
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, check_yaml, from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,
colorstr, file_size, print_args, url2file) check_yaml, colorstr, file_size, get_default_args, print_args, url2file)
from utils.torch_utils import select_device, smart_inference_mode from utils.torch_utils import select_device, smart_inference_mode
...@@ -89,9 +89,27 @@ def export_formats(): ...@@ -89,9 +89,27 @@ def export_formats():
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU']) return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
def try_export(inner_func):
# YOLOv5 export decorator, i..e @try_export
inner_args = get_default_args(inner_func)
def outer_func(*args, **kwargs):
prefix = inner_args['prefix']
try:
with Profile() as dt:
f, model = inner_func(*args, **kwargs)
LOGGER.info(f'{prefix} export success ✅ {dt.t:.1f}s, saved as {f} ({file_size(f):.1f} MB)')
return f, model
except Exception as e:
LOGGER.info(f'{prefix} export failure ❌ {dt.t:.1f}s: {e}')
return None, None
return outer_func
@try_export
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')): def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
# YOLOv5 TorchScript model export # YOLOv5 TorchScript model export
try:
LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...') LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
f = file.with_suffix('.torchscript') f = file.with_suffix('.torchscript')
...@@ -102,16 +120,12 @@ def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:' ...@@ -102,16 +120,12 @@ def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:'
optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files) optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
else: else:
ts.save(str(f), _extra_files=extra_files) ts.save(str(f), _extra_files=extra_files)
return f, None
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return f
except Exception as e:
LOGGER.info(f'{prefix} export failure: {e}')
@try_export
def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')): def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
# YOLOv5 ONNX export # YOLOv5 ONNX export
try:
check_requirements(('onnx',)) check_requirements(('onnx',))
import onnx import onnx
...@@ -162,15 +176,12 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst ...@@ -162,15 +176,12 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst
onnx.save(model_onnx, f) onnx.save(model_onnx, f)
except Exception as e: except Exception as e:
LOGGER.info(f'{prefix} simplifier failure: {e}') LOGGER.info(f'{prefix} simplifier failure: {e}')
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') return f, model_onnx
return f
except Exception as e:
LOGGER.info(f'{prefix} export failure: {e}')
@try_export
def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')): def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')):
# YOLOv5 OpenVINO export # YOLOv5 OpenVINO export
try:
check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/ check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
import openvino.inference_engine as ie import openvino.inference_engine as ie
...@@ -181,16 +192,12 @@ def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')): ...@@ -181,16 +192,12 @@ def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')):
subprocess.check_output(cmd.split()) # export subprocess.check_output(cmd.split()) # export
with open(Path(f) / file.with_suffix('.yaml').name, 'w') as g: with open(Path(f) / file.with_suffix('.yaml').name, 'w') as g:
yaml.dump({'stride': int(max(model.stride)), 'names': model.names}, g) # add metadata.yaml yaml.dump({'stride': int(max(model.stride)), 'names': model.names}, g) # add metadata.yaml
return f, None
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return f
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')
@try_export
def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')): def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):
# YOLOv5 CoreML export # YOLOv5 CoreML export
try:
check_requirements(('coremltools',)) check_requirements(('coremltools',))
import coremltools as ct import coremltools as ct
...@@ -208,18 +215,12 @@ def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')): ...@@ -208,18 +215,12 @@ def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):
else: else:
print(f'{prefix} quantization only supported on macOS, skipping...') print(f'{prefix} quantization only supported on macOS, skipping...')
ct_model.save(f) ct_model.save(f)
return f, ct_model
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return ct_model, f
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')
return None, None
@try_export
def export_engine(model, im, file, train, half, dynamic, simplify, workspace=4, verbose=False): def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
prefix = colorstr('TensorRT:')
try:
assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`' assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
try: try:
import tensorrt as trt import tensorrt as trt
...@@ -231,11 +232,11 @@ def export_engine(model, im, file, train, half, dynamic, simplify, workspace=4, ...@@ -231,11 +232,11 @@ def export_engine(model, im, file, train, half, dynamic, simplify, workspace=4,
if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012 if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
grid = model.model[-1].anchor_grid grid = model.model[-1].anchor_grid
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid] model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
export_onnx(model, im, file, 12, train, dynamic, simplify) # opset 12 export_onnx(model, im, file, 12, False, dynamic, simplify) # opset 12
model.model[-1].anchor_grid = grid model.model[-1].anchor_grid = grid
else: # TensorRT >= 8 else: # TensorRT >= 8
check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0 check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
export_onnx(model, im, file, 13, train, dynamic, simplify) # opset 13 export_onnx(model, im, file, 13, False, dynamic, simplify) # opset 13
onnx = file.with_suffix('.onnx') onnx = file.with_suffix('.onnx')
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...') LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
...@@ -277,12 +278,10 @@ def export_engine(model, im, file, train, half, dynamic, simplify, workspace=4, ...@@ -277,12 +278,10 @@ def export_engine(model, im, file, train, half, dynamic, simplify, workspace=4,
config.set_flag(trt.BuilderFlag.FP16) config.set_flag(trt.BuilderFlag.FP16)
with builder.build_engine(network, config) as engine, open(f, 'wb') as t: with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
t.write(engine.serialize()) t.write(engine.serialize())
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') return f, None
return f
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')
@try_export
def export_saved_model(model, def export_saved_model(model,
im, im,
file, file,
...@@ -296,11 +295,10 @@ def export_saved_model(model, ...@@ -296,11 +295,10 @@ def export_saved_model(model,
keras=False, keras=False,
prefix=colorstr('TensorFlow SavedModel:')): prefix=colorstr('TensorFlow SavedModel:')):
# YOLOv5 TensorFlow SavedModel export # YOLOv5 TensorFlow SavedModel export
try:
import tensorflow as tf import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
from models.tf import TFDetect, TFModel from models.tf import TFModel
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
f = str(file).replace('.pt', '_saved_model') f = str(file).replace('.pt', '_saved_model')
...@@ -326,18 +324,14 @@ def export_saved_model(model, ...@@ -326,18 +324,14 @@ def export_saved_model(model,
tfm.__call__(im) tfm.__call__(im)
tf.saved_model.save(tfm, tf.saved_model.save(tfm,
f, f,
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if check_version(
if check_version(tf.__version__, '2.6') else tf.saved_model.SaveOptions()) tf.__version__, '2.6') else tf.saved_model.SaveOptions())
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') return f, keras_model
return keras_model, f
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')
return None, None
@try_export
def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')): def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')):
# YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow # YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
try:
import tensorflow as tf import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
...@@ -349,16 +343,12 @@ def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')): ...@@ -349,16 +343,12 @@ def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')):
frozen_func = convert_variables_to_constants_v2(m) frozen_func = convert_variables_to_constants_v2(m)
frozen_func.graph.as_graph_def() frozen_func.graph.as_graph_def()
tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False) tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
return f, None
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return f
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')
@try_export
def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')): def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
# YOLOv5 TensorFlow Lite export # YOLOv5 TensorFlow Lite export
try:
import tensorflow as tf import tensorflow as tf
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
...@@ -384,15 +374,12 @@ def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=c ...@@ -384,15 +374,12 @@ def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=c
tflite_model = converter.convert() tflite_model = converter.convert()
open(f, "wb").write(tflite_model) open(f, "wb").write(tflite_model)
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') return f, None
return f
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')
@try_export
def export_edgetpu(file, prefix=colorstr('Edge TPU:')): def export_edgetpu(file, prefix=colorstr('Edge TPU:')):
# YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/ # YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
try:
cmd = 'edgetpu_compiler --version' cmd = 'edgetpu_compiler --version'
help_url = 'https://coral.ai/docs/edgetpu/compiler/' help_url = 'https://coral.ai/docs/edgetpu/compiler/'
assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}' assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'
...@@ -412,16 +399,12 @@ def export_edgetpu(file, prefix=colorstr('Edge TPU:')): ...@@ -412,16 +399,12 @@ def export_edgetpu(file, prefix=colorstr('Edge TPU:')):
cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {file.parent} {f_tfl}" cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {file.parent} {f_tfl}"
subprocess.run(cmd.split(), check=True) subprocess.run(cmd.split(), check=True)
return f, None
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return f
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')
@try_export
def export_tfjs(file, prefix=colorstr('TensorFlow.js:')): def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
# YOLOv5 TensorFlow.js export # YOLOv5 TensorFlow.js export
try:
check_requirements(('tensorflowjs',)) check_requirements(('tensorflowjs',))
import re import re
...@@ -447,11 +430,7 @@ def export_tfjs(file, prefix=colorstr('TensorFlow.js:')): ...@@ -447,11 +430,7 @@ def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
r'"Identity_2": {"name": "Identity_2"}, ' r'"Identity_2": {"name": "Identity_2"}, '
r'"Identity_3": {"name": "Identity_3"}}}', json) r'"Identity_3": {"name": "Identity_3"}}}', json)
j.write(subst) j.write(subst)
return f, None
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return f
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')
@smart_inference_mode() @smart_inference_mode()
...@@ -524,22 +503,22 @@ def run( ...@@ -524,22 +503,22 @@ def run(
f = [''] * 10 # exported filenames f = [''] * 10 # exported filenames
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
if jit: if jit:
f[0] = export_torchscript(model, im, file, optimize) f[0], _ = export_torchscript(model, im, file, optimize)
if engine: # TensorRT required before ONNX if engine: # TensorRT required before ONNX
f[1] = export_engine(model, im, file, train, half, dynamic, simplify, workspace, verbose) f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
if onnx or xml: # OpenVINO requires ONNX if onnx or xml: # OpenVINO requires ONNX
f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify) f[2], _ = export_onnx(model, im, file, opset, train, dynamic, simplify)
if xml: # OpenVINO if xml: # OpenVINO
f[3] = export_openvino(model, file, half) f[3], _ = export_openvino(model, file, half)
if coreml: if coreml:
_, f[4] = export_coreml(model, im, file, int8, half) f[4], _ = export_coreml(model, im, file, int8, half)
# TensorFlow Exports # TensorFlow Exports
if any((saved_model, pb, tflite, edgetpu, tfjs)): if any((saved_model, pb, tflite, edgetpu, tfjs)):
if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707 if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow` check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow`
assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.' assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'
model, f[5] = export_saved_model(model.cpu(), f[5], model = export_saved_model(model.cpu(),
im, im,
file, file,
dynamic, dynamic,
...@@ -551,19 +530,19 @@ def run( ...@@ -551,19 +530,19 @@ def run(
conf_thres=conf_thres, conf_thres=conf_thres,
keras=keras) keras=keras)
if pb or tfjs: # pb prerequisite to tfjs if pb or tfjs: # pb prerequisite to tfjs
f[6] = export_pb(model, file) f[6], _ = export_pb(model, file)
if tflite or edgetpu: if tflite or edgetpu:
f[7] = export_tflite(model, im, file, int8=int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms) f[7], _ = export_tflite(model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
if edgetpu: if edgetpu:
f[8] = export_edgetpu(file) f[8], _ = export_edgetpu(file)
if tfjs: if tfjs:
f[9] = export_tfjs(file) f[9], _ = export_tfjs(file)
# Finish # Finish
f = [str(x) for x in f if x] # filter out '' and None f = [str(x) for x in f if x] # filter out '' and None
if any(f): if any(f):
h = '--half' if half else '' # --half FP16 inference arg h = '--half' if half else '' # --half FP16 inference arg
LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)' LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
f"\nResults saved to {colorstr('bold', file.parent.resolve())}" f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
f"\nDetect: python detect.py --weights {f[-1]} {h}" f"\nDetect: python detect.py --weights {f[-1]} {h}"
f"\nValidate: python val.py --weights {f[-1]} {h}" f"\nValidate: python val.py --weights {f[-1]} {h}"
......
...@@ -148,6 +148,7 @@ class Profile(contextlib.ContextDecorator): ...@@ -148,6 +148,7 @@ class Profile(contextlib.ContextDecorator):
def __enter__(self): def __enter__(self):
self.start = self.time() self.start = self.time()
return self
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
self.dt = self.time() - self.start # delta-time self.dt = self.time() - self.start # delta-time
...@@ -220,10 +221,10 @@ def methods(instance): ...@@ -220,10 +221,10 @@ def methods(instance):
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")] return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False): def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
# Print function arguments (optional args dict) # Print function arguments (optional args dict)
x = inspect.currentframe().f_back # previous frame x = inspect.currentframe().f_back # previous frame
file, _, fcn, _, _ = inspect.getframeinfo(x) file, _, func, _, _ = inspect.getframeinfo(x)
if args is None: # get args automatically if args is None: # get args automatically
args, _, _, frm = inspect.getargvalues(x) args, _, _, frm = inspect.getargvalues(x)
args = {k: v for k, v in frm.items() if k in args} args = {k: v for k, v in frm.items() if k in args}
...@@ -231,7 +232,7 @@ def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False): ...@@ -231,7 +232,7 @@ def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False):
file = Path(file).resolve().relative_to(ROOT).with_suffix('') file = Path(file).resolve().relative_to(ROOT).with_suffix('')
except ValueError: except ValueError:
file = Path(file).stem file = Path(file).stem
s = (f'{file}: ' if show_file else '') + (f'{fcn}: ' if show_fcn else '') s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items())) LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
...@@ -255,7 +256,13 @@ def init_seeds(seed=0, deterministic=False): ...@@ -255,7 +256,13 @@ def init_seeds(seed=0, deterministic=False):
def intersect_dicts(da, db, exclude=()): def intersect_dicts(da, db, exclude=()):
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape} return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
def get_default_args(func):
# Get func() default arguments
signature = inspect.signature(func)
return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
def get_latest_run(search_dir='.'): def get_latest_run(search_dir='.'):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论