Unverified 提交 518c0957 authored 作者: Ayush Chaurasia's avatar Ayush Chaurasia 提交者: GitHub

W&B resume ddp from run link fix (#2579)

* W&B resume ddp from run link fix * Native DDP W&B support for training, resuming
上级 dc51e80b
...@@ -33,7 +33,7 @@ from utils.google_utils import attempt_download ...@@ -33,7 +33,7 @@ from utils.google_utils import attempt_download
from utils.loss import ComputeLoss from utils.loss import ComputeLoss
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
from utils.wandb_logging.wandb_utils import WandbLogger, resume_and_get_id from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -496,7 +496,7 @@ if __name__ == '__main__': ...@@ -496,7 +496,7 @@ if __name__ == '__main__':
check_requirements() check_requirements()
# Resume # Resume
wandb_run = resume_and_get_id(opt) wandb_run = check_wandb_resume(opt)
if opt.resume and not wandb_run: # resume an interrupted run if opt.resume and not wandb_run: # resume an interrupted run
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
......
...@@ -23,7 +23,7 @@ except ImportError: ...@@ -23,7 +23,7 @@ except ImportError:
WANDB_ARTIFACT_PREFIX = 'wandb-artifact://' WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
def remove_prefix(from_string, prefix): def remove_prefix(from_string, prefix=WANDB_ARTIFACT_PREFIX):
return from_string[len(prefix):] return from_string[len(prefix):]
...@@ -33,35 +33,73 @@ def check_wandb_config_file(data_config_file): ...@@ -33,35 +33,73 @@ def check_wandb_config_file(data_config_file):
return wandb_config return wandb_config
return data_config_file return data_config_file
def get_run_info(run_path):
run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX))
run_id = run_path.stem
project = run_path.parent.stem
model_artifact_name = 'run_' + run_id + '_model'
return run_id, project, model_artifact_name
def resume_and_get_id(opt): def check_wandb_resume(opt):
# It's more elegant to stick to 1 wandb.init call, but as useful config data is overwritten in the WandbLogger's wandb.init call process_wandb_config_ddp_mode(opt) if opt.global_rank not in [-1, 0] else None
if isinstance(opt.resume, str): if isinstance(opt.resume, str):
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
run_path = Path(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX)) if opt.global_rank not in [-1, 0]: # For resuming DDP runs
run_id = run_path.stem run_id, project, model_artifact_name = get_run_info(opt.resume)
project = run_path.parent.stem api = wandb.Api()
model_artifact_name = WANDB_ARTIFACT_PREFIX + 'run_' + run_id + '_model' artifact = api.artifact(project + '/' + model_artifact_name + ':latest')
assert wandb, 'install wandb to resume wandb runs' modeldir = artifact.download()
# Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config opt.weights = str(Path(modeldir) / "last.pt")
run = wandb.init(id=run_id, project=project, resume='allow') return True
opt.resume = model_artifact_name
return run
return None return None
def process_wandb_config_ddp_mode(opt):
with open(opt.data) as f:
data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict
train_dir, val_dir = None, None
if data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX):
api = wandb.Api()
train_artifact = api.artifact(remove_prefix(data_dict['train']) + ':' + opt.artifact_alias)
train_dir = train_artifact.download()
train_path = Path(train_dir) / 'data/images/'
data_dict['train'] = str(train_path)
if data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX):
api = wandb.Api()
val_artifact = api.artifact(remove_prefix(data_dict['val']) + ':' + opt.artifact_alias)
val_dir = val_artifact.download()
val_path = Path(val_dir) / 'data/images/'
data_dict['val'] = str(val_path)
if train_dir or val_dir:
ddp_data_path = str(Path(val_dir) / 'wandb_local_data.yaml')
with open(ddp_data_path, 'w') as f:
yaml.dump(data_dict, f)
opt.data = ddp_data_path
class WandbLogger(): class WandbLogger():
def __init__(self, opt, name, run_id, data_dict, job_type='Training'): def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
# Pre-training routine -- # Pre-training routine --
self.job_type = job_type self.job_type = job_type
self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict
if self.wandb: # It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call
if isinstance(opt.resume, str): # checks resume from artifact
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
run_id, project, model_artifact_name = get_run_info(opt.resume)
model_artifact_name = WANDB_ARTIFACT_PREFIX + model_artifact_name
assert wandb, 'install wandb to resume wandb runs'
# Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config
self.wandb_run = wandb.init(id=run_id, project=project, resume='allow')
opt.resume = model_artifact_name
elif self.wandb:
self.wandb_run = wandb.init(config=opt, self.wandb_run = wandb.init(config=opt,
resume="allow", resume="allow",
project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem, project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
name=name, name=name,
job_type=job_type, job_type=job_type,
id=run_id) if not wandb.run else wandb.run id=run_id) if not wandb.run else wandb.run
if self.wandb_run:
if self.job_type == 'Training': if self.job_type == 'Training':
if not opt.resume: if not opt.resume:
wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论