Unverified 提交 66e5d794 authored 作者: Jiacong Fang's avatar Jiacong Fang 提交者: GitHub

Fix TF exports >= 2GB (#6292)

* Fix exporting saved_model: pb exceeds 2GB * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Replace TF v1.x API with TF v2.x API for saved_model export * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Clean up * Remove lambda in tf.function() * Revert "Remove lambda in tf.function()" to be compatible with TF v2.4 This reverts commit 46c7931f11dfdea6ae340c77287c35c30b9e0779. * Fix for pre-commit.ci * Cleanup1 * Cleanup2 * Backwards compatibility update * Update common.py * Update common.py * Cleanup3 Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 07221f15
...@@ -247,11 +247,11 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F ...@@ -247,11 +247,11 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
def export_saved_model(model, im, file, dynamic, def export_saved_model(model, im, file, dynamic,
tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45, tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45,
conf_thres=0.25, prefix=colorstr('TensorFlow SavedModel:')): conf_thres=0.25, keras=False, prefix=colorstr('TensorFlow SavedModel:')):
# YOLOv5 TensorFlow SavedModel export # YOLOv5 TensorFlow SavedModel export
try: try:
import tensorflow as tf import tensorflow as tf
from tensorflow import keras from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
from models.tf import TFDetect, TFModel from models.tf import TFDetect, TFModel
...@@ -262,13 +262,26 @@ def export_saved_model(model, im, file, dynamic, ...@@ -262,13 +262,26 @@ def export_saved_model(model, im, file, dynamic,
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz) tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for TensorFlow im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for TensorFlow
_ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) _ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
inputs = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size) inputs = tf.keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
keras_model = keras.Model(inputs=inputs, outputs=outputs) keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
keras_model.trainable = False keras_model.trainable = False
keras_model.summary() keras_model.summary()
keras_model.save(f, save_format='tf') if keras:
keras_model.save(f, save_format='tf')
else:
m = tf.function(lambda x: keras_model(x)) # full model
spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)
m = m.get_concrete_function(spec)
frozen_func = convert_variables_to_constants_v2(m)
tfm = tf.Module()
tfm.__call__ = tf.function(lambda x: frozen_func(x), [spec])
tfm.__call__(im)
tf.saved_model.save(
tfm,
f,
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if
check_version(tf.__version__, '2.6') else tf.saved_model.SaveOptions())
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return keras_model, f return keras_model, f
except Exception as e: except Exception as e:
......
...@@ -359,7 +359,8 @@ class DetectMultiBackend(nn.Module): ...@@ -359,7 +359,8 @@ class DetectMultiBackend(nn.Module):
if saved_model: # SavedModel if saved_model: # SavedModel
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...') LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
import tensorflow as tf import tensorflow as tf
model = tf.keras.models.load_model(w) keras = False # assume TF1 saved_model
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...') LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
import tensorflow as tf import tensorflow as tf
...@@ -431,7 +432,7 @@ class DetectMultiBackend(nn.Module): ...@@ -431,7 +432,7 @@ class DetectMultiBackend(nn.Module):
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
if self.saved_model: # SavedModel if self.saved_model: # SavedModel
y = self.model(im, training=False).numpy() y = (self.model(im, training=False) if self.keras else self.model(im)[0]).numpy()
elif self.pb: # GraphDef elif self.pb: # GraphDef
y = self.frozen_func(x=self.tf.constant(im)).numpy() y = self.frozen_func(x=self.tf.constant(im)).numpy()
elif self.tflite: # Lite elif self.tflite: # Lite
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论