Unverified 提交 2c631754 authored 作者: Diego Montes's avatar Diego Montes 提交者: GitHub

Add nms and agnostic nms to export.py (#5938)

* add nms and agnostic nms to export.py * fix agnostic implies nms * reorder args to group TF args * PEP8 120 char Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 a42af30d
...@@ -328,6 +328,8 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' ...@@ -328,6 +328,8 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
opset=14, # ONNX: opset version opset=14, # ONNX: opset version
verbose=False, # TensorRT: verbose log verbose=False, # TensorRT: verbose log
workspace=4, # TensorRT: workspace size (GB) workspace=4, # TensorRT: workspace size (GB)
nms=False, # TF: add NMS to model
agnostic_nms=False, # TF: add agnostic NMS to model
topk_per_class=100, # TF.js NMS: topk per class to keep topk_per_class=100, # TF.js NMS: topk per class to keep
topk_all=100, # TF.js NMS: topk for all classes to keep topk_all=100, # TF.js NMS: topk for all classes to keep
iou_thres=0.45, # TF.js NMS: IoU threshold iou_thres=0.45, # TF.js NMS: IoU threshold
...@@ -381,9 +383,9 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' ...@@ -381,9 +383,9 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
if any(tf_exports): if any(tf_exports):
pb, tflite, tfjs = tf_exports[1:] pb, tflite, tfjs = tf_exports[1:]
assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.' assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
model = export_saved_model(model, im, file, dynamic, tf_nms=tfjs, agnostic_nms=tfjs, model = export_saved_model(model, im, file, dynamic, tf_nms=nms or agnostic_nms or tfjs,
topk_per_class=topk_per_class, topk_all=topk_all, conf_thres=conf_thres, agnostic_nms=agnostic_nms or tfjs, topk_per_class=topk_per_class, topk_all=topk_all,
iou_thres=iou_thres) # keras model conf_thres=conf_thres, iou_thres=iou_thres) # keras model
if pb or tfjs: # pb prerequisite to tfjs if pb or tfjs: # pb prerequisite to tfjs
export_pb(model, im, file) export_pb(model, im, file)
if tflite: if tflite:
...@@ -414,6 +416,8 @@ def parse_opt(): ...@@ -414,6 +416,8 @@ def parse_opt():
parser.add_argument('--opset', type=int, default=14, help='ONNX: opset version') parser.add_argument('--opset', type=int, default=14, help='ONNX: opset version')
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log') parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)') parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
parser.add_argument('--nms', action='store_true', help='TF: add NMS to model')
parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model')
parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep') parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep') parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold') parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论