Unverified 提交 e0700cce authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Daemon `plot_labels()` for faster start (#9057)

* Daemon `plot_labels()` for faster start * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update train.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
上级 27fb6fd8
......@@ -52,7 +52,7 @@ from utils.loggers import Loggers
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.loss import ComputeLoss
from utils.metrics import fitness
from utils.plots import plot_evolve, plot_labels
from utils.plots import plot_evolve
from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer,
smart_resume, torch_distributed_zero_first)
......@@ -215,15 +215,11 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
prefix=colorstr('val: '))[0]
if not resume:
if plots:
plot_labels(labels, names, save_dir)
# Anchors
if not opt.noautoanchor:
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # run AutoAnchor
model.half().float() # pre-reduce anchor precision
callbacks.run('on_pretrain_routine_end')
callbacks.run('on_pretrain_routine_end', labels, names, plots)
# DDP mode
if cuda and RANK != -1:
......
......@@ -3,6 +3,8 @@
Callback utils
"""
import threading
class Callbacks:
""""
......@@ -55,17 +57,20 @@ class Callbacks:
"""
return self._callbacks[hook] if hook else self._callbacks
def run(self, hook, *args, **kwargs):
def run(self, hook, *args, thread=False, **kwargs):
"""
Loop through the registered actions and fire all callbacks
Loop through the registered actions and fire all callbacks on main thread
Args:
hook: The name of the hook to check, defaults to all
args: Arguments to receive from YOLOv5
thread: (boolean) Run callbacks in daemon thread
kwargs: Keyword Arguments to receive from YOLOv5
"""
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
for logger in self._callbacks[hook]:
logger['callback'](*args, **kwargs)
if thread:
threading.Thread(target=logger['callback'], args=args, kwargs=kwargs, daemon=True).start()
else:
logger['callback'](*args, **kwargs)
......@@ -622,7 +622,7 @@ def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry
dir.mkdir(parents=True, exist_ok=True) # make directory
if threads > 1:
pool = ThreadPool(threads)
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
pool.close()
pool.join()
else:
......
......@@ -11,10 +11,10 @@ import pkg_resources as pkg
import torch
from torch.utils.tensorboard import SummaryWriter
from utils.general import colorstr, cv2
from utils.general import colorstr, cv2, threaded
from utils.loggers.clearml.clearml_utils import ClearmlLogger
from utils.loggers.wandb.wandb_utils import WandbLogger
from utils.plots import plot_images, plot_results
from utils.plots import plot_images, plot_labels, plot_results
from utils.torch_utils import de_parallel
LOGGERS = ('csv', 'tb', 'wandb', 'clearml') # *.csv, TensorBoard, Weights & Biases, ClearML
......@@ -110,13 +110,15 @@ class Loggers():
# Callback runs on train start
pass
def on_pretrain_routine_end(self):
def on_pretrain_routine_end(self, labels, names, plots):
# Callback runs on pre-train routine end
if plots:
plot_labels(labels, names, self.save_dir)
paths = self.save_dir.glob('*labels*.jpg') # training labels
if self.wandb:
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
if self.clearml:
pass # ClearML saves these images automatically using hooks
# if self.clearml:
# pass # ClearML saves these images automatically using hooks
def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
# Callback runs on train batch end
......
......@@ -340,7 +340,6 @@ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_
@try_except # known issue https://github.com/ultralytics/yolov5/issues/5395
@Timeout(30) # known issue https://github.com/ultralytics/yolov5/issues/5611
def plot_labels(labels, names=(), save_dir=Path('')):
# plot dataset labels
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论