Unverified 提交 4a295b1a authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Add `@threaded` decorator (#7813)

* Add `@threaded` decorator * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ciCo-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
上级 9d8ed37d
...@@ -48,8 +48,8 @@ from utils.dataloaders import create_dataloader ...@@ -48,8 +48,8 @@ from utils.dataloaders import create_dataloader
from utils.downloads import attempt_download from utils.downloads import attempt_download
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements, from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path, check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path,
init_seeds, intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights, init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
methods, one_cycle, print_args, print_mutation, strip_optimizer) one_cycle, print_args, print_mutation, strip_optimizer)
from utils.loggers import Loggers from utils.loggers import Loggers
from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.loss import ComputeLoss from utils.loss import ComputeLoss
......
...@@ -14,6 +14,7 @@ import random ...@@ -14,6 +14,7 @@ import random
import re import re
import shutil import shutil
import signal import signal
import threading
import time import time
import urllib import urllib
from datetime import datetime from datetime import datetime
...@@ -167,6 +168,16 @@ def try_except(func): ...@@ -167,6 +168,16 @@ def try_except(func):
return handler return handler
def threaded(func):
# Multi-threads a target function and returns thread. Usage: @threaded decorator
def wrapper(*args, **kwargs):
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread
return wrapper
def methods(instance): def methods(instance):
# Get class/instance methods # Get class/instance methods
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")] return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
......
...@@ -5,7 +5,6 @@ Logging utils ...@@ -5,7 +5,6 @@ Logging utils
import os import os
import warnings import warnings
from threading import Thread
import pkg_resources as pkg import pkg_resources as pkg
import torch import torch
...@@ -109,7 +108,7 @@ class Loggers(): ...@@ -109,7 +108,7 @@ class Loggers():
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), []) self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
if ni < 3: if ni < 3:
f = self.save_dir / f'train_batch{ni}.jpg' # filename f = self.save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start() plot_images(imgs, targets, paths, f)
if self.wandb and ni == 10: if self.wandb and ni == 10:
files = sorted(self.save_dir.glob('train*.jpg')) files = sorted(self.save_dir.glob('train*.jpg'))
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]}) self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
...@@ -132,7 +131,7 @@ class Loggers(): ...@@ -132,7 +131,7 @@ class Loggers():
def on_fit_epoch_end(self, vals, epoch, best_fitness, fi): def on_fit_epoch_end(self, vals, epoch, best_fitness, fi):
# Callback runs at the end of each fit (train+val) epoch # Callback runs at the end of each fit (train+val) epoch
x = {k: v for k, v in zip(self.keys, vals)} # dict x = dict(zip(self.keys, vals))
if self.csv: if self.csv:
file = self.save_dir / 'results.csv' file = self.save_dir / 'results.csv'
n = len(x) + 1 # number of cols n = len(x) + 1 # number of cols
...@@ -171,7 +170,7 @@ class Loggers(): ...@@ -171,7 +170,7 @@ class Loggers():
self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC') self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC')
if self.wandb: if self.wandb:
self.wandb.log({k: v for k, v in zip(self.keys[3:10], results)}) # log best.pt val results self.wandb.log(dict(zip(self.keys[3:10], results)))
self.wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]}) self.wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
# Calling wandb.log. TODO: Refactor this into WandbLogger.log_model # Calling wandb.log. TODO: Refactor this into WandbLogger.log_model
if not self.opt.evolve: if not self.opt.evolve:
......
...@@ -19,7 +19,7 @@ import torch ...@@ -19,7 +19,7 @@ import torch
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from utils.general import (CONFIG_DIR, FONT, LOGGER, Timeout, check_font, check_requirements, clip_coords, from utils.general import (CONFIG_DIR, FONT, LOGGER, Timeout, check_font, check_requirements, clip_coords,
increment_path, is_ascii, try_except, xywh2xyxy, xyxy2xywh) increment_path, is_ascii, threaded, try_except, xywh2xyxy, xyxy2xywh)
from utils.metrics import fitness from utils.metrics import fitness
# Settings # Settings
...@@ -32,9 +32,9 @@ class Colors: ...@@ -32,9 +32,9 @@ class Colors:
# Ultralytics color palette https://ultralytics.com/ # Ultralytics color palette https://ultralytics.com/
def __init__(self): def __init__(self):
# hex = matplotlib.colors.TABLEAU_COLORS.values() # hex = matplotlib.colors.TABLEAU_COLORS.values()
hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB', hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7') '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
self.palette = [self.hex2rgb('#' + c) for c in hex] self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
self.n = len(self.palette) self.n = len(self.palette)
def __call__(self, i, bgr=False): def __call__(self, i, bgr=False):
...@@ -100,7 +100,7 @@ class Annotator: ...@@ -100,7 +100,7 @@ class Annotator:
if label: if label:
tf = max(self.lw - 1, 1) # font thickness tf = max(self.lw - 1, 1) # font thickness
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
outside = p1[1] - h - 3 >= 0 # label fits outside box outside = p1[1] - h >= 3
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
cv2.putText(self.im, cv2.putText(self.im,
...@@ -184,6 +184,7 @@ def output_to_target(output): ...@@ -184,6 +184,7 @@ def output_to_target(output):
return np.array(targets) return np.array(targets)
@threaded
def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=1920, max_subplots=16): def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=1920, max_subplots=16):
# Plot image grid with labels # Plot image grid with labels
if isinstance(images, torch.Tensor): if isinstance(images, torch.Tensor):
...@@ -420,7 +421,7 @@ def plot_results(file='path/to/results.csv', dir=''): ...@@ -420,7 +421,7 @@ def plot_results(file='path/to/results.csv', dir=''):
ax = ax.ravel() ax = ax.ravel()
files = list(save_dir.glob('results*.csv')) files = list(save_dir.glob('results*.csv'))
assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.' assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
for fi, f in enumerate(files): for f in files:
try: try:
data = pd.read_csv(f) data = pd.read_csv(f)
s = [x.strip() for x in data.columns] s = [x.strip() for x in data.columns]
......
...@@ -23,7 +23,6 @@ import json ...@@ -23,7 +23,6 @@ import json
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from threading import Thread
import numpy as np import numpy as np
import torch import torch
...@@ -255,10 +254,8 @@ def run( ...@@ -255,10 +254,8 @@ def run(
# Plot images # Plot images
if plots and batch_i < 3: if plots and batch_i < 3:
f = save_dir / f'val_batch{batch_i}_labels.jpg' # labels plot_images(im, targets, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names) # labels
Thread(target=plot_images, args=(im, targets, paths, f, names), daemon=True).start() plot_images(im, output_to_target(out), paths, save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred
f = save_dir / f'val_batch{batch_i}_pred.jpg' # predictions
Thread(target=plot_images, args=(im, output_to_target(out), paths, f, names), daemon=True).start()
callbacks.run('on_val_batch_end') callbacks.run('on_val_batch_end')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论