Unverified 提交 ec1d8496 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Improved model+EMA checkpointing (#2292)

* Enhanced model+EMA checkpointing * update * bug fix * bug fix 2 * always save optimizer * ema half * remove model.float() * model half * carry ema/model in fp32 * rm model.float() * both to float always * cleanup * cleanup
上级 ca5b10b7
...@@ -272,7 +272,6 @@ def test(data, ...@@ -272,7 +272,6 @@ def test(data,
if not training: if not training:
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
print(f"Results saved to {save_dir}{s}") print(f"Results saved to {save_dir}{s}")
model.float() # for training
maps = np.zeros(nc) + map maps = np.zeros(nc) + map
for i, c in enumerate(ap_class): for i, c in enumerate(ap_class):
maps[c] = ap[i] maps[c] = ap[i]
......
...@@ -31,7 +31,7 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima ...@@ -31,7 +31,7 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima
from utils.google_utils import attempt_download from utils.google_utils import attempt_download
from utils.loss import ComputeLoss from utils.loss import ComputeLoss
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -136,6 +136,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ...@@ -136,6 +136,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
id=ckpt.get('wandb_id') if 'ckpt' in locals() else None) id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
loggers = {'wandb': wandb} # loggers dict loggers = {'wandb': wandb} # loggers dict
# EMA
ema = ModelEMA(model) if rank in [-1, 0] else None
# Resume # Resume
start_epoch, best_fitness = 0, 0.0 start_epoch, best_fitness = 0, 0.0
if pretrained: if pretrained:
...@@ -144,6 +147,11 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ...@@ -144,6 +147,11 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
optimizer.load_state_dict(ckpt['optimizer']) optimizer.load_state_dict(ckpt['optimizer'])
best_fitness = ckpt['best_fitness'] best_fitness = ckpt['best_fitness']
# EMA
if ema and ckpt.get('ema'):
ema.ema.load_state_dict(ckpt['ema'][0].float().state_dict())
ema.updates = ckpt['ema'][1]
# Results # Results
if ckpt.get('training_results') is not None: if ckpt.get('training_results') is not None:
results_file.write_text(ckpt['training_results']) # write results.txt results_file.write_text(ckpt['training_results']) # write results.txt
...@@ -173,9 +181,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ...@@ -173,9 +181,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
logger.info('Using SyncBatchNorm()') logger.info('Using SyncBatchNorm()')
# EMA
ema = ModelEMA(model) if rank in [-1, 0] else None
# DDP mode # DDP mode
if cuda and rank != -1: if cuda and rank != -1:
model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank) model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank)
...@@ -191,7 +196,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ...@@ -191,7 +196,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# Process 0 # Process 0
if rank in [-1, 0]: if rank in [-1, 0]:
ema.updates = start_epoch * nb // accumulate # set EMA updates
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1, hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
world_size=opt.world_size, workers=opt.workers, world_size=opt.world_size, workers=opt.workers,
...@@ -335,8 +339,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ...@@ -335,8 +339,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# DDP process 0 or single-GPU # DDP process 0 or single-GPU
if rank in [-1, 0]: if rank in [-1, 0]:
# mAP # mAP
if ema: ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs final_epoch = epoch + 1 == epochs
if not opt.notest or final_epoch: # Calculate mAP if not opt.notest or final_epoch: # Calculate mAP
results, maps, times = test.test(opt.data, results, maps, times = test.test(opt.data,
...@@ -378,8 +381,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ...@@ -378,8 +381,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
ckpt = {'epoch': epoch, ckpt = {'epoch': epoch,
'best_fitness': best_fitness, 'best_fitness': best_fitness,
'training_results': results_file.read_text(), 'training_results': results_file.read_text(),
'model': ema.ema, 'model': (model.module if is_parallel(model) else model).half(),
'optimizer': None if final_epoch else optimizer.state_dict(), 'ema': (ema.ema.half(), ema.updates),
'optimizer': optimizer.state_dict(),
'wandb_id': wandb_run.id if wandb else None} 'wandb_id': wandb_run.id if wandb else None}
# Save last, best and delete # Save last, best and delete
...@@ -387,6 +391,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ...@@ -387,6 +391,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
if best_fitness == fi: if best_fitness == fi:
torch.save(ckpt, best) torch.save(ckpt, best)
del ckpt del ckpt
model.float(), ema.ema.float()
# end epoch ---------------------------------------------------------------------------------------------------- # end epoch ----------------------------------------------------------------------------------------------------
# end training # end training
......
...@@ -484,8 +484,8 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non ...@@ -484,8 +484,8 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer() def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer()
# Strip optimizer from 'f' to finalize training, optionally save as 's' # Strip optimizer from 'f' to finalize training, optionally save as 's'
x = torch.load(f, map_location=torch.device('cpu')) x = torch.load(f, map_location=torch.device('cpu'))
for key in 'optimizer', 'training_results', 'wandb_id': for k in 'optimizer', 'training_results', 'wandb_id', 'ema': # keys
x[key] = None x[k] = None
x['epoch'] = -1 x['epoch'] = -1
x['model'].half() # to FP16 x['model'].half() # to FP16
for p in x['model'].parameters(): for p in x['model'].parameters():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论