Unverified 提交 f0e5a608 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Add CSV logging to GenericLogger (#9128)

Enable CSV logging for Classify training. Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 f8816f58
...@@ -242,9 +242,10 @@ class GenericLogger: ...@@ -242,9 +242,10 @@ class GenericLogger:
def __init__(self, opt, console_logger, include=('tb', 'wandb')): def __init__(self, opt, console_logger, include=('tb', 'wandb')):
# init default loggers # init default loggers
self.save_dir = opt.save_dir self.save_dir = Path(opt.save_dir)
self.include = include self.include = include
self.console_logger = console_logger self.console_logger = console_logger
self.csv = self.save_dir / 'results.csv' # CSV logger
if 'tb' in self.include: if 'tb' in self.include:
prefix = colorstr('TensorBoard: ') prefix = colorstr('TensorBoard: ')
self.console_logger.info( self.console_logger.info(
...@@ -258,14 +259,21 @@ class GenericLogger: ...@@ -258,14 +259,21 @@ class GenericLogger:
else: else:
self.wandb = None self.wandb = None
def log_metrics(self, metrics_dict, epoch): def log_metrics(self, metrics, epoch):
# Log metrics dictionary to all loggers # Log metrics dictionary to all loggers
if self.csv:
keys, vals = list(metrics.keys()), list(metrics.values())
n = len(metrics) + 1 # number of cols
s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
with open(self.csv, 'a') as f:
f.write(s + ('%23.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n')
if self.tb: if self.tb:
for k, v in metrics_dict.items(): for k, v in metrics.items():
self.tb.add_scalar(k, v, epoch) self.tb.add_scalar(k, v, epoch)
if self.wandb: if self.wandb:
self.wandb.log(metrics_dict, step=epoch) self.wandb.log(metrics, step=epoch)
def log_images(self, files, name='Images', epoch=0): def log_images(self, files, name='Images', epoch=0):
# Log images to all loggers # Log images to all loggers
...@@ -291,6 +299,11 @@ class GenericLogger: ...@@ -291,6 +299,11 @@ class GenericLogger:
art.add_file(str(model_path)) art.add_file(str(model_path))
wandb.log_artifact(art) wandb.log_artifact(art)
def update_params(self, params):
# Update the paramters logged
if self.wandb:
wandb.run.config.update(params, allow_val_change=True)
def log_tensorboard_graph(tb, model, imgsz=(640, 640)): def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
# Log model graph to TensorBoard # Log model graph to TensorBoard
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论