Unverified 提交 fbf41e09 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Add `train.run()` method (#3700)

* Update train.py explicit arguments * Update train.py * Add run method
上级 c1af67dc
...@@ -46,8 +46,9 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -46,8 +46,9 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
opt, opt,
device, device,
): ):
save_dir, epochs, batch_size, weights, single_cls = \ save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, notest, nosave, workers, = \
opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
opt.resume, opt.notest, opt.nosave, opt.workers
# Directories # Directories
save_dir = Path(save_dir) save_dir = Path(save_dir)
...@@ -70,34 +71,34 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -70,34 +71,34 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
yaml.safe_dump(vars(opt), f, sort_keys=False) yaml.safe_dump(vars(opt), f, sort_keys=False)
# Configure # Configure
plots = not opt.evolve # create plots plots = not evolve # create plots
cuda = device.type != 'cpu' cuda = device.type != 'cpu'
init_seeds(2 + RANK) init_seeds(2 + RANK)
with open(opt.data) as f: with open(data) as f:
data_dict = yaml.safe_load(f) # data dict data_dict = yaml.safe_load(f) # data dict
# Loggers # Loggers
loggers = {'wandb': None, 'tb': None} # loggers dict loggers = {'wandb': None, 'tb': None} # loggers dict
if RANK in [-1, 0]: if RANK in [-1, 0]:
# TensorBoard # TensorBoard
if not opt.evolve: if not evolve:
prefix = colorstr('tensorboard: ') prefix = colorstr('tensorboard: ')
logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/") logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
loggers['tb'] = SummaryWriter(opt.save_dir) loggers['tb'] = SummaryWriter(str(save_dir))
# W&B # W&B
opt.hyp = hyp # add hyperparameters opt.hyp = hyp # add hyperparameters
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict) wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
loggers['wandb'] = wandb_logger.wandb loggers['wandb'] = wandb_logger.wandb
if loggers['wandb']:
data_dict = wandb_logger.data_dict data_dict = wandb_logger.data_dict
if wandb_logger.wandb:
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # may update weights, epochs if resuming weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # may update weights, epochs if resuming
nc = 1 if single_cls else int(data_dict['nc']) # number of classes nc = 1 if single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data) # check
is_coco = opt.data.endswith('coco.yaml') and nc == 80 # COCO dataset is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset
# Model # Model
pretrained = weights.endswith('.pt') pretrained = weights.endswith('.pt')
...@@ -105,14 +106,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -105,14 +106,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
with torch_distributed_zero_first(RANK): with torch_distributed_zero_first(RANK):
weights = attempt_download(weights) # download if not found locally weights = attempt_download(weights) # download if not found locally
ckpt = torch.load(weights, map_location=device) # load checkpoint ckpt = torch.load(weights, map_location=device) # load checkpoint
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
state_dict = ckpt['model'].float().state_dict() # to FP32 state_dict = ckpt['model'].float().state_dict() # to FP32
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
model.load_state_dict(state_dict, strict=False) # load model.load_state_dict(state_dict, strict=False) # load
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
else: else:
model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
with torch_distributed_zero_first(RANK): with torch_distributed_zero_first(RANK):
check_dataset(data_dict) # check check_dataset(data_dict) # check
train_path = data_dict['train'] train_path = data_dict['train']
...@@ -182,7 +183,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -182,7 +183,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Epochs # Epochs
start_epoch = ckpt['epoch'] + 1 start_epoch = ckpt['epoch'] + 1
if opt.resume: if resume:
assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs) assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
if epochs < start_epoch: if epochs < start_epoch:
logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' % logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
...@@ -210,20 +211,20 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -210,20 +211,20 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Trainloader # Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls, dataloader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK, hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
workers=opt.workers, workers=workers,
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: ')) image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
nb = len(dataloader) # number of batches nb = len(dataloader) # number of batches
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1) assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, data, nc - 1)
# Process 0 # Process 0
if RANK in [-1, 0]: if RANK in [-1, 0]:
testloader = create_dataloader(test_path, imgsz_test, batch_size // WORLD_SIZE * 2, gs, single_cls, testloader = create_dataloader(test_path, imgsz_test, batch_size // WORLD_SIZE * 2, gs, single_cls,
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1, hyp=hyp, cache=opt.cache_images and not notest, rect=True, rank=-1,
workers=opt.workers, workers=workers,
pad=0.5, prefix=colorstr('val: '))[0] pad=0.5, prefix=colorstr('val: '))[0]
if not opt.resume: if not resume:
labels = np.concatenate(dataset.labels, 0) labels = np.concatenate(dataset.labels, 0)
c = torch.tensor(labels[:, 0]) # classes c = torch.tensor(labels[:, 0]) # classes
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
...@@ -356,8 +357,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -356,8 +357,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning warnings.simplefilter('ignore') # suppress jit trace warning
loggers['tb'].add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), []) loggers['tb'].add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
elif plots and ni == 10 and wandb_logger.wandb: elif plots and ni == 10 and loggers['wandb']:
wandb_logger.log({'Mosaics': [wandb_logger.wandb.Image(str(x), caption=x.name) for x in wandb_logger.log({'Mosaics': [loggers['wandb'].Image(str(x), caption=x.name) for x in
save_dir.glob('train*.jpg') if x.exists()]}) save_dir.glob('train*.jpg') if x.exists()]})
# end batch ------------------------------------------------------------------------------------------------ # end batch ------------------------------------------------------------------------------------------------
...@@ -371,7 +372,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -371,7 +372,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# mAP # mAP
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights']) ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs final_epoch = epoch + 1 == epochs
if not opt.notest or final_epoch: # Calculate mAP if not notest or final_epoch: # Calculate mAP
wandb_logger.current_epoch = epoch + 1 wandb_logger.current_epoch = epoch + 1
results, maps, _ = test.test(data_dict, results, maps, _ = test.test(data_dict,
batch_size=batch_size // WORLD_SIZE * 2, batch_size=batch_size // WORLD_SIZE * 2,
...@@ -398,7 +399,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -398,7 +399,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags): for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
if loggers['tb']: if loggers['tb']:
loggers['tb'].add_scalar(tag, x, epoch) # TensorBoard loggers['tb'].add_scalar(tag, x, epoch) # TensorBoard
if wandb_logger.wandb: if loggers['wandb']:
wandb_logger.log({tag: x}) # W&B wandb_logger.log({tag: x}) # W&B
# Update best mAP # Update best mAP
...@@ -408,7 +409,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -408,7 +409,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
wandb_logger.end_epoch(best_result=best_fitness == fi) wandb_logger.end_epoch(best_result=best_fitness == fi)
# Save model # Save model
if (not opt.nosave) or (final_epoch and not opt.evolve): # if save if (not nosave) or (final_epoch and not evolve): # if save
ckpt = {'epoch': epoch, ckpt = {'epoch': epoch,
'best_fitness': best_fitness, 'best_fitness': best_fitness,
'training_results': results_file.read_text(), 'training_results': results_file.read_text(),
...@@ -416,13 +417,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -416,13 +417,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
'ema': deepcopy(ema.ema).half(), 'ema': deepcopy(ema.ema).half(),
'updates': ema.updates, 'updates': ema.updates,
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None} 'wandb_id': wandb_logger.wandb_run.id if loggers['wandb'] else None}
# Save last, best and delete # Save last, best and delete
torch.save(ckpt, last) torch.save(ckpt, last)
if best_fitness == fi: if best_fitness == fi:
torch.save(ckpt, best) torch.save(ckpt, best)
if wandb_logger.wandb: if loggers['wandb']:
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1: if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi) wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi)
del ckpt del ckpt
...@@ -433,15 +434,15 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -433,15 +434,15 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n') logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
if plots: if plots:
plot_results(save_dir=save_dir) # save as results.png plot_results(save_dir=save_dir) # save as results.png
if wandb_logger.wandb: if loggers['wandb']:
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]] files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files wandb_logger.log({"Results": [loggers['wandb'].Image(str(save_dir / f), caption=f) for f in files
if (save_dir / f).exists()]}) if (save_dir / f).exists()]})
if not opt.evolve: if not evolve:
if is_coco: # COCO dataset if is_coco: # COCO dataset
for m in [last, best] if best.exists() else [last]: # speed, mAP tests for m in [last, best] if best.exists() else [last]: # speed, mAP tests
results, _, _ = test.test(opt.data, results, _, _ = test.test(data,
batch_size=batch_size // WORLD_SIZE * 2, batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz_test, imgsz=imgsz_test,
conf_thres=0.001, conf_thres=0.001,
...@@ -457,8 +458,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -457,8 +458,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
for f in last, best: for f in last, best:
if f.exists(): if f.exists():
strip_optimizer(f) # strip optimizers strip_optimizer(f) # strip optimizers
if wandb_logger.wandb: # Log the stripped model if loggers['wandb']: # Log the stripped model
wandb_logger.wandb.log_artifact(str(best if best.exists() else last), type='model', loggers['wandb'].log_artifact(str(best if best.exists() else last), type='model',
name='run_' + wandb_logger.wandb_run.id + '_model', name='run_' + wandb_logger.wandb_run.id + '_model',
aliases=['latest', 'best', 'stripped']) aliases=['latest', 'best', 'stripped'])
wandb_logger.finish_run() wandb_logger.finish_run()
...@@ -467,7 +468,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -467,7 +468,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
return results return results
def parse_opt(): def parse_opt(known=False):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path') parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
parser.add_argument('--cfg', type=str, default='', help='model.yaml path') parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
...@@ -503,7 +504,7 @@ def parse_opt(): ...@@ -503,7 +504,7 @@ def parse_opt():
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch') parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used') parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
opt = parser.parse_args() opt = parser.parse_known_args()[0] if known else parser.parse_args()
return opt return opt
...@@ -633,6 +634,14 @@ def main(opt): ...@@ -633,6 +634,14 @@ def main(opt):
f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}') f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')
def run(**kwargs):
# Usage: import train; train.run(imgsz=320, weights='yolov5m.pt')
opt = parse_opt(True)
for k, v in kwargs.items():
setattr(opt, k, v)
main(opt)
if __name__ == "__main__": if __name__ == "__main__":
opt = parse_opt() opt = parse_opt()
main(opt) main(opt)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论