Unverified 提交 b2adc7c3 authored 作者: Louis Combaldieu's avatar Louis Combaldieu 提交者: GitHub

Fix export for 1-channel images (#6780)

Export failed for 1-channel input shape, 1-liner fix
上级 cea994b3
...@@ -260,9 +260,9 @@ def export_saved_model(model, im, file, dynamic, ...@@ -260,9 +260,9 @@ def export_saved_model(model, im, file, dynamic,
batch_size, ch, *imgsz = list(im.shape) # BCHW batch_size, ch, *imgsz = list(im.shape) # BCHW
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz) tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for TensorFlow im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow
_ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) _ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
inputs = tf.keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size) inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size)
outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs) keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
keras_model.trainable = False keras_model.trainable = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论