Unverified 提交 e0700cce authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Daemon `plot_labels()` for faster start (#9057)

* Daemon `plot_labels()` for faster start * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update train.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
上级 27fb6fd8
...@@ -52,7 +52,7 @@ from utils.loggers import Loggers ...@@ -52,7 +52,7 @@ from utils.loggers import Loggers
from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.loss import ComputeLoss from utils.loss import ComputeLoss
from utils.metrics import fitness from utils.metrics import fitness
from utils.plots import plot_evolve, plot_labels from utils.plots import plot_evolve
from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer, from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer,
smart_resume, torch_distributed_zero_first) smart_resume, torch_distributed_zero_first)
...@@ -215,15 +215,11 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio ...@@ -215,15 +215,11 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
prefix=colorstr('val: '))[0] prefix=colorstr('val: '))[0]
if not resume: if not resume:
if plots:
plot_labels(labels, names, save_dir)
# Anchors
if not opt.noautoanchor: if not opt.noautoanchor:
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # run AutoAnchor
model.half().float() # pre-reduce anchor precision model.half().float() # pre-reduce anchor precision
callbacks.run('on_pretrain_routine_end') callbacks.run('on_pretrain_routine_end', labels, names, plots)
# DDP mode # DDP mode
if cuda and RANK != -1: if cuda and RANK != -1:
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
Callback utils Callback utils
""" """
import threading
class Callbacks: class Callbacks:
"""" """"
...@@ -55,17 +57,20 @@ class Callbacks: ...@@ -55,17 +57,20 @@ class Callbacks:
""" """
return self._callbacks[hook] if hook else self._callbacks return self._callbacks[hook] if hook else self._callbacks
def run(self, hook, *args, **kwargs): def run(self, hook, *args, thread=False, **kwargs):
""" """
Loop through the registered actions and fire all callbacks Loop through the registered actions and fire all callbacks on main thread
Args: Args:
hook: The name of the hook to check, defaults to all hook: The name of the hook to check, defaults to all
args: Arguments to receive from YOLOv5 args: Arguments to receive from YOLOv5
thread: (boolean) Run callbacks in daemon thread
kwargs: Keyword Arguments to receive from YOLOv5 kwargs: Keyword Arguments to receive from YOLOv5
""" """
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}" assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
for logger in self._callbacks[hook]: for logger in self._callbacks[hook]:
logger['callback'](*args, **kwargs) if thread:
threading.Thread(target=logger['callback'], args=args, kwargs=kwargs, daemon=True).start()
else:
logger['callback'](*args, **kwargs)
...@@ -622,7 +622,7 @@ def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry ...@@ -622,7 +622,7 @@ def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry
dir.mkdir(parents=True, exist_ok=True) # make directory dir.mkdir(parents=True, exist_ok=True) # make directory
if threads > 1: if threads > 1:
pool = ThreadPool(threads) pool = ThreadPool(threads)
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
pool.close() pool.close()
pool.join() pool.join()
else: else:
......
...@@ -11,10 +11,10 @@ import pkg_resources as pkg ...@@ -11,10 +11,10 @@ import pkg_resources as pkg
import torch import torch
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from utils.general import colorstr, cv2 from utils.general import colorstr, cv2, threaded
from utils.loggers.clearml.clearml_utils import ClearmlLogger from utils.loggers.clearml.clearml_utils import ClearmlLogger
from utils.loggers.wandb.wandb_utils import WandbLogger from utils.loggers.wandb.wandb_utils import WandbLogger
from utils.plots import plot_images, plot_results from utils.plots import plot_images, plot_labels, plot_results
from utils.torch_utils import de_parallel from utils.torch_utils import de_parallel
LOGGERS = ('csv', 'tb', 'wandb', 'clearml') # *.csv, TensorBoard, Weights & Biases, ClearML LOGGERS = ('csv', 'tb', 'wandb', 'clearml') # *.csv, TensorBoard, Weights & Biases, ClearML
...@@ -110,13 +110,15 @@ class Loggers(): ...@@ -110,13 +110,15 @@ class Loggers():
# Callback runs on train start # Callback runs on train start
pass pass
def on_pretrain_routine_end(self): def on_pretrain_routine_end(self, labels, names, plots):
# Callback runs on pre-train routine end # Callback runs on pre-train routine end
if plots:
plot_labels(labels, names, self.save_dir)
paths = self.save_dir.glob('*labels*.jpg') # training labels paths = self.save_dir.glob('*labels*.jpg') # training labels
if self.wandb: if self.wandb:
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]}) self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
if self.clearml: # if self.clearml:
pass # ClearML saves these images automatically using hooks # pass # ClearML saves these images automatically using hooks
def on_train_batch_end(self, ni, model, imgs, targets, paths, plots): def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
# Callback runs on train batch end # Callback runs on train batch end
......
...@@ -340,7 +340,6 @@ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_ ...@@ -340,7 +340,6 @@ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_
@try_except # known issue https://github.com/ultralytics/yolov5/issues/5395 @try_except # known issue https://github.com/ultralytics/yolov5/issues/5395
@Timeout(30) # known issue https://github.com/ultralytics/yolov5/issues/5611
def plot_labels(labels, names=(), save_dir=Path('')): def plot_labels(labels, names=(), save_dir=Path('')):
# plot dataset labels # plot dataset labels
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ") LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论