Unverified 提交 b74929c9 authored 作者: Kalen Michael's avatar Kalen Michael 提交者: GitHub

Add `train.py` and `val.py` callbacks (#4220)

* added callbacks * Update callbacks.py * Update train.py * Update val.py * Fix CamlCase add staticmethod * Refactor logger into callbacks * Cleanup * New callback on_val_image_end() * Add curves and results images to TensorBoard Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 d8f18834
...@@ -34,7 +34,7 @@ from utils.autoanchor import check_anchors ...@@ -34,7 +34,7 @@ from utils.autoanchor import check_anchors
from utils.datasets import create_dataloader from utils.datasets import create_dataloader
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \ from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \ strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
check_requirements, print_mutation, set_logging, one_cycle, colorstr check_requirements, print_mutation, set_logging, one_cycle, colorstr, methods
from utils.downloads import attempt_download from utils.downloads import attempt_download
from utils.loss import ComputeLoss from utils.loss import ComputeLoss
from utils.plots import plot_labels, plot_evolution from utils.plots import plot_labels, plot_evolution
...@@ -42,6 +42,7 @@ from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_di ...@@ -42,6 +42,7 @@ from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_di
from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.metrics import fitness from utils.metrics import fitness
from utils.loggers import Loggers from utils.loggers import Loggers
from utils.callbacks import Callbacks
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
...@@ -52,6 +53,7 @@ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) ...@@ -52,6 +53,7 @@ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
def train(hyp, # path/to/hyp.yaml or hyp dictionary def train(hyp, # path/to/hyp.yaml or hyp dictionary
opt, opt,
device, device,
callbacks=Callbacks()
): ):
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \ save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
...@@ -77,12 +79,16 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -77,12 +79,16 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Loggers # Loggers
if RANK in [-1, 0]: if RANK in [-1, 0]:
loggers = Loggers(save_dir, weights, opt, hyp, LOGGER).start() # loggers dict loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
if loggers.wandb: if loggers.wandb:
data_dict = loggers.wandb.data_dict data_dict = loggers.wandb.data_dict
if resume: if resume:
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp
# Register actions
for k in methods(loggers):
callbacks.register_action(k, callback=getattr(loggers, k))
# Config # Config
plots = not evolve # create plots plots = not evolve # create plots
cuda = device.type != 'cpu' cuda = device.type != 'cpu'
...@@ -215,13 +221,15 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -215,13 +221,15 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
# model._initialize_biases(cf.to(device)) # model._initialize_biases(cf.to(device))
if plots: if plots:
plot_labels(labels, names, save_dir, loggers) plot_labels(labels, names, save_dir)
# Anchors # Anchors
if not opt.noautoanchor: if not opt.noautoanchor:
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
model.half().float() # pre-reduce anchor precision model.half().float() # pre-reduce anchor precision
callbacks.on_pretrain_routine_end()
# DDP mode # DDP mode
if cuda and RANK != -1: if cuda and RANK != -1:
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
...@@ -329,8 +337,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -329,8 +337,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB) mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % ( pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])) f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
loggers.on_train_batch_end(ni, model, imgs, targets, paths, plots) callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots)
# end batch ------------------------------------------------------------------------------------------------ # end batch ------------------------------------------------------------------------------------------------
# Scheduler # Scheduler
...@@ -339,7 +346,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -339,7 +346,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if RANK in [-1, 0]: if RANK in [-1, 0]:
# mAP # mAP
loggers.on_train_epoch_end(epoch) callbacks.on_train_epoch_end(epoch=epoch)
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights']) ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs final_epoch = epoch + 1 == epochs
if not noval or final_epoch: # Calculate mAP if not noval or final_epoch: # Calculate mAP
...@@ -353,14 +360,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -353,14 +360,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
save_json=is_coco and final_epoch, save_json=is_coco and final_epoch,
verbose=nc < 50 and final_epoch, verbose=nc < 50 and final_epoch,
plots=plots and final_epoch, plots=plots and final_epoch,
loggers=loggers, callbacks=callbacks,
compute_loss=compute_loss) compute_loss=compute_loss)
# Update best mAP # Update best mAP
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
if fi > best_fitness: if fi > best_fitness:
best_fitness = fi best_fitness = fi
loggers.on_train_val_end(mloss, results, lr, epoch, best_fitness, fi) callbacks.on_fit_epoch_end(mloss, results, lr, epoch, best_fitness, fi)
# Save model # Save model
if (not nosave) or (final_epoch and not evolve): # if save if (not nosave) or (final_epoch and not evolve): # if save
...@@ -377,7 +384,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -377,7 +384,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if best_fitness == fi: if best_fitness == fi:
torch.save(ckpt, best) torch.save(ckpt, best)
del ckpt del ckpt
loggers.on_model_save(last, epoch, final_epoch, best_fitness, fi) callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi)
# end epoch ---------------------------------------------------------------------------------------------------- # end epoch ----------------------------------------------------------------------------------------------------
# end training ----------------------------------------------------------------------------------------------------- # end training -----------------------------------------------------------------------------------------------------
...@@ -400,7 +407,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -400,7 +407,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
for f in last, best: for f in last, best:
if f.exists(): if f.exists():
strip_optimizer(f) # strip optimizers strip_optimizer(f) # strip optimizers
loggers.on_train_end(last, best, plots) callbacks.on_train_end(last, best, plots, epoch)
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
torch.cuda.empty_cache() torch.cuda.empty_cache()
return results return results
...@@ -448,6 +456,7 @@ def parse_opt(known=False): ...@@ -448,6 +456,7 @@ def parse_opt(known=False):
def main(opt): def main(opt):
# Checks
set_logging(RANK) set_logging(RANK)
if RANK in [-1, 0]: if RANK in [-1, 0]:
print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items())) print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
......
#!/usr/bin/env python
class Callbacks:
""""
Handles all registered callbacks for YOLOv5 Hooks
"""
_callbacks = {
'on_pretrain_routine_start': [],
'on_pretrain_routine_end': [],
'on_train_start': [],
'on_train_epoch_start': [],
'on_train_batch_start': [],
'optimizer_step': [],
'on_before_zero_grad': [],
'on_train_batch_end': [],
'on_train_epoch_end': [],
'on_val_start': [],
'on_val_batch_start': [],
'on_val_image_end': [],
'on_val_batch_end': [],
'on_val_end': [],
'on_fit_epoch_end': [], # fit = train + val
'on_model_save': [],
'on_train_end': [],
'teardown': [],
}
def __init__(self):
return
def register_action(self, hook, name='', callback=None):
"""
Register a new action to a callback hook
Args:
hook The callback hook name to register the action to
name The name of the action
callback The callback to fire
"""
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
assert callable(callback), f"callback '{callback}' is not callable"
self._callbacks[hook].append({'name': name, 'callback': callback})
def get_registered_actions(self, hook=None):
""""
Returns all the registered actions by callback hook
Args:
hook The name of the hook to check, defaults to all
"""
if hook:
return self._callbacks[hook]
else:
return self._callbacks
@staticmethod
def run_callbacks(register, *args, **kwargs):
"""
Loop through the registered actions and fire all callbacks
"""
for logger in register:
# print(f"Running callbacks.{logger['callback'].__name__}()")
logger['callback'](*args, **kwargs)
def on_pretrain_routine_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each pretraining routine
"""
self.run_callbacks(self._callbacks['on_pretrain_routine_start'], *args, **kwargs)
def on_pretrain_routine_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each pretraining routine
"""
self.run_callbacks(self._callbacks['on_pretrain_routine_end'], *args, **kwargs)
def on_train_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training
"""
self.run_callbacks(self._callbacks['on_train_start'], *args, **kwargs)
def on_train_epoch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training epoch
"""
self.run_callbacks(self._callbacks['on_train_epoch_start'], *args, **kwargs)
def on_train_batch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training batch
"""
self.run_callbacks(self._callbacks['on_train_batch_start'], *args, **kwargs)
def optimizer_step(self, *args, **kwargs):
"""
Fires all registered callbacks on each optimizer step
"""
self.run_callbacks(self._callbacks['optimizer_step'], *args, **kwargs)
def on_before_zero_grad(self, *args, **kwargs):
"""
Fires all registered callbacks before zero grad
"""
self.run_callbacks(self._callbacks['on_before_zero_grad'], *args, **kwargs)
def on_train_batch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each training batch
"""
self.run_callbacks(self._callbacks['on_train_batch_end'], *args, **kwargs)
def on_train_epoch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each training epoch
"""
self.run_callbacks(self._callbacks['on_train_epoch_end'], *args, **kwargs)
def on_val_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of the validation
"""
self.run_callbacks(self._callbacks['on_val_start'], *args, **kwargs)
def on_val_batch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each validation batch
"""
self.run_callbacks(self._callbacks['on_val_batch_start'], *args, **kwargs)
def on_val_image_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each val image
"""
self.run_callbacks(self._callbacks['on_val_image_end'], *args, **kwargs)
def on_val_batch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each validation batch
"""
self.run_callbacks(self._callbacks['on_val_batch_end'], *args, **kwargs)
def on_val_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of the validation
"""
self.run_callbacks(self._callbacks['on_val_end'], *args, **kwargs)
def on_fit_epoch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each fit (train+val) epoch
"""
self.run_callbacks(self._callbacks['on_fit_epoch_end'], *args, **kwargs)
def on_model_save(self, *args, **kwargs):
"""
Fires all registered callbacks after each model save
"""
self.run_callbacks(self._callbacks['on_model_save'], *args, **kwargs)
def on_train_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of training
"""
self.run_callbacks(self._callbacks['on_train_end'], *args, **kwargs)
def teardown(self, *args, **kwargs):
"""
Fires all registered callbacks before teardown
"""
self.run_callbacks(self._callbacks['teardown'], *args, **kwargs)
...@@ -67,6 +67,11 @@ def try_except(func): ...@@ -67,6 +67,11 @@ def try_except(func):
return handler return handler
def methods(instance):
# Get class/instance methods
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
def set_logging(rank=-1, verbose=True): def set_logging(rank=-1, verbose=True):
logging.basicConfig( logging.basicConfig(
format="%(message)s", format="%(message)s",
......
...@@ -29,10 +29,12 @@ class Loggers(): ...@@ -29,10 +29,12 @@ class Loggers():
self.hyp = hyp self.hyp = hyp
self.logger = logger # for printing results to console self.logger = logger # for printing results to console
self.include = include self.include = include
self.keys = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', # metrics
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
'x/lr0', 'x/lr1', 'x/lr2'] # params
for k in LOGGERS: for k in LOGGERS:
setattr(self, k, None) # init empty logger dictionary setattr(self, k, None) # init empty logger dictionary
def start(self):
self.csv = True # always log to csv self.csv = True # always log to csv
# Message # Message
...@@ -57,7 +59,11 @@ class Loggers(): ...@@ -57,7 +59,11 @@ class Loggers():
else: else:
self.wandb = None self.wandb = None
return self def on_pretrain_routine_end(self):
# Callback runs on pre-train routine end
paths = self.save_dir.glob('*labels*.jpg') # training labels
if self.wandb:
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
def on_train_batch_end(self, ni, model, imgs, targets, paths, plots): def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
# Callback runs on train batch end # Callback runs on train batch end
...@@ -78,8 +84,8 @@ class Loggers(): ...@@ -78,8 +84,8 @@ class Loggers():
if self.wandb: if self.wandb:
self.wandb.current_epoch = epoch + 1 self.wandb.current_epoch = epoch + 1
def on_val_batch_end(self, pred, predn, path, names, im): def on_val_image_end(self, pred, predn, path, names, im):
# Callback runs on train batch end # Callback runs on val image end
if self.wandb: if self.wandb:
self.wandb.val_one_image(pred, predn, path, names, im) self.wandb.val_one_image(pred, predn, path, names, im)
...@@ -89,25 +95,20 @@ class Loggers(): ...@@ -89,25 +95,20 @@ class Loggers():
files = sorted(self.save_dir.glob('val*.jpg')) files = sorted(self.save_dir.glob('val*.jpg'))
self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]}) self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})
def on_train_val_end(self, mloss, results, lr, epoch, best_fitness, fi): def on_fit_epoch_end(self, mloss, results, lr, epoch, best_fitness, fi):
# Callback runs on val end during training # Callback runs at the end of each fit (train+val) epoch
vals = list(mloss) + list(results) + lr vals = list(mloss) + list(results) + lr
keys = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss x = {k: v for k, v in zip(self.keys, vals)} # dict
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', # metrics
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
'x/lr0', 'x/lr1', 'x/lr2'] # params
x = {k: v for k, v in zip(keys, vals)} # dict
if self.csv: if self.csv:
file = self.save_dir / 'results.csv' file = self.save_dir / 'results.csv'
n = len(x) + 1 # number of cols n = len(x) + 1 # number of cols
s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # add header s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + self.keys)).rstrip(',') + '\n') # add header
with open(file, 'a') as f: with open(file, 'a') as f:
f.write(s + ('%20.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n') f.write(s + ('%20.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n')
if self.tb: if self.tb:
for k, v in x.items(): for k, v in x.items():
self.tb.add_scalar(k, v, epoch) # TensorBoard self.tb.add_scalar(k, v, epoch)
if self.wandb: if self.wandb:
self.wandb.log(x) self.wandb.log(x)
...@@ -119,20 +120,22 @@ class Loggers(): ...@@ -119,20 +120,22 @@ class Loggers():
if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1: if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi) self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
def on_train_end(self, last, best, plots): def on_train_end(self, last, best, plots, epoch):
# Callback runs on training end # Callback runs on training end
if plots: if plots:
plot_results(dir=self.save_dir) # save results.png plot_results(dir=self.save_dir) # save results.png
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]] files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
if self.tb:
from PIL import Image
import numpy as np
for f in files:
self.tb.add_image(f.stem, np.asarray(Image.open(f)), epoch, dataformats='HWC')
if self.wandb: if self.wandb:
wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]}) wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
wandb.log_artifact(str(best if best.exists() else last), type='model', wandb.log_artifact(str(best if best.exists() else last), type='model',
name='run_' + self.wandb.wandb_run.id + '_model', name='run_' + self.wandb.wandb_run.id + '_model',
aliases=['latest', 'best', 'stripped']) aliases=['latest', 'best', 'stripped'])
self.wandb.finish_run() self.wandb.finish_run()
def log_images(self, paths):
# Log images
if self.wandb:
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
...@@ -281,7 +281,7 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx ...@@ -281,7 +281,7 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx
plt.savefig(str(Path(path).name) + '.png', dpi=300) plt.savefig(str(Path(path).name) + '.png', dpi=300)
def plot_labels(labels, names=(), save_dir=Path(''), loggers=None): def plot_labels(labels, names=(), save_dir=Path('')):
# plot dataset labels # plot dataset labels
print('Plotting labels... ') print('Plotting labels... ')
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
...@@ -324,10 +324,6 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None): ...@@ -324,10 +324,6 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
matplotlib.use('Agg') matplotlib.use('Agg')
plt.close() plt.close()
# loggers
if loggers:
loggers.log_images(save_dir.glob('*labels*.jpg'))
def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution() def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
# Plot hyperparameter evolution results in evolve.txt # Plot hyperparameter evolution results in evolve.txt
......
...@@ -25,7 +25,7 @@ from utils.general import coco80_to_coco91_class, check_dataset, check_file, che ...@@ -25,7 +25,7 @@ from utils.general import coco80_to_coco91_class, check_dataset, check_file, che
from utils.metrics import ap_per_class, ConfusionMatrix from utils.metrics import ap_per_class, ConfusionMatrix
from utils.plots import plot_images, output_to_target, plot_study_txt from utils.plots import plot_images, output_to_target, plot_study_txt
from utils.torch_utils import select_device, time_sync from utils.torch_utils import select_device, time_sync
from utils.loggers import Loggers from utils.callbacks import Callbacks
def save_one_txt(predn, save_conf, shape, file): def save_one_txt(predn, save_conf, shape, file):
...@@ -97,7 +97,7 @@ def run(data, ...@@ -97,7 +97,7 @@ def run(data,
dataloader=None, dataloader=None,
save_dir=Path(''), save_dir=Path(''),
plots=True, plots=True,
loggers=Loggers(), callbacks=Callbacks(),
compute_loss=None, compute_loss=None,
): ):
# Initialize/load model and set device # Initialize/load model and set device
...@@ -213,7 +213,7 @@ def run(data, ...@@ -213,7 +213,7 @@ def run(data,
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt')) save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
if save_json: if save_json:
save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
loggers.on_val_batch_end(pred, predn, path, names, img[si]) callbacks.on_val_image_end(pred, predn, path, names, img[si])
# Plot images # Plot images
if plots and batch_i < 3: if plots and batch_i < 3:
...@@ -250,7 +250,7 @@ def run(data, ...@@ -250,7 +250,7 @@ def run(data,
# Plots # Plots
if plots: if plots:
confusion_matrix.plot(save_dir=save_dir, names=list(names.values())) confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
loggers.on_val_end() callbacks.on_val_end()
# Save JSON # Save JSON
if save_json and len(jdict): if save_json and len(jdict):
...@@ -282,7 +282,7 @@ def run(data, ...@@ -282,7 +282,7 @@ def run(data,
model.float() # for training model.float() # for training
if not training: if not training:
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
print(f"Results saved to {save_dir}{s}") print(f"Results saved to {colorstr('bold', save_dir)}{s}")
maps = np.zeros(nc) + map maps = np.zeros(nc) + map
for i, c in enumerate(ap_class): for i, c in enumerate(ap_class):
maps[c] = ap[i] maps[c] = ap[i]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论