Unverified 提交 7316b78e authored 作者: Ayush Chaurasia's avatar Ayush Chaurasia 提交者: GitHub

W&B: Refactor the wandb_utils.py file (#4496)

* Improve docstrings and run names * default wandb login prompt with timeout * return key * Update api_key check logic * Properly support zipped dataset feature * update docstring * Revert tuorial change * extend changes to log_dataset * add run name * bug fix * bug fix * Update comment * fix import check * remove unused import * Hardcore .yaml file extension * reduce code * Reformat using pycharm * Remove redundant try catch * More refactoring and bug fixes * retry * Reformat using pycharm * respect LOGGERS include list * Fix * fix * refactor constructor * refactor * refactor * refactor * PyCharm reformat Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 d1182c4f
...@@ -38,6 +38,19 @@ def check_wandb_config_file(data_config_file): ...@@ -38,6 +38,19 @@ def check_wandb_config_file(data_config_file):
return data_config_file return data_config_file
def check_wandb_dataset(data_file):
is_wandb_artifact = False
if check_file(data_file) and data_file.endswith('.yaml'):
with open(data_file, errors='ignore') as f:
data_dict = yaml.safe_load(f)
is_wandb_artifact = (data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX) or
data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX))
if is_wandb_artifact:
return data_dict
else:
return check_dataset(data_file)
def get_run_info(run_path): def get_run_info(run_path):
run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX)) run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX))
run_id = run_path.stem run_id = run_path.stem
...@@ -104,7 +117,7 @@ class WandbLogger(): ...@@ -104,7 +117,7 @@ class WandbLogger():
- Initialize WandbLogger instance - Initialize WandbLogger instance
- Upload dataset if opt.upload_dataset is True - Upload dataset if opt.upload_dataset is True
- Setup trainig processes if job_type is 'Training' - Setup trainig processes if job_type is 'Training'
arguments: arguments:
opt (namespace) -- Commandline arguments for this run opt (namespace) -- Commandline arguments for this run
run_id (str) -- Run ID of W&B run to be resumed run_id (str) -- Run ID of W&B run to be resumed
...@@ -147,26 +160,24 @@ class WandbLogger(): ...@@ -147,26 +160,24 @@ class WandbLogger():
allow_val_change=True) if not wandb.run else wandb.run allow_val_change=True) if not wandb.run else wandb.run
if self.wandb_run: if self.wandb_run:
if self.job_type == 'Training': if self.job_type == 'Training':
if not opt.resume: if opt.upload_dataset:
if opt.upload_dataset: if not opt.resume:
self.wandb_artifact_data_dict = self.check_and_upload_dataset(opt) self.wandb_artifact_data_dict = self.check_and_upload_dataset(opt)
elif opt.data.endswith('_wandb.yaml'): # When dataset is W&B artifact if opt.resume:
with open(opt.data, errors='ignore') as f: # resume from artifact
data_dict = yaml.safe_load(f) if isinstance(opt.resume, str) and opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
self.data_dict = data_dict self.data_dict = dict(self.wandb_run.config.data_dict)
else: # Local .yaml dataset file or .zip file else: # local resume
self.data_dict = check_dataset(opt.data) self.data_dict = check_wandb_dataset(opt.data)
else: else:
self.data_dict = check_dataset(opt.data) self.data_dict = check_wandb_dataset(opt.data)
self.wandb_artifact_data_dict = self.wandb_artifact_data_dict or self.data_dict
self.setup_training(opt) # write data_dict to config. useful for resuming from artifacts. Do this only when not resuming.
if not self.wandb_artifact_data_dict:
self.wandb_artifact_data_dict = self.data_dict
# write data_dict to config. useful for resuming from artifacts. Do this only when not resuming.
if not opt.resume:
self.wandb_run.config.update({'data_dict': self.wandb_artifact_data_dict}, self.wandb_run.config.update({'data_dict': self.wandb_artifact_data_dict},
allow_val_change=True) allow_val_change=True)
self.setup_training(opt)
if self.job_type == 'Dataset Creation': if self.job_type == 'Dataset Creation':
self.data_dict = self.check_and_upload_dataset(opt) self.data_dict = self.check_and_upload_dataset(opt)
...@@ -174,10 +185,10 @@ class WandbLogger(): ...@@ -174,10 +185,10 @@ class WandbLogger():
def check_and_upload_dataset(self, opt): def check_and_upload_dataset(self, opt):
""" """
Check if the dataset format is compatible and upload it as W&B artifact Check if the dataset format is compatible and upload it as W&B artifact
arguments: arguments:
opt (namespace)-- Commandline arguments for current run opt (namespace)-- Commandline arguments for current run
returns: returns:
Updated dataset info dictionary where local dataset paths are replaced by WAND_ARFACT_PREFIX links. Updated dataset info dictionary where local dataset paths are replaced by WAND_ARFACT_PREFIX links.
""" """
...@@ -196,10 +207,10 @@ class WandbLogger(): ...@@ -196,10 +207,10 @@ class WandbLogger():
- Attempt to download model checkpoint and dataset artifacts if opt.resume stats with WANDB_ARTIFACT_PREFIX - Attempt to download model checkpoint and dataset artifacts if opt.resume stats with WANDB_ARTIFACT_PREFIX
- Update data_dict, to contain info of previous run if resumed and the paths of dataset artifact if downloaded - Update data_dict, to contain info of previous run if resumed and the paths of dataset artifact if downloaded
- Setup log_dict, initialize bbox_interval - Setup log_dict, initialize bbox_interval
arguments: arguments:
opt (namespace) -- commandline arguments for this run opt (namespace) -- commandline arguments for this run
""" """
self.log_dict, self.current_epoch = {}, 0 self.log_dict, self.current_epoch = {}, 0
self.bbox_interval = opt.bbox_interval self.bbox_interval = opt.bbox_interval
...@@ -211,9 +222,7 @@ class WandbLogger(): ...@@ -211,9 +222,7 @@ class WandbLogger():
opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp = str( opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp = str(
self.weights), config.save_period, config.batch_size, config.bbox_interval, config.epochs, \ self.weights), config.save_period, config.batch_size, config.bbox_interval, config.epochs, \
config.hyp config.hyp
data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume data_dict = self.data_dict
else:
data_dict = self.data_dict
if self.val_artifact is None: # If --upload_dataset is set, use the existing artifact, don't download if self.val_artifact is None: # If --upload_dataset is set, use the existing artifact, don't download
self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'), self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'),
opt.artifact_alias) opt.artifact_alias)
...@@ -243,11 +252,11 @@ class WandbLogger(): ...@@ -243,11 +252,11 @@ class WandbLogger():
def download_dataset_artifact(self, path, alias): def download_dataset_artifact(self, path, alias):
""" """
download the model checkpoint artifact if the path starts with WANDB_ARTIFACT_PREFIX download the model checkpoint artifact if the path starts with WANDB_ARTIFACT_PREFIX
arguments: arguments:
path -- path of the dataset to be used for training path -- path of the dataset to be used for training
alias (str)-- alias of the artifact to be download/used for training alias (str)-- alias of the artifact to be download/used for training
returns: returns:
(str, wandb.Artifact) -- path of the downladed dataset and it's corresponding artifact object if dataset (str, wandb.Artifact) -- path of the downladed dataset and it's corresponding artifact object if dataset
is found otherwise returns (None, None) is found otherwise returns (None, None)
...@@ -263,7 +272,7 @@ class WandbLogger(): ...@@ -263,7 +272,7 @@ class WandbLogger():
def download_model_artifact(self, opt): def download_model_artifact(self, opt):
""" """
download the model checkpoint artifact if the resume path starts with WANDB_ARTIFACT_PREFIX download the model checkpoint artifact if the resume path starts with WANDB_ARTIFACT_PREFIX
arguments: arguments:
opt (namespace) -- Commandline arguments for this run opt (namespace) -- Commandline arguments for this run
""" """
...@@ -281,7 +290,7 @@ class WandbLogger(): ...@@ -281,7 +290,7 @@ class WandbLogger():
def log_model(self, path, opt, epoch, fitness_score, best_model=False): def log_model(self, path, opt, epoch, fitness_score, best_model=False):
""" """
Log the model checkpoint as W&B artifact Log the model checkpoint as W&B artifact
arguments: arguments:
path (Path) -- Path of directory containing the checkpoints path (Path) -- Path of directory containing the checkpoints
opt (namespace) -- Command line arguments for this run opt (namespace) -- Command line arguments for this run
...@@ -305,14 +314,14 @@ class WandbLogger(): ...@@ -305,14 +314,14 @@ class WandbLogger():
def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False): def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):
""" """
Log the dataset as W&B artifact and return the new data file with W&B links Log the dataset as W&B artifact and return the new data file with W&B links
arguments: arguments:
data_file (str) -- the .yaml file with information about the dataset like - path, classes etc. data_file (str) -- the .yaml file with information about the dataset like - path, classes etc.
single_class (boolean) -- train multi-class data as single-class single_class (boolean) -- train multi-class data as single-class
project (str) -- project name. Used to construct the artifact path project (str) -- project name. Used to construct the artifact path
overwrite_config (boolean) -- overwrites the data.yaml file if set to true otherwise creates a new overwrite_config (boolean) -- overwrites the data.yaml file if set to true otherwise creates a new
file with _wandb postfix. Eg -> data_wandb.yaml file with _wandb postfix. Eg -> data_wandb.yaml
returns: returns:
the new .yaml file with artifact links. it can be used to start training directly from artifacts the new .yaml file with artifact links. it can be used to start training directly from artifacts
""" """
...@@ -359,12 +368,12 @@ class WandbLogger(): ...@@ -359,12 +368,12 @@ class WandbLogger():
def create_dataset_table(self, dataset, class_to_id, name='dataset'): def create_dataset_table(self, dataset, class_to_id, name='dataset'):
""" """
Create and return W&B artifact containing W&B Table of the dataset. Create and return W&B artifact containing W&B Table of the dataset.
arguments: arguments:
dataset (LoadImagesAndLabels) -- instance of LoadImagesAndLabels class used to iterate over the data to build Table dataset (LoadImagesAndLabels) -- instance of LoadImagesAndLabels class used to iterate over the data to build Table
class_to_id (dict(int, str)) -- hash map that maps class ids to labels class_to_id (dict(int, str)) -- hash map that maps class ids to labels
name (str) -- name of the artifact name (str) -- name of the artifact
returns: returns:
dataset artifact to be logged or used dataset artifact to be logged or used
""" """
...@@ -401,7 +410,7 @@ class WandbLogger(): ...@@ -401,7 +410,7 @@ class WandbLogger():
def log_training_progress(self, predn, path, names): def log_training_progress(self, predn, path, names):
""" """
Build evaluation Table. Uses reference from validation dataset table. Build evaluation Table. Uses reference from validation dataset table.
arguments: arguments:
predn (list): list of predictions in the native space in the format - [xmin, ymin, xmax, ymax, confidence, class] predn (list): list of predictions in the native space in the format - [xmin, ymin, xmax, ymax, confidence, class]
path (str): local path of the current evaluation image path (str): local path of the current evaluation image
...@@ -431,7 +440,7 @@ class WandbLogger(): ...@@ -431,7 +440,7 @@ class WandbLogger():
def val_one_image(self, pred, predn, path, names, im): def val_one_image(self, pred, predn, path, names, im):
""" """
Log validation data for one image. updates the result Table if validation dataset is uploaded and log bbox media panel Log validation data for one image. updates the result Table if validation dataset is uploaded and log bbox media panel
arguments: arguments:
pred (list): list of scaled predictions in the format - [xmin, ymin, xmax, ymax, confidence, class] pred (list): list of scaled predictions in the format - [xmin, ymin, xmax, ymax, confidence, class]
predn (list): list of predictions in the native space - [xmin, ymin, xmax, ymax, confidence, class] predn (list): list of predictions in the native space - [xmin, ymin, xmax, ymax, confidence, class]
...@@ -453,7 +462,7 @@ class WandbLogger(): ...@@ -453,7 +462,7 @@ class WandbLogger():
def log(self, log_dict): def log(self, log_dict):
""" """
save the metrics to the logging dictionary save the metrics to the logging dictionary
arguments: arguments:
log_dict (Dict) -- metrics/media to be logged in current step log_dict (Dict) -- metrics/media to be logged in current step
""" """
...@@ -464,7 +473,7 @@ class WandbLogger(): ...@@ -464,7 +473,7 @@ class WandbLogger():
def end_epoch(self, best_result=False): def end_epoch(self, best_result=False):
""" """
commit the log_dict, model artifacts and Tables to W&B and flush the log_dict. commit the log_dict, model artifacts and Tables to W&B and flush the log_dict.
arguments: arguments:
best_result (boolean): Boolean representing if the result of this evaluation is best or not best_result (boolean): Boolean representing if the result of this evaluation is best or not
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论