Unverified 提交 045d5d86 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update TensorBoard (#3669)

上级 fa201f96
...@@ -42,7 +42,6 @@ logger = logging.getLogger(__name__) ...@@ -42,7 +42,6 @@ logger = logging.getLogger(__name__)
def train(hyp, # path/to/hyp.yaml or hyp dictionary def train(hyp, # path/to/hyp.yaml or hyp dictionary
opt, opt,
device, device,
tb_writer=None
): ):
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \ save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
...@@ -74,9 +73,16 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -74,9 +73,16 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
with open(opt.data) as f: with open(opt.data) as f:
data_dict = yaml.safe_load(f) # data dict data_dict = yaml.safe_load(f) # data dict
# Logging- Doing this before checking the dataset. Might update data_dict # Loggers
loggers = {'wandb': None} # loggers dict loggers = {'wandb': None, 'tb': None} # loggers dict
if rank in [-1, 0]: if rank in [-1, 0]:
# TensorBoard
if not opt.evolve:
prefix = colorstr('tensorboard: ')
logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
loggers['tb'] = SummaryWriter(opt.save_dir)
# W&B
opt.hyp = hyp # add hyperparameters opt.hyp = hyp # add hyperparameters
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict) wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
...@@ -219,8 +225,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -219,8 +225,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# model._initialize_biases(cf.to(device)) # model._initialize_biases(cf.to(device))
if plots: if plots:
plot_labels(labels, names, save_dir, loggers) plot_labels(labels, names, save_dir, loggers)
if tb_writer: if loggers['tb']:
tb_writer.add_histogram('classes', c, 0) loggers['tb'].add_histogram('classes', c, 0) # TensorBoard
# Anchors # Anchors
if not opt.noautoanchor: if not opt.noautoanchor:
...@@ -341,10 +347,10 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -341,10 +347,10 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if plots and ni < 3: if plots and ni < 3:
f = save_dir / f'train_batch{ni}.jpg' # filename f = save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start() Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
if tb_writer and ni == 0: if loggers['tb'] and ni == 0: # TensorBoard
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning warnings.simplefilter('ignore') # suppress jit trace warning
tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), []) loggers['tb'].add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
elif plots and ni == 10 and wandb_logger.wandb: elif plots and ni == 10 and wandb_logger.wandb:
wandb_logger.log({'Mosaics': [wandb_logger.wandb.Image(str(x), caption=x.name) for x in wandb_logger.log({'Mosaics': [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
save_dir.glob('train*.jpg') if x.exists()]}) save_dir.glob('train*.jpg') if x.exists()]})
...@@ -352,7 +358,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -352,7 +358,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# end batch ------------------------------------------------------------------------------------------------ # end batch ------------------------------------------------------------------------------------------------
# Scheduler # Scheduler
lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard lr = [x['lr'] for x in optimizer.param_groups] # for loggers
scheduler.step() scheduler.step()
# DDP process 0 or single-GPU # DDP process 0 or single-GPU
...@@ -385,8 +391,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -385,8 +391,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
'x/lr0', 'x/lr1', 'x/lr2'] # params 'x/lr0', 'x/lr1', 'x/lr2'] # params
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags): for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
if tb_writer: if loggers['tb']:
tb_writer.add_scalar(tag, x, epoch) # tensorboard loggers['tb'].add_scalar(tag, x, epoch) # TensorBoard
if wandb_logger.wandb: if wandb_logger.wandb:
wandb_logger.log({tag: x}) # W&B wandb_logger.log({tag: x}) # W&B
...@@ -537,12 +543,7 @@ if __name__ == '__main__': ...@@ -537,12 +543,7 @@ if __name__ == '__main__':
# Train # Train
logger.info(opt) logger.info(opt)
if not opt.evolve: if not opt.evolve:
tb_writer = None # init loggers train(opt.hyp, opt, device)
if opt.global_rank in [-1, 0]:
prefix = colorstr('tensorboard: ')
logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
train(opt.hyp, opt, device, tb_writer)
# Evolve hyperparameters (optional) # Evolve hyperparameters (optional)
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论