Unverified 提交 afb98605 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Add Paddle exports to benchmarks (#9459)

* Add Paddle exports to benchmarks Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update plots.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update common.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 06083740
......@@ -65,7 +65,7 @@ def run(
model_type = type(attempt_load(weights, fuse=False)) # DetectionModel, SegmentationModel, etc.
for i, (name, f, suffix, cpu, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, CPU, GPU)
try:
assert i not in (9, 10, 11), 'inference not supported' # Edge TPU, TF.js and Paddle are unsupported
assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
if 'cpu' in device.type:
assert cpu, 'inference not supported on CPU'
......
......@@ -460,8 +460,8 @@ class DetectMultiBackend(nn.Module):
if cuda:
config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
predictor = pdi.create_predictor(config)
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
output_names = predictor.get_output_names()
else:
raise NotImplementedError(f'ERROR: {w} is not a supported format')
......@@ -517,12 +517,10 @@ class DetectMultiBackend(nn.Module):
k = 'var_' + str(sorted(int(k.replace('var_', '')) for k in y)[-1]) # output key
y = y[k] # output
elif self.paddle: # PaddlePaddle
im = im.cpu().numpy().astype("float32")
im = im.cpu().numpy().astype(np.float32)
self.input_handle.copy_from_cpu(im)
self.predictor.run()
output_names = self.predictor.get_output_names()
output_handle = self.predictor.get_output_handle(output_names[0])
y = output_handle.copy_to_cpu()
y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
if self.saved_model: # SavedModel
......
......@@ -99,9 +99,9 @@ def plot_images_and_masks(images, targets, masks, paths=None, fname='images.jpg'
if mh != h or mw != w:
mask = image_masks[j].astype(np.uint8)
mask = cv2.resize(mask, (w, h))
mask = mask.astype(np.bool)
mask = mask.astype(bool)
else:
mask = image_masks[j].astype(np.bool)
mask = image_masks[j].astype(bool)
with contextlib.suppress(Exception):
im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
annotator.fromarray(im)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论