Unverified 提交 e3e5122f authored 作者: Katteria's avatar Katteria 提交者: GitHub

Add PaddlePaddle export and inference (#9240)

* Add PaddlePaddle Model Export Test on Yolov5 DockerEnviroment with paddlepaddle-gpu v2.2 Signed-off-by: 's avatarKatteria <39751846+kisaragychihaya@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleanup Paddle Export Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update common.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update export.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update export.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update export.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update export.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Use PyTorch2Paddle Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Paddle no longer requires ONNX Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update export.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update export.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update benchmarks.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Add inference code of PaddlePaddle Signed-off-by: 's avatarKatteria <39751846+kisaragychihaya@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update common.py Signed-off-by: 's avatarKatteria <39751846+kisaragychihaya@users.noreply.github.com> * Update common.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Add paddlepaddle-gpu install if cuda Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update common.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update common.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update common.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: 's avatarKatteria <39751846+kisaragychihaya@users.noreply.github.com> Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 57ef676a
...@@ -15,6 +15,7 @@ TensorFlow GraphDef | `pb` | yolov5s.pb ...@@ -15,6 +15,7 @@ TensorFlow GraphDef | `pb` | yolov5s.pb
TensorFlow Lite | `tflite` | yolov5s.tflite TensorFlow Lite | `tflite` | yolov5s.tflite
TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite
TensorFlow.js | `tfjs` | yolov5s_web_model/ TensorFlow.js | `tfjs` | yolov5s_web_model/
PaddlePaddle | `paddle` | yolov5s_paddle_model/
Requirements: Requirements:
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU
...@@ -54,7 +55,6 @@ from pathlib import Path ...@@ -54,7 +55,6 @@ from pathlib import Path
import pandas as pd import pandas as pd
import torch import torch
import yaml
from torch.utils.mobile_optimizer import optimize_for_mobile from torch.utils.mobile_optimizer import optimize_for_mobile
FILE = Path(__file__).resolve() FILE = Path(__file__).resolve()
...@@ -68,7 +68,7 @@ from models.experimental import attempt_load ...@@ -68,7 +68,7 @@ from models.experimental import attempt_load
from models.yolo import ClassificationModel, Detect from models.yolo import ClassificationModel, Detect
from utils.dataloaders import LoadImages from utils.dataloaders import LoadImages
from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version, from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,
check_yaml, colorstr, file_size, get_default_args, print_args, url2file) check_yaml, colorstr, file_size, get_default_args, print_args, url2file, yaml_save)
from utils.torch_utils import select_device, smart_inference_mode from utils.torch_utils import select_device, smart_inference_mode
...@@ -85,7 +85,8 @@ def export_formats(): ...@@ -85,7 +85,8 @@ def export_formats():
['TensorFlow GraphDef', 'pb', '.pb', True, True], ['TensorFlow GraphDef', 'pb', '.pb', True, True],
['TensorFlow Lite', 'tflite', '.tflite', True, False], ['TensorFlow Lite', 'tflite', '.tflite', True, False],
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False], ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False],
['TensorFlow.js', 'tfjs', '_web_model', False, False],] ['TensorFlow.js', 'tfjs', '_web_model', False, False],
['PaddlePaddle', 'paddle', '_paddle_model', True, True],]
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU']) return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
...@@ -180,7 +181,7 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst ...@@ -180,7 +181,7 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst
@try_export @try_export
def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')): def export_openvino(file, metadata, half, prefix=colorstr('OpenVINO:')):
# YOLOv5 OpenVINO export # YOLOv5 OpenVINO export
check_requirements('openvino-dev') # requires openvino-dev: https://pypi.org/project/openvino-dev/ check_requirements('openvino-dev') # requires openvino-dev: https://pypi.org/project/openvino-dev/
import openvino.inference_engine as ie import openvino.inference_engine as ie
...@@ -189,9 +190,23 @@ def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')): ...@@ -189,9 +190,23 @@ def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')):
f = str(file).replace('.pt', f'_openvino_model{os.sep}') f = str(file).replace('.pt', f'_openvino_model{os.sep}')
cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}" cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"
subprocess.check_output(cmd.split()) # export subprocess.run(cmd.split(), check=True, env=os.environ) # export
with open(Path(f) / file.with_suffix('.yaml').name, 'w') as g: yaml_save(Path(f) / file.with_suffix('.yaml').name, metadata) # add metadata.yaml
yaml.dump({'stride': int(max(model.stride)), 'names': model.names}, g) # add metadata.yaml return f, None
@try_export
def export_paddle(model, im, file, metadata, prefix=colorstr('PaddlePaddle:')):
# YOLOv5 Paddle export
check_requirements(('paddlepaddle', 'x2paddle'))
import x2paddle
from x2paddle.convert import pytorch2paddle
LOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...')
f = str(file).replace('.pt', f'_paddle_model{os.sep}')
pytorch2paddle(module=model, save_dir=f, jit_type='trace', input_examples=[im]) # export
yaml_save(Path(f) / file.with_suffix('.yaml').name, metadata) # add metadata.yaml
return f, None return f, None
...@@ -464,7 +479,7 @@ def run( ...@@ -464,7 +479,7 @@ def run(
fmts = tuple(export_formats()['Argument'][1:]) # --include arguments fmts = tuple(export_formats()['Argument'][1:]) # --include arguments
flags = [x in include for x in fmts] flags = [x in include for x in fmts]
assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}' assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}'
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = flags # export booleans
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights
# Load PyTorch model # Load PyTorch model
...@@ -497,29 +512,28 @@ def run( ...@@ -497,29 +512,28 @@ def run(
if half and not coreml: if half and not coreml:
im, model = im.half(), model.half() # to FP16 im, model = im.half(), model.half() # to FP16
shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
metadata = {'stride': int(max(model.stride)), 'names': model.names} # model metadata
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)") LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")
# Exports # Exports
f = [''] * 10 # exported filenames f = [''] * len(fmts) # exported filenames
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
if jit: if jit: # TorchScript
f[0], _ = export_torchscript(model, im, file, optimize) f[0], _ = export_torchscript(model, im, file, optimize)
if engine: # TensorRT required before ONNX if engine: # TensorRT required before ONNX
f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose) f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
if onnx or xml: # OpenVINO requires ONNX if onnx or xml: # OpenVINO requires ONNX
f[2], _ = export_onnx(model, im, file, opset, train, dynamic, simplify) f[2], _ = export_onnx(model, im, file, opset, train, dynamic, simplify)
if xml: # OpenVINO if xml: # OpenVINO
f[3], _ = export_openvino(model, file, half) f[3], _ = export_openvino(file, metadata, half)
if coreml: if coreml: # CoreML
f[4], _ = export_coreml(model, im, file, int8, half) f[4], _ = export_coreml(model, im, file, int8, half)
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
# TensorFlow Exports
if any((saved_model, pb, tflite, edgetpu, tfjs)):
if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707 if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
check_requirements('flatbuffers==1.12') # required before `import tensorflow` check_requirements('flatbuffers==1.12') # required before `import tensorflow`
assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.' assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'
assert not isinstance(model, ClassificationModel), 'ClassificationModel export to TF formats not yet supported.' assert not isinstance(model, ClassificationModel), 'ClassificationModel export to TF formats not yet supported.'
f[5], model = export_saved_model(model.cpu(), f[5], s_model = export_saved_model(model.cpu(),
im, im,
file, file,
dynamic, dynamic,
...@@ -531,13 +545,15 @@ def run( ...@@ -531,13 +545,15 @@ def run(
conf_thres=conf_thres, conf_thres=conf_thres,
keras=keras) keras=keras)
if pb or tfjs: # pb prerequisite to tfjs if pb or tfjs: # pb prerequisite to tfjs
f[6], _ = export_pb(model, file) f[6], _ = export_pb(s_model, file)
if tflite or edgetpu: if tflite or edgetpu:
f[7], _ = export_tflite(model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms) f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
if edgetpu: if edgetpu:
f[8], _ = export_edgetpu(file) f[8], _ = export_edgetpu(file)
if tfjs: if tfjs:
f[9], _ = export_tfjs(file) f[9], _ = export_tfjs(file)
if paddle: # PaddlePaddle
f[10], _ = export_paddle(model, im, file, metadata)
# Finish # Finish
f = [str(x) for x in f if x] # filter out '' and None f = [str(x) for x in f if x] # filter out '' and None
......
...@@ -320,14 +320,16 @@ class DetectMultiBackend(nn.Module): ...@@ -320,14 +320,16 @@ class DetectMultiBackend(nn.Module):
# TensorFlow GraphDef: *.pb # TensorFlow GraphDef: *.pb
# TensorFlow Lite: *.tflite # TensorFlow Lite: *.tflite
# TensorFlow Edge TPU: *_edgetpu.tflite # TensorFlow Edge TPU: *_edgetpu.tflite
# PaddlePaddle: *_paddle_model
from models.experimental import attempt_download, attempt_load # scoped to avoid circular import from models.experimental import attempt_download, attempt_load # scoped to avoid circular import
super().__init__() super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights) w = str(weights[0] if isinstance(weights, list) else weights)
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self._model_type(w) # get backend pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = self._model_type(w) # type
w = attempt_download(w) # download if not local w = attempt_download(w) # download if not local
fp16 &= pt or jit or onnx or engine # FP16 fp16 &= pt or jit or onnx or engine # FP16
stride = 32 # default stride stride = 32 # default stride
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
if pt: # PyTorch if pt: # PyTorch
model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse) model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
...@@ -351,7 +353,6 @@ class DetectMultiBackend(nn.Module): ...@@ -351,7 +353,6 @@ class DetectMultiBackend(nn.Module):
net = cv2.dnn.readNetFromONNX(w) net = cv2.dnn.readNetFromONNX(w)
elif onnx: # ONNX Runtime elif onnx: # ONNX Runtime
LOGGER.info(f'Loading {w} for ONNX Runtime inference...') LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
cuda = torch.cuda.is_available() and device.type != 'cpu'
check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime')) check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
import onnxruntime import onnxruntime
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider'] providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
...@@ -408,8 +409,7 @@ class DetectMultiBackend(nn.Module): ...@@ -408,8 +409,7 @@ class DetectMultiBackend(nn.Module):
LOGGER.info(f'Loading {w} for CoreML inference...') LOGGER.info(f'Loading {w} for CoreML inference...')
import coremltools as ct import coremltools as ct
model = ct.models.MLModel(w) model = ct.models.MLModel(w)
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) elif saved_model: # TF SavedModel
if saved_model: # SavedModel
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...') LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
import tensorflow as tf import tensorflow as tf
keras = False # assume TF1 saved_model keras = False # assume TF1 saved_model
...@@ -423,7 +423,7 @@ class DetectMultiBackend(nn.Module): ...@@ -423,7 +423,7 @@ class DetectMultiBackend(nn.Module):
ge = x.graph.as_graph_element ge = x.graph.as_graph_element
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs)) return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
gd = tf.Graph().as_graph_def() # graph_def gd = tf.Graph().as_graph_def() # TF GraphDef
with open(w, 'rb') as f: with open(w, 'rb') as f:
gd.ParseFromString(f.read()) gd.ParseFromString(f.read())
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs="Identity:0") frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs="Identity:0")
...@@ -433,21 +433,34 @@ class DetectMultiBackend(nn.Module): ...@@ -433,21 +433,34 @@ class DetectMultiBackend(nn.Module):
except ImportError: except ImportError:
import tensorflow as tf import tensorflow as tf
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate, Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
if edgetpu: # Edge TPU https://coral.ai/software/#edgetpu-runtime if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...') LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
delegate = { delegate = {
'Linux': 'libedgetpu.so.1', 'Linux': 'libedgetpu.so.1',
'Darwin': 'libedgetpu.1.dylib', 'Darwin': 'libedgetpu.1.dylib',
'Windows': 'edgetpu.dll'}[platform.system()] 'Windows': 'edgetpu.dll'}[platform.system()]
interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)]) interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
else: # Lite else: # TFLite
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...') LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
interpreter = Interpreter(model_path=w) # load TFLite model interpreter = Interpreter(model_path=w) # load TFLite model
interpreter.allocate_tensors() # allocate interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs output_details = interpreter.get_output_details() # outputs
elif tfjs: elif tfjs: # TF.js
raise NotImplementedError('ERROR: YOLOv5 TF.js inference is not supported') raise NotImplementedError('ERROR: YOLOv5 TF.js inference is not supported')
elif paddle: # PaddlePaddle
LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
import paddle.inference as pdi
if not Path(w).is_file(): # if not *.pdmodel
w = next(Path(w).rglob('*.pdmodel')) # get *.xml file from *_openvino_model dir
weights = Path(w).with_suffix('.pdiparams')
config = pdi.Config(str(w), str(weights))
if cuda:
config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
predictor = pdi.create_predictor(config)
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
else: else:
raise NotImplementedError(f'ERROR: {w} is not a supported format') raise NotImplementedError(f'ERROR: {w} is not a supported format')
...@@ -502,6 +515,13 @@ class DetectMultiBackend(nn.Module): ...@@ -502,6 +515,13 @@ class DetectMultiBackend(nn.Module):
else: else:
k = 'var_' + str(sorted(int(k.replace('var_', '')) for k in y)[-1]) # output key k = 'var_' + str(sorted(int(k.replace('var_', '')) for k in y)[-1]) # output key
y = y[k] # output y = y[k] # output
elif self.paddle: # PaddlePaddle
im = im.cpu().numpy().astype("float32")
self.input_handle.copy_from_cpu(im)
self.predictor.run()
output_names = self.predictor.get_output_names()
output_handle = self.predictor.get_output_handle(output_names[0])
y = output_handle.copy_to_cpu()
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
if self.saved_model: # SavedModel if self.saved_model: # SavedModel
...@@ -542,13 +562,13 @@ class DetectMultiBackend(nn.Module): ...@@ -542,13 +562,13 @@ class DetectMultiBackend(nn.Module):
def _model_type(p='path/to/model.pt'): def _model_type(p='path/to/model.pt'):
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
from export import export_formats from export import export_formats
suffixes = list(export_formats().Suffix) + ['.xml'] # export suffixes sf = list(export_formats().Suffix) + ['.xml'] # export suffixes
check_suffix(p, suffixes) # checks check_suffix(p, sf) # checks
p = Path(p).name # eliminate trailing separators p = Path(p).name # eliminate trailing separators
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, xml2 = (s in p for s in suffixes) pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, xml2 = (s in p for s in sf)
xml |= xml2 # *_openvino_model or *.xml xml |= xml2 # *_openvino_model or *.xml
tflite &= not edgetpu # *.tflite tflite &= not edgetpu # *.tflite
return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle
@staticmethod @staticmethod
def _load_metadata(f=Path('path/to/meta.yaml')): def _load_metadata(f=Path('path/to/meta.yaml')):
......
...@@ -61,7 +61,7 @@ def run( ...@@ -61,7 +61,7 @@ def run(
device = select_device(device) device = select_device(device)
for i, (name, f, suffix, cpu, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, CPU, GPU) for i, (name, f, suffix, cpu, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, CPU, GPU)
try: try:
assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported assert i not in (9, 10, 11), 'inference not supported' # Edge TPU, TF.js and Paddle are unsupported
assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
if 'cpu' in device.type: if 'cpu' in device.type:
assert cpu, 'inference not supported on CPU' assert cpu, 'inference not supported on CPU'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论