Unverified 提交 c09964c2 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update inference default to multi_label=False (#2252)

* Update inference default to multi_label=False * bug fix * Update plots.py * Update plots.py
上级 ab2da5ed
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
import requests import requests
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image, ImageDraw from PIL import Image
from utils.datasets import letterbox from utils.datasets import letterbox
from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh
......
...@@ -106,7 +106,7 @@ def test(data, ...@@ -106,7 +106,7 @@ def test(data,
with torch.no_grad(): with torch.no_grad():
# Run model # Run model
t = time_synchronized() t = time_synchronized()
inf_out, train_out = model(img, augment=augment) # inference and training outputs out, train_out = model(img, augment=augment) # inference and training outputs
t0 += time_synchronized() - t t0 += time_synchronized() - t
# Compute loss # Compute loss
...@@ -117,11 +117,11 @@ def test(data, ...@@ -117,11 +117,11 @@ def test(data,
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
t = time_synchronized() t = time_synchronized()
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, labels=lb) out = non_max_suppression(out, conf_thres=conf_thres, iou_thres=iou_thres, labels=lb, multi_label=True)
t1 += time_synchronized() - t t1 += time_synchronized() - t
# Statistics per image # Statistics per image
for si, pred in enumerate(output): for si, pred in enumerate(out):
labels = targets[targets[:, 0] == si, 1:] labels = targets[targets[:, 0] == si, 1:]
nl = len(labels) nl = len(labels)
tcls = labels[:, 0].tolist() if nl else [] # target class tcls = labels[:, 0].tolist() if nl else [] # target class
...@@ -209,7 +209,7 @@ def test(data, ...@@ -209,7 +209,7 @@ def test(data,
f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels
Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start() Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start()
f = save_dir / f'test_batch{batch_i}_pred.jpg' # predictions f = save_dir / f'test_batch{batch_i}_pred.jpg' # predictions
Thread(target=plot_images, args=(img, output_to_target(output), paths, f, names), daemon=True).start() Thread(target=plot_images, args=(img, output_to_target(out), paths, f, names), daemon=True).start()
# Compute statistics # Compute statistics
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
......
...@@ -390,11 +390,12 @@ def wh_iou(wh1, wh2): ...@@ -390,11 +390,12 @@ def wh_iou(wh1, wh2):
return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter) return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()): def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
"""Performs Non-Maximum Suppression (NMS) on inference results labels=()):
"""Runs Non-Maximum Suppression (NMS) on inference results
Returns: Returns:
detections with shape: nx6 (x1, y1, x2, y2, conf, cls) list of detections, on (n,6) tensor per image [xyxy, conf, cls]
""" """
nc = prediction.shape[2] - 5 # number of classes nc = prediction.shape[2] - 5 # number of classes
...@@ -406,7 +407,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non ...@@ -406,7 +407,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 10.0 # seconds to quit after time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections redundant = True # require redundant detections
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS merge = False # use merge-NMS
t = time.time() t = time.time()
......
...@@ -54,7 +54,7 @@ def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5): ...@@ -54,7 +54,7 @@ def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
return filtfilt(b, a, data) # forward-backward filter return filtfilt(b, a, data) # forward-backward filter
def plot_one_box(x, img, color=None, label=None, line_thickness=None): def plot_one_box(x, img, color=None, label=None, line_thickness=3):
# Plots one bounding box on image img # Plots one bounding box on image img
tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
color = color or [random.randint(0, 255) for _ in range(3)] color = color or [random.randint(0, 255) for _ in range(3)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论