Unverified 提交 1bf93652 authored 作者: Ayush Chaurasia's avatar Ayush Chaurasia 提交者: GitHub

W&B DDP fix (#2574)

上级 0d891c60
...@@ -66,14 +66,16 @@ def train(hyp, opt, device, tb_writer=None): ...@@ -66,14 +66,16 @@ def train(hyp, opt, device, tb_writer=None):
is_coco = opt.data.endswith('coco.yaml') is_coco = opt.data.endswith('coco.yaml')
# Logging- Doing this before checking the dataset. Might update data_dict # Logging- Doing this before checking the dataset. Might update data_dict
loggers = {'wandb': None} # loggers dict
if rank in [-1, 0]: if rank in [-1, 0]:
opt.hyp = hyp # add hyperparameters opt.hyp = hyp # add hyperparameters
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict) wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict)
loggers['wandb'] = wandb_logger.wandb
data_dict = wandb_logger.data_dict data_dict = wandb_logger.data_dict
if wandb_logger.wandb: if wandb_logger.wandb:
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
loggers = {'wandb': wandb_logger.wandb} # loggers dict
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
...@@ -381,6 +383,7 @@ def train(hyp, opt, device, tb_writer=None): ...@@ -381,6 +383,7 @@ def train(hyp, opt, device, tb_writer=None):
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
if fi > best_fitness: if fi > best_fitness:
best_fitness = fi best_fitness = fi
wandb_logger.end_epoch(best_result=best_fitness == fi)
# Save model # Save model
if (not opt.nosave) or (final_epoch and not opt.evolve): # if save if (not opt.nosave) or (final_epoch and not opt.evolve): # if save
...@@ -402,7 +405,6 @@ def train(hyp, opt, device, tb_writer=None): ...@@ -402,7 +405,6 @@ def train(hyp, opt, device, tb_writer=None):
wandb_logger.log_model( wandb_logger.log_model(
last.parent, opt, epoch, fi, best_model=best_fitness == fi) last.parent, opt, epoch, fi, best_model=best_fitness == fi)
del ckpt del ckpt
wandb_logger.end_epoch(best_result=best_fitness == fi)
# end epoch ---------------------------------------------------------------------------------------------------- # end epoch ----------------------------------------------------------------------------------------------------
# end training # end training
...@@ -442,10 +444,10 @@ def train(hyp, opt, device, tb_writer=None): ...@@ -442,10 +444,10 @@ def train(hyp, opt, device, tb_writer=None):
wandb_logger.wandb.log_artifact(str(final), type='model', wandb_logger.wandb.log_artifact(str(final), type='model',
name='run_' + wandb_logger.wandb_run.id + '_model', name='run_' + wandb_logger.wandb_run.id + '_model',
aliases=['last', 'best', 'stripped']) aliases=['last', 'best', 'stripped'])
wandb_logger.finish_run()
else: else:
dist.destroy_process_group() dist.destroy_process_group()
torch.cuda.empty_cache() torch.cuda.empty_cache()
wandb_logger.finish_run()
return results return results
......
...@@ -16,9 +16,9 @@ from utils.general import colorstr, xywh2xyxy, check_dataset ...@@ -16,9 +16,9 @@ from utils.general import colorstr, xywh2xyxy, check_dataset
try: try:
import wandb import wandb
from wandb import init, finish
except ImportError: except ImportError:
wandb = None wandb = None
print(f"{colorstr('wandb: ')}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
WANDB_ARTIFACT_PREFIX = 'wandb-artifact://' WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
...@@ -71,6 +71,9 @@ class WandbLogger(): ...@@ -71,6 +71,9 @@ class WandbLogger():
self.data_dict = self.setup_training(opt, data_dict) self.data_dict = self.setup_training(opt, data_dict)
if self.job_type == 'Dataset Creation': if self.job_type == 'Dataset Creation':
self.data_dict = self.check_and_upload_dataset(opt) self.data_dict = self.check_and_upload_dataset(opt)
else:
print(f"{colorstr('wandb: ')}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
def check_and_upload_dataset(self, opt): def check_and_upload_dataset(self, opt):
assert wandb, 'Install wandb to upload dataset' assert wandb, 'Install wandb to upload dataset'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论