Unverified 提交 fa201f96 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update `train(hyp, *args)` to accept `hyp` file or dict (#3668)

上级 6d6e2ca6
...@@ -39,12 +39,11 @@ from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume ...@@ -39,12 +39,11 @@ from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def train(hyp, def train(hyp, # path/to/hyp.yaml or hyp dictionary
opt, opt,
device, device,
tb_writer=None tb_writer=None
): ):
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \ save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
opt.single_cls opt.single_cls
...@@ -56,6 +55,12 @@ def train(hyp, ...@@ -56,6 +55,12 @@ def train(hyp,
best = wdir / 'best.pt' best = wdir / 'best.pt'
results_file = save_dir / 'results.txt' results_file = save_dir / 'results.txt'
# Hyperparameters
if isinstance(hyp, str):
with open(hyp) as f:
hyp = yaml.safe_load(f) # load hyps dict
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
# Save run settings # Save run settings
with open(save_dir / 'hyp.yaml', 'w') as f: with open(save_dir / 'hyp.yaml', 'w') as f:
yaml.safe_dump(hyp, f, sort_keys=False) yaml.safe_dump(hyp, f, sort_keys=False)
...@@ -529,10 +534,6 @@ if __name__ == '__main__': ...@@ -529,10 +534,6 @@ if __name__ == '__main__':
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training' assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
opt.batch_size = opt.total_batch_size // opt.world_size opt.batch_size = opt.total_batch_size // opt.world_size
# Hyperparameters
with open(opt.hyp) as f:
hyp = yaml.safe_load(f) # load hyps
# Train # Train
logger.info(opt) logger.info(opt)
if not opt.evolve: if not opt.evolve:
...@@ -541,7 +542,7 @@ if __name__ == '__main__': ...@@ -541,7 +542,7 @@ if __name__ == '__main__':
prefix = colorstr('tensorboard: ') prefix = colorstr('tensorboard: ')
logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/") logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
tb_writer = SummaryWriter(opt.save_dir) # Tensorboard tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
train(hyp, opt, device, tb_writer) train(opt.hyp, opt, device, tb_writer)
# Evolve hyperparameters (optional) # Evolve hyperparameters (optional)
else: else:
...@@ -575,6 +576,8 @@ if __name__ == '__main__': ...@@ -575,6 +576,8 @@ if __name__ == '__main__':
'mosaic': (1, 0.0, 1.0), # image mixup (probability) 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
'mixup': (1, 0.0, 1.0)} # image mixup (probability) 'mixup': (1, 0.0, 1.0)} # image mixup (probability)
with open(opt.hyp) as f:
hyp = yaml.safe_load(f) # load hyps dict
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve' assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
opt.notest, opt.nosave = True, True # only test/save final epoch opt.notest, opt.nosave = True, True # only test/save final epoch
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论