Unverified 提交 4b52e19a authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

COCO evolution fix (#3388)

* COCO evolution fix * cleanup * update print * print fix
上级 21a9607e
...@@ -62,7 +62,6 @@ def train(hyp, opt, device, tb_writer=None): ...@@ -62,7 +62,6 @@ def train(hyp, opt, device, tb_writer=None):
init_seeds(2 + rank) init_seeds(2 + rank)
with open(opt.data) as f: with open(opt.data) as f:
data_dict = yaml.safe_load(f) # data dict data_dict = yaml.safe_load(f) # data dict
is_coco = opt.data.endswith('coco.yaml')
# Logging- Doing this before checking the dataset. Might update data_dict # Logging- Doing this before checking the dataset. Might update data_dict
loggers = {'wandb': None} # loggers dict loggers = {'wandb': None} # loggers dict
...@@ -78,6 +77,7 @@ def train(hyp, opt, device, tb_writer=None): ...@@ -78,6 +77,7 @@ def train(hyp, opt, device, tb_writer=None):
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
is_coco = opt.data.endswith('coco.yaml') and nc == 80 # COCO dataset
# Model # Model
pretrained = weights.endswith('.pt') pretrained = weights.endswith('.pt')
...@@ -358,6 +358,7 @@ def train(hyp, opt, device, tb_writer=None): ...@@ -358,6 +358,7 @@ def train(hyp, opt, device, tb_writer=None):
single_cls=opt.single_cls, single_cls=opt.single_cls,
dataloader=testloader, dataloader=testloader,
save_dir=save_dir, save_dir=save_dir,
save_json=is_coco and final_epoch,
verbose=nc < 50 and final_epoch, verbose=nc < 50 and final_epoch,
plots=plots and final_epoch, plots=plots and final_epoch,
wandb_logger=wandb_logger, wandb_logger=wandb_logger,
...@@ -409,41 +410,38 @@ def train(hyp, opt, device, tb_writer=None): ...@@ -409,41 +410,38 @@ def train(hyp, opt, device, tb_writer=None):
# end epoch ---------------------------------------------------------------------------------------------------- # end epoch ----------------------------------------------------------------------------------------------------
# end training # end training
if rank in [-1, 0]: if rank in [-1, 0]:
# Plots logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
if plots: if plots:
plot_results(save_dir=save_dir) # save as results.png plot_results(save_dir=save_dir) # save as results.png
if wandb_logger.wandb: if wandb_logger.wandb:
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]] files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files
if (save_dir / f).exists()]}) if (save_dir / f).exists()]})
# Test best.pt
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) if not opt.evolve:
if opt.data.endswith('coco.yaml') and nc == 80: # if COCO if is_coco: # COCO dataset
for m in [last, best] if best.exists() else [last]: # speed, mAP tests for m in [last, best] if best.exists() else [last]: # speed, mAP tests
results, _, _ = test.test(opt.data, results, _, _ = test.test(opt.data,
batch_size=batch_size * 2, batch_size=batch_size * 2,
imgsz=imgsz_test, imgsz=imgsz_test,
conf_thres=0.001, conf_thres=0.001,
iou_thres=0.7, iou_thres=0.7,
model=attempt_load(m, device).half(), model=attempt_load(m, device).half(),
single_cls=opt.single_cls, single_cls=opt.single_cls,
dataloader=testloader, dataloader=testloader,
save_dir=save_dir, save_dir=save_dir,
save_json=True, save_json=True,
plots=False, plots=False,
is_coco=is_coco) is_coco=is_coco)
# Strip optimizers # Strip optimizers
final = best if best.exists() else last # final model for f in last, best:
for f in last, best: if f.exists():
if f.exists(): strip_optimizer(f) # strip optimizers
strip_optimizer(f) # strip optimizers if wandb_logger.wandb: # Log the stripped model
if opt.bucket: wandb_logger.wandb.log_artifact(str(best if best.exists() else last), type='model',
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload name='run_' + wandb_logger.wandb_run.id + '_model',
if wandb_logger.wandb and not opt.evolve: # Log the stripped model aliases=['latest', 'best', 'stripped'])
wandb_logger.wandb.log_artifact(str(final), type='model',
name='run_' + wandb_logger.wandb_run.id + '_model',
aliases=['latest', 'best', 'stripped'])
wandb_logger.finish_run() wandb_logger.finish_run()
else: else:
dist.destroy_process_group() dist.destroy_process_group()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论