Unverified 提交 ffb6e110 authored 作者: Ayush Chaurasia's avatar Ayush Chaurasia 提交者: GitHub

W&B: Update Tables API and comply with new dataset_check (#3772)

* Update tables API and windows path fix * update dataset check
上级 09246a5a
...@@ -136,7 +136,6 @@ class WandbLogger(): ...@@ -136,7 +136,6 @@ class WandbLogger():
def check_and_upload_dataset(self, opt): def check_and_upload_dataset(self, opt):
assert wandb, 'Install wandb to upload dataset' assert wandb, 'Install wandb to upload dataset'
check_dataset(self.data_dict)
config_path = self.log_dataset_artifact(check_file(opt.data), config_path = self.log_dataset_artifact(check_file(opt.data),
opt.single_cls, opt.single_cls,
'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem) 'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem)
...@@ -171,9 +170,11 @@ class WandbLogger(): ...@@ -171,9 +170,11 @@ class WandbLogger():
data_dict['val'] = str(val_path) data_dict['val'] = str(val_path)
self.val_table = self.val_artifact.get("val") self.val_table = self.val_artifact.get("val")
self.map_val_table_path() self.map_val_table_path()
wandb.log({"validation dataset": self.val_table})
if self.val_artifact is not None: if self.val_artifact is not None:
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation") self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"]) self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"])
if opt.bbox_interval == -1: if opt.bbox_interval == -1:
self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1 self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1
return data_dict return data_dict
...@@ -181,7 +182,7 @@ class WandbLogger(): ...@@ -181,7 +182,7 @@ class WandbLogger():
def download_dataset_artifact(self, path, alias): def download_dataset_artifact(self, path, alias):
if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX): if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX):
artifact_path = Path(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias) artifact_path = Path(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
dataset_artifact = wandb.use_artifact(artifact_path.as_posix()) dataset_artifact = wandb.use_artifact(artifact_path.as_posix().replace("\\","/"))
assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'" assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'"
datadir = dataset_artifact.download() datadir = dataset_artifact.download()
return datadir, dataset_artifact return datadir, dataset_artifact
...@@ -216,6 +217,7 @@ class WandbLogger(): ...@@ -216,6 +217,7 @@ class WandbLogger():
def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False): def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):
with open(data_file) as f: with open(data_file) as f:
data = yaml.safe_load(f) # data dict data = yaml.safe_load(f) # data dict
check_dataset(data)
nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names']) nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])
names = {k: v for k, v in enumerate(names)} # to index dictionary names = {k: v for k, v in enumerate(names)} # to index dictionary
self.train_artifact = self.create_dataset_table(LoadImagesAndLabels( self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(
...@@ -228,6 +230,7 @@ class WandbLogger(): ...@@ -228,6 +230,7 @@ class WandbLogger():
data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val') data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val')
path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path
data.pop('download', None) data.pop('download', None)
data.pop('path', None)
with open(path, 'w') as f: with open(path, 'w') as f:
yaml.safe_dump(data, f) yaml.safe_dump(data, f)
...@@ -297,6 +300,7 @@ class WandbLogger(): ...@@ -297,6 +300,7 @@ class WandbLogger():
id = self.val_table_map[Path(path).name] id = self.val_table_map[Path(path).name]
self.result_table.add_data(self.current_epoch, self.result_table.add_data(self.current_epoch,
id, id,
self.val_table.data[id][1],
wandb.Image(self.val_table.data[id][1], boxes=boxes, classes=class_set), wandb.Image(self.val_table.data[id][1], boxes=boxes, classes=class_set),
total_conf / max(1, len(box_data)) total_conf / max(1, len(box_data))
) )
...@@ -312,11 +316,12 @@ class WandbLogger(): ...@@ -312,11 +316,12 @@ class WandbLogger():
wandb.log(self.log_dict) wandb.log(self.log_dict)
self.log_dict = {} self.log_dict = {}
if self.result_artifact: if self.result_artifact:
train_results = wandb.JoinedTable(self.val_table, self.result_table, "id") self.result_artifact.add(self.result_table, 'result')
self.result_artifact.add(train_results, 'result')
wandb.log_artifact(self.result_artifact, aliases=['latest', 'last', 'epoch ' + str(self.current_epoch), wandb.log_artifact(self.result_artifact, aliases=['latest', 'last', 'epoch ' + str(self.current_epoch),
('best' if best_result else '')]) ('best' if best_result else '')])
self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
wandb.log({"evaluation": self.result_table})
self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"])
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation") self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
def finish_run(self): def finish_run(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论