提交 453acdec authored 作者: Glenn Jocher's avatar Glenn Jocher

Update tensorboard logging

上级 7f164069
...@@ -191,9 +191,9 @@ def test(data, ...@@ -191,9 +191,9 @@ def test(data,
# Plot images # Plot images
if plots and batch_i < 1: if plots and batch_i < 1:
f = save_dir / ('test_batch%g_gt.jpg' % batch_i) # filename f = save_dir / f'test_batch{batch_i}_gt.jpg' # filename
plot_images(img, targets, paths, str(f), names) # ground truth plot_images(img, targets, paths, str(f), names) # ground truth
f = save_dir / ('test_batch%g_pred.jpg' % batch_i) f = save_dir / f'test_batch{batch_i}_pred.jpg'
plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions
# Compute statistics # Compute statistics
......
...@@ -291,10 +291,10 @@ def train(hyp, opt, device, tb_writer=None): ...@@ -291,10 +291,10 @@ def train(hyp, opt, device, tb_writer=None):
# Plot # Plot
if ni < 3: if ni < 3:
f = str(log_dir / ('train_batch%g.jpg' % ni)) # filename f = str(log_dir / f'train_batch{ni}.jpg') # filename
result = plot_images(images=imgs, targets=targets, paths=paths, fname=f) result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
if tb_writer and result is not None: # if tb_writer and result is not None:
tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
# tb_writer.add_graph(model, imgs) # add model to tensorboard # tb_writer.add_graph(model, imgs) # add model to tensorboard
# end batch ------------------------------------------------------------------------------------------------ # end batch ------------------------------------------------------------------------------------------------
......
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import yaml import yaml
from PIL import Image
from scipy.cluster.vq import kmeans from scipy.cluster.vq import kmeans
from scipy.signal import butter, filtfilt from scipy.signal import butter, filtfilt
from tqdm import tqdm from tqdm import tqdm
...@@ -1096,8 +1097,8 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max ...@@ -1096,8 +1097,8 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
if fname is not None: if fname is not None:
mosaic = cv2.resize(mosaic, (int(ns * w * 0.5), int(ns * h * 0.5)), interpolation=cv2.INTER_AREA) mosaic = cv2.resize(mosaic, (int(ns * w * 0.5), int(ns * h * 0.5)), interpolation=cv2.INTER_AREA)
cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save
Image.fromarray(mosaic).save(fname) # PIL save
return mosaic return mosaic
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论