Unverified 提交 efe60b56 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Refactor train.py and val.py `loggers` (#4137)

* Update loggers * Config * Update val.py * cleanup * fix1 * fix2 * fix3 and reformat * format sweep.py * Logger() class * cleanup * cleanup2 * wandb package import fix * wandb package import fix2 * txt fix * fix4 * fix5 * fix6 * drop wandb into utils/loggers * fix 7 * rename loggers/wandb_logging to loggers/wandb * Update message * Update message * Update message * cleanup * Fix x axis bug * fix rank 0 issue * cleanup
上级 63dd65e7
...@@ -10,7 +10,6 @@ import os ...@@ -10,7 +10,6 @@ import os
import random import random
import sys import sys
import time import time
import warnings
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread
...@@ -24,7 +23,6 @@ import yaml ...@@ -24,7 +23,6 @@ import yaml
from torch.cuda import amp from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Adam, SGD, lr_scheduler from torch.optim import Adam, SGD, lr_scheduler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm from tqdm import tqdm
FILE = Path(__file__).absolute() FILE = Path(__file__).absolute()
...@@ -42,8 +40,9 @@ from utils.google_utils import attempt_download ...@@ -42,8 +40,9 @@ from utils.google_utils import attempt_download
from utils.loss import ComputeLoss from utils.loss import ComputeLoss
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.metrics import fitness from utils.metrics import fitness
from utils.loggers import Loggers
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
...@@ -76,37 +75,23 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -76,37 +75,23 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
with open(save_dir / 'opt.yaml', 'w') as f: with open(save_dir / 'opt.yaml', 'w') as f:
yaml.safe_dump(vars(opt), f, sort_keys=False) yaml.safe_dump(vars(opt), f, sort_keys=False)
# Configure # Config
plots = not evolve # create plots plots = not evolve # create plots
cuda = device.type != 'cpu' cuda = device.type != 'cpu'
init_seeds(1 + RANK) init_seeds(1 + RANK)
with open(data) as f: with open(data) as f:
data_dict = yaml.safe_load(f) # data dict data_dict = yaml.safe_load(f) # data dict
# Loggers
loggers = {'wandb': None, 'tb': None} # loggers dict
if RANK in [-1, 0]:
# TensorBoard
if plots:
prefix = colorstr('tensorboard: ')
LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
loggers['tb'] = SummaryWriter(str(save_dir))
# W&B
opt.hyp = hyp # add hyperparameters
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
run_id = run_id if opt.resume else None # start fresh run if transfer learning
wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
loggers['wandb'] = wandb_logger.wandb
if loggers['wandb']:
data_dict = wandb_logger.data_dict
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # may update values if resuming
nc = 1 if single_cls else int(data_dict['nc']) # number of classes nc = 1 if single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset
# Loggers
if RANK in [-1, 0]:
loggers = Loggers(save_dir, results_file, weights, opt, hyp, data_dict, LOGGER).start() # loggers dict
if loggers.wandb and resume:
weights, epochs, hyp, data_dict = opt.weights, opt.epochs, opt.hyp, loggers.wandb.data_dict
# Model # Model
pretrained = weights.endswith('.pt') pretrained = weights.endswith('.pt')
if pretrained: if pretrained:
...@@ -351,16 +336,11 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -351,16 +336,11 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
pbar.set_description(s) pbar.set_description(s)
# Plot # Plot
if plots and ni < 3: if plots:
if ni < 3:
f = save_dir / f'train_batch{ni}.jpg' # filename f = save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start() Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
if loggers['tb'] and ni == 0: # TensorBoard loggers.on_train_batch_end(ni, model, imgs)
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning
loggers['tb'].add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
elif plots and ni == 10 and loggers['wandb']:
wandb_logger.log({'Mosaics': [loggers['wandb'].Image(str(x), caption=x.name) for x in
save_dir.glob('train*.jpg') if x.exists()]})
# end batch ------------------------------------------------------------------------------------------------ # end batch ------------------------------------------------------------------------------------------------
...@@ -368,13 +348,12 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -368,13 +348,12 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
lr = [x['lr'] for x in optimizer.param_groups] # for loggers lr = [x['lr'] for x in optimizer.param_groups] # for loggers
scheduler.step() scheduler.step()
# DDP process 0 or single-GPU
if RANK in [-1, 0]: if RANK in [-1, 0]:
# mAP # mAP
loggers.on_train_epoch_end(epoch)
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights']) ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs final_epoch = epoch + 1 == epochs
if not noval or final_epoch: # Calculate mAP if not noval or final_epoch: # Calculate mAP
wandb_logger.current_epoch = epoch + 1
results, maps, _ = val.run(data_dict, results, maps, _ = val.run(data_dict,
batch_size=batch_size // WORLD_SIZE * 2, batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz, imgsz=imgsz,
...@@ -385,29 +364,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -385,29 +364,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
save_json=is_coco and final_epoch, save_json=is_coco and final_epoch,
verbose=nc < 50 and final_epoch, verbose=nc < 50 and final_epoch,
plots=plots and final_epoch, plots=plots and final_epoch,
wandb_logger=wandb_logger, loggers=loggers,
compute_loss=compute_loss) compute_loss=compute_loss)
# Write
with open(results_file, 'a') as f:
f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss
# Log
tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
'x/lr0', 'x/lr1', 'x/lr2'] # params
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
if loggers['tb']:
loggers['tb'].add_scalar(tag, x, epoch) # TensorBoard
if loggers['wandb']:
wandb_logger.log({tag: x}) # W&B
# Update best mAP # Update best mAP
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
if fi > best_fitness: if fi > best_fitness:
best_fitness = fi best_fitness = fi
wandb_logger.end_epoch(best_result=best_fitness == fi) loggers.on_train_val_end(mloss, results, lr, epoch, s, best_fitness, fi)
# Save model # Save model
if (not nosave) or (final_epoch and not evolve): # if save if (not nosave) or (final_epoch and not evolve): # if save
...@@ -418,16 +382,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -418,16 +382,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
'ema': deepcopy(ema.ema).half(), 'ema': deepcopy(ema.ema).half(),
'updates': ema.updates, 'updates': ema.updates,
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'wandb_id': wandb_logger.wandb_run.id if loggers['wandb'] else None} 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None}
# Save last, best and delete # Save last, best and delete
torch.save(ckpt, last) torch.save(ckpt, last)
if best_fitness == fi: if best_fitness == fi:
torch.save(ckpt, best) torch.save(ckpt, best)
if loggers['wandb']:
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi)
del ckpt del ckpt
loggers.on_model_save(last, epoch, final_epoch, best_fitness, fi)
# end epoch ---------------------------------------------------------------------------------------------------- # end epoch ----------------------------------------------------------------------------------------------------
# end training ----------------------------------------------------------------------------------------------------- # end training -----------------------------------------------------------------------------------------------------
...@@ -435,10 +397,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -435,10 +397,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
LOGGER.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n') LOGGER.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
if plots: if plots:
plot_results(save_dir=save_dir) # save as results.png plot_results(save_dir=save_dir) # save as results.png
if loggers['wandb']:
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
wandb_logger.log({"Results": [loggers['wandb'].Image(str(save_dir / f), caption=f) for f in files
if (save_dir / f).exists()]})
if not evolve: if not evolve:
if is_coco: # COCO dataset if is_coco: # COCO dataset
...@@ -458,11 +416,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -458,11 +416,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
for f in last, best: for f in last, best:
if f.exists(): if f.exists():
strip_optimizer(f) # strip optimizers strip_optimizer(f) # strip optimizers
if loggers['wandb']: # Log the stripped model
loggers['wandb'].log_artifact(str(best if best.exists() else last), type='model', loggers.on_train_end(last, best)
name='run_' + wandb_logger.wandb_run.id + '_model',
aliases=['latest', 'best', 'stripped'])
wandb_logger.finish_run()
torch.cuda.empty_cache() torch.cuda.empty_cache()
return results return results
......
# YOLOv5 experiment logging utils
import warnings
import torch
from torch.utils.tensorboard import SummaryWriter
from utils.general import colorstr, emojis
from utils.loggers.wandb.wandb_utils import WandbLogger
from utils.torch_utils import de_parallel
LOGGERS = ('txt', 'tb', 'wandb') # text-file, TensorBoard, Weights & Biases
try:
import wandb
assert hasattr(wandb, '__version__') # verify package import not local dir
except (ImportError, AssertionError):
wandb = None
class Loggers():
# YOLOv5 Loggers class
def __init__(self, save_dir=None, results_file=None, weights=None, opt=None, hyp=None,
data_dict=None, logger=None, include=LOGGERS):
self.save_dir = save_dir
self.results_file = results_file
self.weights = weights
self.opt = opt
self.hyp = hyp
self.data_dict = data_dict
self.logger = logger # for printing results to console
self.include = include
for k in LOGGERS:
setattr(self, k, None) # init empty logger dictionary
def start(self):
self.txt = True # always log to txt
# Message
try:
import wandb
except ImportError:
prefix = colorstr('Weights & Biases: ')
s = f"{prefix}run 'pip install wandb' to automatically track and visualize YOLOv5 🚀 runs (RECOMMENDED)"
print(emojis(s))
# TensorBoard
s = self.save_dir
if 'tb' in self.include and not self.opt.evolve:
prefix = colorstr('TensorBoard: ')
self.logger.info(f"{prefix}Start with 'tensorboard --logdir {s.parent}', view at http://localhost:6006/")
self.tb = SummaryWriter(str(s))
# W&B
try:
assert 'wandb' in self.include and wandb
run_id = torch.load(self.weights).get('wandb_id') if self.opt.resume else None
self.opt.hyp = self.hyp # add hyperparameters
self.wandb = WandbLogger(self.opt, s.stem, run_id, self.data_dict)
except:
self.wandb = None
return self
def on_train_batch_end(self, ni, model, imgs):
# Callback runs on train batch end
if ni == 0:
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
if self.wandb and ni == 10:
files = sorted(self.save_dir.glob('train*.jpg'))
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
def on_train_epoch_end(self, epoch):
# Callback runs on train epoch end
if self.wandb:
self.wandb.current_epoch = epoch + 1
def on_val_batch_end(self, pred, predn, path, names, im):
# Callback runs on train batch end
if self.wandb:
self.wandb.val_one_image(pred, predn, path, names, im)
def on_val_end(self):
# Callback runs on val end
if self.wandb:
files = sorted(self.save_dir.glob('val*.jpg'))
self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})
def on_train_val_end(self, mloss, results, lr, epoch, s, best_fitness, fi):
# Callback runs on validation end during training
vals = list(mloss[:-1]) + list(results) + lr
tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
'x/lr0', 'x/lr1', 'x/lr2'] # params
if self.txt:
with open(self.results_file, 'a') as f:
f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss
if self.tb:
for x, tag in zip(vals, tags):
self.tb.add_scalar(tag, x, epoch) # TensorBoard
if self.wandb:
self.wandb.log({k: v for k, v in zip(tags, vals)})
self.wandb.end_epoch(best_result=best_fitness == fi)
def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
# Callback runs on model save event
if self.wandb:
if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
def on_train_end(self, last, best):
# Callback runs on training end
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
if self.wandb:
wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
wandb.log_artifact(str(best if best.exists() else last), type='model',
name='run_' + self.wandb.wandb_run.id + '_model',
aliases=['latest', 'best', 'stripped'])
self.wandb.finish_run()
def log_images(self, paths):
# Log images
if self.wandb:
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
import sys import sys
from pathlib import Path from pathlib import Path
import wandb import wandb
FILE = Path(__file__).absolute() FILE = Path(__file__).absolute()
sys.path.append(FILE.parents[2].as_posix()) # add utils/ to path sys.path.append(FILE.parents[2].as_posix()) # add utils/ to path
from train import train, parse_opt from train import train, parse_opt
import test
from utils.general import increment_path from utils.general import increment_path
from utils.torch_utils import select_device from utils.torch_utils import select_device
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# You can use grid, bayesian and hyperopt search strategy # You can use grid, bayesian and hyperopt search strategy
# For more info on configuring sweeps visit - https://docs.wandb.ai/guides/sweeps/configuration # For more info on configuring sweeps visit - https://docs.wandb.ai/guides/sweeps/configuration
program: utils/wandb_logging/sweep.py program: utils/loggers/wandb/sweep.py
method: random method: random
metric: metric:
name: metrics/mAP_0.5 name: metrics/mAP_0.5
......
"""Utilities and tools for tracking runs with Weights & Biases.""" """Utilities and tools for tracking runs with Weights & Biases."""
import logging import logging
import os import os
import sys import sys
...@@ -8,15 +9,18 @@ from pathlib import Path ...@@ -8,15 +9,18 @@ from pathlib import Path
import yaml import yaml
from tqdm import tqdm from tqdm import tqdm
sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path FILE = Path(__file__).absolute()
sys.path.append(FILE.parents[3].as_posix()) # add yolov5/ to path
from utils.datasets import LoadImagesAndLabels from utils.datasets import LoadImagesAndLabels
from utils.datasets import img2label_paths from utils.datasets import img2label_paths
from utils.general import colorstr, check_dataset, check_file from utils.general import check_dataset, check_file
try: try:
import wandb import wandb
from wandb import init, finish
except ImportError: assert hasattr(wandb, '__version__') # verify package import not local dir
except (ImportError, AssertionError):
wandb = None wandb = None
RANK = int(os.getenv('RANK', -1)) RANK = int(os.getenv('RANK', -1))
...@@ -134,13 +138,11 @@ class WandbLogger(): ...@@ -134,13 +138,11 @@ class WandbLogger():
if not opt.resume: if not opt.resume:
wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict
# Info useful for resuming from artifacts # Info useful for resuming from artifacts
self.wandb_run.config.update({'opt': vars(opt), 'data_dict': wandb_data_dict}, allow_val_change=True) self.wandb_run.config.update({'opt': vars(opt), 'data_dict': wandb_data_dict},
allow_val_change=True)
self.data_dict = self.setup_training(opt, data_dict) self.data_dict = self.setup_training(opt, data_dict)
if self.job_type == 'Dataset Creation': if self.job_type == 'Dataset Creation':
self.data_dict = self.check_and_upload_dataset(opt) self.data_dict = self.check_and_upload_dataset(opt)
else:
prefix = colorstr('wandb: ')
print(f"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
def check_and_upload_dataset(self, opt): def check_and_upload_dataset(self, opt):
assert wandb, 'Install wandb to upload dataset' assert wandb, 'Install wandb to upload dataset'
...@@ -177,7 +179,6 @@ class WandbLogger(): ...@@ -177,7 +179,6 @@ class WandbLogger():
val_path = Path(self.val_artifact_path) / 'data/images/' val_path = Path(self.val_artifact_path) / 'data/images/'
data_dict['val'] = str(val_path) data_dict['val'] = str(val_path)
if self.val_artifact is not None: if self.val_artifact is not None:
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation") self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"]) self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"])
...@@ -328,7 +329,6 @@ class WandbLogger(): ...@@ -328,7 +329,6 @@ class WandbLogger():
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
self.bbox_media_panel_images.append(wandb.Image(im, boxes=boxes, caption=path.name)) self.bbox_media_panel_images.append(wandb.Image(im, boxes=boxes, caption=path.name))
def log(self, log_dict): def log(self, log_dict):
if self.wandb_run: if self.wandb_run:
for key, value in log_dict.items(): for key, value in log_dict.items():
......
...@@ -327,9 +327,8 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None): ...@@ -327,9 +327,8 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
plt.close() plt.close()
# loggers # loggers
for k, v in loggers.items() or {}: if loggers:
if k == 'wandb' and v: loggers.log_images(save_dir.glob('*labels*.jpg'))
v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.jpg')]}, commit=False)
def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution() def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
......
...@@ -26,6 +26,7 @@ from utils.general import coco80_to_coco91_class, check_dataset, check_file, che ...@@ -26,6 +26,7 @@ from utils.general import coco80_to_coco91_class, check_dataset, check_file, che
from utils.metrics import ap_per_class, ConfusionMatrix from utils.metrics import ap_per_class, ConfusionMatrix
from utils.plots import plot_images, output_to_target, plot_study_txt from utils.plots import plot_images, output_to_target, plot_study_txt
from utils.torch_utils import select_device, time_sync from utils.torch_utils import select_device, time_sync
from utils.loggers import Loggers
def save_one_txt(predn, save_conf, shape, file): def save_one_txt(predn, save_conf, shape, file):
...@@ -97,7 +98,7 @@ def run(data, ...@@ -97,7 +98,7 @@ def run(data,
dataloader=None, dataloader=None,
save_dir=Path(''), save_dir=Path(''),
plots=True, plots=True,
wandb_logger=None, loggers=Loggers(),
compute_loss=None, compute_loss=None,
): ):
# Initialize/load model and set device # Initialize/load model and set device
...@@ -215,8 +216,7 @@ def run(data, ...@@ -215,8 +216,7 @@ def run(data,
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt')) save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
if save_json: if save_json:
save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
if wandb_logger and wandb_logger.wandb_run: loggers.on_val_batch_end(pred, predn, path, names, img[si])
wandb_logger.val_one_image(pred, predn, path, names, img[si])
# Plot images # Plot images
if plots and batch_i < 3: if plots and batch_i < 3:
...@@ -253,9 +253,7 @@ def run(data, ...@@ -253,9 +253,7 @@ def run(data,
# Plots # Plots
if plots: if plots:
confusion_matrix.plot(save_dir=save_dir, names=list(names.values())) confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
if wandb_logger and wandb_logger.wandb: loggers.on_val_end()
val_batches = [wandb_logger.wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('val*.jpg'))]
wandb_logger.log({"Validation": val_batches})
# Save JSON # Save JSON
if save_json and len(jdict): if save_json and len(jdict):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论