Unverified 提交 53349dac authored 作者: Phil2020's avatar Phil2020 提交者: GitHub

Scope TF imports in `DetectMultiBackend()` (#5792)

* tensorflow or tflite exclusively as interpreter As per bug report https://github.com/ultralytics/yolov5/issues/5709 I think there should be only one attempt to assign interpreter, and it appears tflite is only ever needed for the case of edgetpu model. * Scope imports * Nested definition line fix * Update common.py Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 f2ca30a4
...@@ -337,19 +337,21 @@ class DetectMultiBackend(nn.Module): ...@@ -337,19 +337,21 @@ class DetectMultiBackend(nn.Module):
context = model.create_execution_context() context = model.create_execution_context()
batch_size = bindings['images'].shape[0] batch_size = bindings['images'].shape[0]
else: # TensorFlow model (TFLite, pb, saved_model) else: # TensorFlow model (TFLite, pb, saved_model)
import tensorflow as tf
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
LOGGER.info(f'Loading {w} for TensorFlow *.pb inference...')
import tensorflow as tf
def wrap_frozen_graph(gd, inputs, outputs): def wrap_frozen_graph(gd, inputs, outputs):
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs), return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
tf.nest.map_structure(x.graph.as_graph_element, outputs)) tf.nest.map_structure(x.graph.as_graph_element, outputs))
LOGGER.info(f'Loading {w} for TensorFlow *.pb inference...')
graph_def = tf.Graph().as_graph_def() graph_def = tf.Graph().as_graph_def()
graph_def.ParseFromString(open(w, 'rb').read()) graph_def.ParseFromString(open(w, 'rb').read())
frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0") frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
elif saved_model: elif saved_model:
LOGGER.info(f'Loading {w} for TensorFlow saved_model inference...') LOGGER.info(f'Loading {w} for TensorFlow saved_model inference...')
import tensorflow as tf
model = tf.keras.models.load_model(w) model = tf.keras.models.load_model(w)
elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
if 'edgetpu' in w.lower(): if 'edgetpu' in w.lower():
...@@ -361,6 +363,7 @@ class DetectMultiBackend(nn.Module): ...@@ -361,6 +363,7 @@ class DetectMultiBackend(nn.Module):
interpreter = tfli.Interpreter(model_path=w, experimental_delegates=[tfli.load_delegate(delegate)]) interpreter = tfli.Interpreter(model_path=w, experimental_delegates=[tfli.load_delegate(delegate)])
else: else:
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...') LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
import tensorflow as tf
interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model
interpreter.allocate_tensors() # allocate interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs input_details = interpreter.get_input_details() # inputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论