Unverified 提交 a297efc3 authored 作者: Raffaele Galliera's avatar Raffaele Galliera 提交者: GitHub

Edge TPU inference fix (#6686)

* refactor: use edgetpu flag * fix: remove bitwise and assignation to tflite * Cleanup and fix tflite * Cleanup Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 03653790
...@@ -279,17 +279,17 @@ class DetectMultiBackend(nn.Module): ...@@ -279,17 +279,17 @@ class DetectMultiBackend(nn.Module):
# YOLOv5 MultiBackend class for python inference on various backends # YOLOv5 MultiBackend class for python inference on various backends
def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None): def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
# Usage: # Usage:
# PyTorch: weights = *.pt # PyTorch: weights = *.pt
# TorchScript: *.torchscript # TorchScript: *.torchscript
# CoreML: *.mlmodel # ONNX Runtime: *.onnx
# OpenVINO: *.xml # ONNX OpenCV DNN: *.onnx with --dnn
# TensorFlow: *_saved_model # OpenVINO: *.xml
# TensorFlow: *.pb # CoreML: *.mlmodel
# TensorFlow Lite: *.tflite # TensorRT: *.engine
# TensorFlow Edge TPU: *_edgetpu.tflite # TensorFlow SavedModel: *_saved_model
# ONNX Runtime: *.onnx # TensorFlow GraphDef: *.pb
# OpenCV DNN: *.onnx with dnn=True # TensorFlow Lite: *.tflite
# TensorRT: *.engine # TensorFlow Edge TPU: *_edgetpu.tflite
from models.experimental import attempt_download, attempt_load # scoped to avoid circular import from models.experimental import attempt_download, attempt_load # scoped to avoid circular import
super().__init__() super().__init__()
...@@ -367,19 +367,19 @@ class DetectMultiBackend(nn.Module): ...@@ -367,19 +367,19 @@ class DetectMultiBackend(nn.Module):
def wrap_frozen_graph(gd, inputs, outputs): def wrap_frozen_graph(gd, inputs, outputs):
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs), ge = x.graph.as_graph_element
tf.nest.map_structure(x.graph.as_graph_element, outputs)) return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
graph_def = tf.Graph().as_graph_def() gd = tf.Graph().as_graph_def() # graph_def
graph_def.ParseFromString(open(w, 'rb').read()) gd.ParseFromString(open(w, 'rb').read())
frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0") frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs="Identity:0")
elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
from tflite_runtime.interpreter import Interpreter, load_delegate from tflite_runtime.interpreter import Interpreter, load_delegate
except ImportError: except ImportError:
import tensorflow as tf import tensorflow as tf
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate, Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
if 'edgetpu' in w.lower(): # Edge TPU https://coral.ai/software/#edgetpu-runtime if edgetpu: # Edge TPU https://coral.ai/software/#edgetpu-runtime
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...') LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
delegate = {'Linux': 'libedgetpu.so.1', delegate = {'Linux': 'libedgetpu.so.1',
'Darwin': 'libedgetpu.1.dylib', 'Darwin': 'libedgetpu.1.dylib',
...@@ -391,6 +391,8 @@ class DetectMultiBackend(nn.Module): ...@@ -391,6 +391,8 @@ class DetectMultiBackend(nn.Module):
interpreter.allocate_tensors() # allocate interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs output_details = interpreter.get_output_details() # outputs
elif tfjs:
raise Exception('ERROR: YOLOv5 TF.js inference is not supported')
self.__dict__.update(locals()) # assign all variables to self self.__dict__.update(locals()) # assign all variables to self
def forward(self, im, augment=False, visualize=False, val=False): def forward(self, im, augment=False, visualize=False, val=False):
...@@ -436,7 +438,7 @@ class DetectMultiBackend(nn.Module): ...@@ -436,7 +438,7 @@ class DetectMultiBackend(nn.Module):
y = (self.model(im, training=False) if self.keras else self.model(im)[0]).numpy() y = (self.model(im, training=False) if self.keras else self.model(im)[0]).numpy()
elif self.pb: # GraphDef elif self.pb: # GraphDef
y = self.frozen_func(x=self.tf.constant(im)).numpy() y = self.frozen_func(x=self.tf.constant(im)).numpy()
elif self.tflite: # Lite else: # Lite or Edge TPU
input, output = self.input_details[0], self.output_details[0] input, output = self.input_details[0], self.output_details[0]
int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
if int8: if int8:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论