Unverified 提交 2317f86c authored 作者: Kalen Michael's avatar Kalen Michael 提交者: GitHub

Optimised Callback Class to Reduce Code and Fix Errors (#4688)

* added callbacks * added back callback to main * added save_dir to callback output * reduced code count * updated callbacks * added default callback class to main, added missing parameters to on_model_save * Glenn updates Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 54874518
...@@ -56,7 +56,7 @@ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) ...@@ -56,7 +56,7 @@ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
def train(hyp, # path/to/hyp.yaml or hyp dictionary def train(hyp, # path/to/hyp.yaml or hyp dictionary
opt, opt,
device, device,
callbacks=Callbacks() callbacks
): ):
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \ save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
...@@ -231,7 +231,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -231,7 +231,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
model.half().float() # pre-reduce anchor precision model.half().float() # pre-reduce anchor precision
callbacks.on_pretrain_routine_end() callbacks.run('on_pretrain_routine_end')
# DDP mode # DDP mode
if cuda and RANK != -1: if cuda and RANK != -1:
...@@ -333,7 +333,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -333,7 +333,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB) mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % ( pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])) f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots, opt.sync_bn) callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots, opt.sync_bn)
# end batch ------------------------------------------------------------------------------------------------ # end batch ------------------------------------------------------------------------------------------------
# Scheduler # Scheduler
...@@ -342,7 +342,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -342,7 +342,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if RANK in [-1, 0]: if RANK in [-1, 0]:
# mAP # mAP
callbacks.on_train_epoch_end(epoch=epoch) callbacks.run('on_train_epoch_end', epoch=epoch)
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights']) ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
final_epoch = (epoch + 1 == epochs) or stopper.possible_stop final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
if not noval or final_epoch: # Calculate mAP if not noval or final_epoch: # Calculate mAP
...@@ -364,7 +364,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -364,7 +364,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if fi > best_fitness: if fi > best_fitness:
best_fitness = fi best_fitness = fi
log_vals = list(mloss) + list(results) + lr log_vals = list(mloss) + list(results) + lr
callbacks.on_fit_epoch_end(log_vals, epoch, best_fitness, fi) callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
# Save model # Save model
if (not nosave) or (final_epoch and not evolve): # if save if (not nosave) or (final_epoch and not evolve): # if save
...@@ -381,7 +381,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -381,7 +381,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if best_fitness == fi: if best_fitness == fi:
torch.save(ckpt, best) torch.save(ckpt, best)
del ckpt del ckpt
callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi) callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
# Stop Single-GPU # Stop Single-GPU
if RANK == -1 and stopper(epoch=epoch, fitness=fi): if RANK == -1 and stopper(epoch=epoch, fitness=fi):
...@@ -418,7 +418,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -418,7 +418,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
for f in last, best: for f in last, best:
if f.exists(): if f.exists():
strip_optimizer(f) # strip optimizers strip_optimizer(f) # strip optimizers
callbacks.on_train_end(last, best, plots, epoch) callbacks.run('on_train_end', last, best, plots, epoch)
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}") LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -467,7 +467,7 @@ def parse_opt(known=False): ...@@ -467,7 +467,7 @@ def parse_opt(known=False):
return opt return opt
def main(opt): def main(opt, callbacks=Callbacks()):
# Checks # Checks
set_logging(RANK) set_logging(RANK)
if RANK in [-1, 0]: if RANK in [-1, 0]:
...@@ -505,7 +505,7 @@ def main(opt): ...@@ -505,7 +505,7 @@ def main(opt):
# Train # Train
if not opt.evolve: if not opt.evolve:
train(opt.hyp, opt, device) train(opt.hyp, opt, device, callbacks)
if WORLD_SIZE > 1 and RANK == 0: if WORLD_SIZE > 1 and RANK == 0:
_ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')] _ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')]
...@@ -585,7 +585,7 @@ def main(opt): ...@@ -585,7 +585,7 @@ def main(opt):
hyp[k] = round(hyp[k], 5) # significant digits hyp[k] = round(hyp[k], 5) # significant digits
# Train mutation # Train mutation
results = train(hyp.copy(), opt, device) results = train(hyp.copy(), opt, device, callbacks)
# Write mutation results # Write mutation results
print_mutation(results, hyp.copy(), save_dir, opt.bucket) print_mutation(results, hyp.copy(), save_dir, opt.bucket)
......
...@@ -9,6 +9,7 @@ class Callbacks: ...@@ -9,6 +9,7 @@ class Callbacks:
Handles all registered callbacks for YOLOv5 Hooks Handles all registered callbacks for YOLOv5 Hooks
""" """
# Define the available callbacks
_callbacks = { _callbacks = {
'on_pretrain_routine_start': [], 'on_pretrain_routine_start': [],
'on_pretrain_routine_end': [], 'on_pretrain_routine_end': [],
...@@ -34,16 +35,13 @@ class Callbacks: ...@@ -34,16 +35,13 @@ class Callbacks:
'teardown': [], 'teardown': [],
} }
def __init__(self):
return
def register_action(self, hook, name='', callback=None): def register_action(self, hook, name='', callback=None):
""" """
Register a new action to a callback hook Register a new action to a callback hook
Args: Args:
hook The callback hook name to register the action to hook The callback hook name to register the action to
name The name of the action name The name of the action for later reference
callback The callback to fire callback The callback to fire
""" """
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}" assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
...@@ -62,118 +60,17 @@ class Callbacks: ...@@ -62,118 +60,17 @@ class Callbacks:
else: else:
return self._callbacks return self._callbacks
def run_callbacks(self, hook, *args, **kwargs): def run(self, hook, *args, **kwargs):
""" """
Loop through the registered actions and fire all callbacks Loop through the registered actions and fire all callbacks
"""
for logger in self._callbacks[hook]:
# print(f"Running callbacks.{logger['callback'].__name__}()")
logger['callback'](*args, **kwargs)
def on_pretrain_routine_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each pretraining routine
"""
self.run_callbacks('on_pretrain_routine_start', *args, **kwargs)
def on_pretrain_routine_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each pretraining routine
"""
self.run_callbacks('on_pretrain_routine_end', *args, **kwargs)
def on_train_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training
"""
self.run_callbacks('on_train_start', *args, **kwargs)
def on_train_epoch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training epoch
"""
self.run_callbacks('on_train_epoch_start', *args, **kwargs)
def on_train_batch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training batch
"""
self.run_callbacks('on_train_batch_start', *args, **kwargs)
def optimizer_step(self, *args, **kwargs): Args:
""" hook The name of the hook to check, defaults to all
Fires all registered callbacks on each optimizer step args Arguments to receive from YOLOv5
""" kwargs Keyword Arguments to receive from YOLOv5
self.run_callbacks('optimizer_step', *args, **kwargs)
def on_before_zero_grad(self, *args, **kwargs):
"""
Fires all registered callbacks before zero grad
"""
self.run_callbacks('on_before_zero_grad', *args, **kwargs)
def on_train_batch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each training batch
"""
self.run_callbacks('on_train_batch_end', *args, **kwargs)
def on_train_epoch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each training epoch
"""
self.run_callbacks('on_train_epoch_end', *args, **kwargs)
def on_val_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of the validation
"""
self.run_callbacks('on_val_start', *args, **kwargs)
def on_val_batch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each validation batch
"""
self.run_callbacks('on_val_batch_start', *args, **kwargs)
def on_val_image_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each val image
"""
self.run_callbacks('on_val_image_end', *args, **kwargs)
def on_val_batch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each validation batch
"""
self.run_callbacks('on_val_batch_end', *args, **kwargs)
def on_val_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of the validation
"""
self.run_callbacks('on_val_end', *args, **kwargs)
def on_fit_epoch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each fit (train+val) epoch
"""
self.run_callbacks('on_fit_epoch_end', *args, **kwargs)
def on_model_save(self, *args, **kwargs):
"""
Fires all registered callbacks after each model save
""" """
self.run_callbacks('on_model_save', *args, **kwargs)
def on_train_end(self, *args, **kwargs): assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
"""
Fires all registered callbacks at the end of training
"""
self.run_callbacks('on_train_end', *args, **kwargs)
def teardown(self, *args, **kwargs): for logger in self._callbacks[hook]:
""" logger['callback'](*args, **kwargs)
Fires all registered callbacks before teardown
"""
self.run_callbacks('teardown', *args, **kwargs)
...@@ -216,7 +216,7 @@ def run(data, ...@@ -216,7 +216,7 @@ def run(data,
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt')) save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
if save_json: if save_json:
save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
callbacks.on_val_image_end(pred, predn, path, names, img[si]) callbacks.run('on_val_image_end', pred, predn, path, names, img[si])
# Plot images # Plot images
if plots and batch_i < 3: if plots and batch_i < 3:
...@@ -253,7 +253,7 @@ def run(data, ...@@ -253,7 +253,7 @@ def run(data,
# Plots # Plots
if plots: if plots:
confusion_matrix.plot(save_dir=save_dir, names=list(names.values())) confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
callbacks.on_val_end() callbacks.run('on_val_end')
# Save JSON # Save JSON
if save_json and len(jdict): if save_json and len(jdict):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论