Unverified 提交 ca5b10b7 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update train.py (#2290)

* Update train.py * Update train.py * Update train.py * Update train.py * Create train.py
上级 0070995b
......@@ -146,8 +146,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# Results
if ckpt.get('training_results') is not None:
with open(results_file, 'w') as file:
file.write(ckpt['training_results']) # write results.txt
results_file.write_text(ckpt['training_results']) # write results.txt
# Epochs
start_epoch = ckpt['epoch'] + 1
......@@ -354,7 +353,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# Write
with open(results_file, 'a') as f:
f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss
if len(opt.name) and opt.bucket:
os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
......@@ -375,15 +374,13 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
best_fitness = fi
# Save model
save = (not opt.nosave) or (final_epoch and not opt.evolve)
if save:
with open(results_file, 'r') as f: # create checkpoint
ckpt = {'epoch': epoch,
'best_fitness': best_fitness,
'training_results': f.read(),
'model': ema.ema,
'optimizer': None if final_epoch else optimizer.state_dict(),
'wandb_id': wandb_run.id if wandb else None}
if (not opt.nosave) or (final_epoch and not opt.evolve): # if save
ckpt = {'epoch': epoch,
'best_fitness': best_fitness,
'training_results': results_file.read_text(),
'model': ema.ema,
'optimizer': None if final_epoch else optimizer.state_dict(),
'wandb_id': wandb_run.id if wandb else None}
# Save last, best and delete
torch.save(ckpt, last)
......@@ -396,9 +393,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
if rank in [-1, 0]:
# Strip optimizers
final = best if best.exists() else last # final model
for f in [last, best]:
for f in last, best:
if f.exists():
strip_optimizer(f) # strip optimizers
strip_optimizer(f)
if opt.bucket:
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
......@@ -415,17 +412,17 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# Test best.pt
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
for conf, iou, save_json in ([0.25, 0.45, False], [0.001, 0.65, True]): # speed, mAP tests
for m in (last, best) if best.exists() else (last): # speed, mAP tests
results, _, _ = test.test(opt.data,
batch_size=batch_size * 2,
imgsz=imgsz_test,
conf_thres=conf,
iou_thres=iou,
model=attempt_load(final, device).half(),
conf_thres=0.001,
iou_thres=0.7,
model=attempt_load(m, device).half(),
single_cls=opt.single_cls,
dataloader=testloader,
save_dir=save_dir,
save_json=save_json,
save_json=True,
plots=False)
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论