Unverified 提交 d40cd0d4 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Improved `Profile()` inference timing (#9024)

* Improved `Profile()` class * Update predict.py * Update val.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update val.py * Update AutoShape Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
上级 c0e7a776
......@@ -22,8 +22,8 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from models.common import DetectMultiBackend
from utils.augmentations import classify_transforms
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages
from utils.general import LOGGER, check_file, check_requirements, colorstr, increment_path, print_args
from utils.torch_utils import select_device, smart_inference_mode, time_sync
from utils.general import LOGGER, Profile, check_file, check_requirements, colorstr, increment_path, print_args
from utils.torch_utils import select_device, smart_inference_mode
@smart_inference_mode()
......@@ -44,7 +44,7 @@ def run(
if is_url and is_file:
source = check_file(source) # download
seen, dt = 1, [0.0, 0.0, 0.0]
dt = Profile(), Profile(), Profile()
device = select_device(device)
# Directories
......@@ -55,30 +55,27 @@ def run(
model = DetectMultiBackend(weights, device=device, dnn=dnn, fp16=half)
model.warmup(imgsz=(1, 3, imgsz, imgsz)) # warmup
dataset = LoadImages(source, img_size=imgsz, transforms=classify_transforms(imgsz))
for path, im, im0s, vid_cap, s in dataset:
for seen, (path, im, im0s, vid_cap, s) in enumerate(dataset):
# Image
t1 = time_sync()
im = im.unsqueeze(0).to(device)
im = im.half() if model.fp16 else im.float()
t2 = time_sync()
dt[0] += t2 - t1
with dt[0]:
im = im.unsqueeze(0).to(device)
im = im.half() if model.fp16 else im.float()
# Inference
results = model(im)
t3 = time_sync()
dt[1] += t3 - t2
with dt[1]:
results = model(im)
# Post-process
p = F.softmax(results, dim=1) # probabilities
i = p.argsort(1, descending=True)[:, :5].squeeze().tolist() # top 5 indices
dt[2] += time_sync() - t3
# if save:
# imshow_cls(im, f=save_dir / Path(path).name, verbose=True)
seen += 1
LOGGER.info(f"{s}{imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i)}")
with dt[2]:
p = F.softmax(results, dim=1) # probabilities
i = p.argsort(1, descending=True)[:, :5].squeeze().tolist() # top 5 indices
# if save:
# imshow_cls(im, f=save_dir / Path(path).name, verbose=True)
LOGGER.info(
f"{s}{imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i)}, {dt[1].dt * 1E3:.1f}ms")
# Print results
t = tuple(x / seen * 1E3 for x in dt) # speeds per image
t = tuple(x.t / (seen + 1) * 1E3 for x in dt) # speeds per image
shape = (1, 3, imgsz, imgsz)
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
......
......@@ -23,8 +23,8 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from models.common import DetectMultiBackend
from utils.dataloaders import create_classification_dataloader
from utils.general import LOGGER, check_img_size, check_requirements, colorstr, increment_path, print_args
from utils.torch_utils import select_device, smart_inference_mode, time_sync
from utils.general import LOGGER, Profile, check_img_size, check_requirements, colorstr, increment_path, print_args
from utils.torch_utils import select_device, smart_inference_mode
@smart_inference_mode()
......@@ -83,27 +83,24 @@ def run(
workers=workers)
model.eval()
pred, targets, loss, dt = [], [], 0, [0.0, 0.0, 0.0]
pred, targets, loss, dt = [], [], 0, (Profile(), Profile(), Profile())
n = len(dataloader) # number of batches
action = 'validating' if dataloader.dataset.root.stem == 'val' else 'testing'
desc = f"{pbar.desc[:-36]}{action:>36}" if pbar else f"{action}"
bar = tqdm(dataloader, desc, n, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}', position=0)
with torch.cuda.amp.autocast(enabled=device.type != 'cpu'):
for images, labels in bar:
t1 = time_sync()
images, labels = images.to(device, non_blocking=True), labels.to(device)
t2 = time_sync()
dt[0] += t2 - t1
with dt[0]:
images, labels = images.to(device, non_blocking=True), labels.to(device)
y = model(images)
t3 = time_sync()
dt[1] += t3 - t2
with dt[1]:
y = model(images)
pred.append(y.argsort(1, descending=True)[:, :5])
targets.append(labels)
if criterion:
loss += criterion(y, labels)
dt[2] += time_sync() - t3
with dt[2]:
pred.append(y.argsort(1, descending=True)[:, :5])
targets.append(labels)
if criterion:
loss += criterion(y, labels)
loss /= n
pred, targets = torch.cat(pred), torch.cat(targets)
......@@ -122,7 +119,7 @@ def run(
LOGGER.info(f"{c:>24}{aci.shape[0]:>12}{top1i:>12.3g}{top5i:>12.3g}")
# Print results
t = tuple(x / len(dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image
t = tuple(x.t / len(dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image
shape = (1, 3, imgsz, imgsz)
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
......
......@@ -41,10 +41,10 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import select_device, smart_inference_mode, time_sync
from utils.torch_utils import select_device, smart_inference_mode
@smart_inference_mode()
......@@ -107,26 +107,23 @@ def run(
# Run inference
model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup
seen, windows, dt = 0, [], [0.0, 0.0, 0.0]
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
for path, im, im0s, vid_cap, s in dataset:
t1 = time_sync()
im = torch.from_numpy(im).to(device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
t2 = time_sync()
dt[0] += t2 - t1
with dt[0]:
im = torch.from_numpy(im).to(device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
# Inference
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
pred = model(im, augment=augment, visualize=visualize)
t3 = time_sync()
dt[1] += t3 - t2
with dt[1]:
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
pred = model(im, augment=augment, visualize=visualize)
# NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
dt[2] += time_sync() - t3
with dt[2]:
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
# Second-stage classifier (optional)
# pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
......@@ -201,10 +198,10 @@ def run(
vid_writer[i].write(im0)
# Print time (inference-only)
LOGGER.info(f'{s}Done. ({t3 - t2:.3f}s)')
LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
# Print results
t = tuple(x / seen * 1E3 for x in dt) # speeds per image
t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
if save_txt or save_img:
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
......
差异被折叠。
......@@ -141,16 +141,26 @@ CONFIG_DIR = user_config_dir() # Ultralytics settings dir
class Profile(contextlib.ContextDecorator):
# Usage: @Profile() decorator or 'with Profile():' context manager
# YOLOv5 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
def __init__(self, t=0.0):
self.t = t
self.cuda = torch.cuda.is_available()
def __enter__(self):
self.start = time.time()
self.start = self.time()
def __exit__(self, type, value, traceback):
print(f'Profile results: {time.time() - self.start:.5f}s')
self.dt = self.time() - self.start # delta-time
self.t += self.dt # accumulate dt
def time(self):
if self.cuda:
torch.cuda.synchronize()
return time.time()
class Timeout(contextlib.ContextDecorator):
# Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
# YOLOv5 Timeout class. Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
self.seconds = int(seconds)
self.timeout_message = timeout_msg
......
......@@ -37,7 +37,7 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from models.common import DetectMultiBackend
from utils.callbacks import Callbacks
from utils.dataloaders import create_dataloader
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_yaml,
from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_yaml,
coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args,
scale_coords, xywh2xyxy, xyxy2xywh)
from utils.metrics import ConfusionMatrix, ap_per_class, box_iou
......@@ -187,26 +187,24 @@ def run(
names = dict(enumerate(names))
class_map = coco80_to_coco91_class() if is_coco else list(range(1000))
s = ('%20s' + '%11s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
dt, p, r, f1, mp, mr, map50, map = [0.0, 0.0, 0.0], 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
dt, p, r, f1, mp, mr, map50, map = (Profile(), Profile(), Profile()), 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
loss = torch.zeros(3, device=device)
jdict, stats, ap, ap_class = [], [], [], []
callbacks.run('on_val_start')
pbar = tqdm(dataloader, desc=s, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
for batch_i, (im, targets, paths, shapes) in enumerate(pbar):
callbacks.run('on_val_batch_start')
t1 = time_sync()
if cuda:
im = im.to(device, non_blocking=True)
targets = targets.to(device)
im = im.half() if half else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
nb, _, height, width = im.shape # batch size, channels, height, width
t2 = time_sync()
dt[0] += t2 - t1
with dt[0]:
if cuda:
im = im.to(device, non_blocking=True)
targets = targets.to(device)
im = im.half() if half else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
nb, _, height, width = im.shape # batch size, channels, height, width
# Inference
out, train_out = model(im) if training else model(im, augment=augment, val=True) # inference, loss outputs
dt[1] += time_sync() - t2
with dt[1]:
out, train_out = model(im) if training else model(im, augment=augment, val=True) # inference, loss outputs
# Loss
if compute_loss:
......@@ -215,9 +213,8 @@ def run(
# NMS
targets[:, 2:] *= torch.tensor((width, height, width, height), device=device) # to pixels
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
t3 = time_sync()
out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
dt[2] += time_sync() - t3
with dt[2]:
out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
# Metrics
for si, pred in enumerate(out):
......@@ -284,7 +281,7 @@ def run(
LOGGER.info(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
# Print speeds
t = tuple(x / seen * 1E3 for x in dt) # speeds per image
t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
if not training:
shape = (batch_size, 3, imgsz, imgsz)
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论