Unverified 提交 3f74cd9e authored 作者: Adrian Holovaty's avatar Adrian Holovaty 提交者: GitHub

Parameterize max_det + inference default at 1000 (#3215)

* Added max_det parameters in various places * 120 character line * PEP8 * 120 character line * Update inference default to 1000 instances * Update inference default to 1000 instances Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 13a1c726
...@@ -68,7 +68,8 @@ def detect(opt): ...@@ -68,7 +68,8 @@ def detect(opt):
pred = model(img, augment=opt.augment)[0] pred = model(img, augment=opt.augment)[0]
# Apply NMS # Apply NMS
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms) pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, opt.classes, opt.agnostic_nms,
max_det=opt.max_det)
t2 = time_synchronized() t2 = time_synchronized()
# Apply Classifier # Apply Classifier
...@@ -153,6 +154,7 @@ if __name__ == '__main__': ...@@ -153,6 +154,7 @@ if __name__ == '__main__':
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold') parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS') parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
parser.add_argument('--max-det', type=int, default=1000, help='maximum number of detections per image')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--view-img', action='store_true', help='display results') parser.add_argument('--view-img', action='store_true', help='display results')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
......
...@@ -215,12 +215,13 @@ class NMS(nn.Module): ...@@ -215,12 +215,13 @@ class NMS(nn.Module):
conf = 0.25 # confidence threshold conf = 0.25 # confidence threshold
iou = 0.45 # IoU threshold iou = 0.45 # IoU threshold
classes = None # (optional list) filter by class classes = None # (optional list) filter by class
max_det = 1000 # maximum number of detections per image
def __init__(self): def __init__(self):
super(NMS, self).__init__() super(NMS, self).__init__()
def forward(self, x): def forward(self, x):
return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) return non_max_suppression(x[0], self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det)
class AutoShape(nn.Module): class AutoShape(nn.Module):
...@@ -228,6 +229,7 @@ class AutoShape(nn.Module): ...@@ -228,6 +229,7 @@ class AutoShape(nn.Module):
conf = 0.25 # NMS confidence threshold conf = 0.25 # NMS confidence threshold
iou = 0.45 # NMS IoU threshold iou = 0.45 # NMS IoU threshold
classes = None # (optional list) filter by class classes = None # (optional list) filter by class
max_det = 1000 # maximum number of detections per image
def __init__(self, model): def __init__(self, model):
super(AutoShape, self).__init__() super(AutoShape, self).__init__()
...@@ -285,7 +287,7 @@ class AutoShape(nn.Module): ...@@ -285,7 +287,7 @@ class AutoShape(nn.Module):
t.append(time_synchronized()) t.append(time_synchronized())
# Post-process # Post-process
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS y = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS
for i in range(n): for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i]) scale_coords(shape1, y[i][:, :4], shape0[i])
......
...@@ -482,7 +482,7 @@ def wh_iou(wh1, wh2): ...@@ -482,7 +482,7 @@ def wh_iou(wh1, wh2):
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
labels=()): labels=(), max_det=300):
"""Runs Non-Maximum Suppression (NMS) on inference results """Runs Non-Maximum Suppression (NMS) on inference results
Returns: Returns:
...@@ -498,7 +498,6 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non ...@@ -498,7 +498,6 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
# Settings # Settings
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
max_det = 300 # maximum number of detections per image
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 10.0 # seconds to quit after time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections redundant = True # require redundant detections
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论