Unverified 提交 fa201f96 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update `train(hyp, *args)` to accept `hyp` file or dict (#3668)

上级 6d6e2ca6
......@@ -39,12 +39,11 @@ from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
logger = logging.getLogger(__name__)
def train(hyp,
def train(hyp, # path/to/hyp.yaml or hyp dictionary
opt,
device,
tb_writer=None
):
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
opt.single_cls
......@@ -56,6 +55,12 @@ def train(hyp,
best = wdir / 'best.pt'
results_file = save_dir / 'results.txt'
# Hyperparameters
if isinstance(hyp, str):
with open(hyp) as f:
hyp = yaml.safe_load(f) # load hyps dict
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
# Save run settings
with open(save_dir / 'hyp.yaml', 'w') as f:
yaml.safe_dump(hyp, f, sort_keys=False)
......@@ -529,10 +534,6 @@ if __name__ == '__main__':
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
opt.batch_size = opt.total_batch_size // opt.world_size
# Hyperparameters
with open(opt.hyp) as f:
hyp = yaml.safe_load(f) # load hyps
# Train
logger.info(opt)
if not opt.evolve:
......@@ -541,7 +542,7 @@ if __name__ == '__main__':
prefix = colorstr('tensorboard: ')
logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
train(hyp, opt, device, tb_writer)
train(opt.hyp, opt, device, tb_writer)
# Evolve hyperparameters (optional)
else:
......@@ -575,6 +576,8 @@ if __name__ == '__main__':
'mosaic': (1, 0.0, 1.0), # image mixup (probability)
'mixup': (1, 0.0, 1.0)} # image mixup (probability)
with open(opt.hyp) as f:
hyp = yaml.safe_load(f) # load hyps dict
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
opt.notest, opt.nosave = True, True # only test/save final epoch
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论