Unverified 提交 4aa29591 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Suppress jit trace warning + graph once (#3454)

* Suppress jit trace warning + graph once Suppress harmless jit trace warning on TensorBoard add_graph call. Also fix multiple add_graph() calls bug, now only on batch 0. * Update train.py
上级 af2bc3a1
...@@ -4,6 +4,7 @@ import math ...@@ -4,6 +4,7 @@ import math
import os import os
import random import random
import time import time
import warnings
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread
...@@ -323,18 +324,19 @@ def train(hyp, opt, device, tb_writer=None): ...@@ -323,18 +324,19 @@ def train(hyp, opt, device, tb_writer=None):
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB) mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
s = ('%10s' * 2 + '%10.4g' * 6) % ( s = ('%10s' * 2 + '%10.4g' * 6) % (
'%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1]) f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])
pbar.set_description(s) pbar.set_description(s)
# Plot # Plot
if plots and ni < 3: if plots and ni < 3:
f = save_dir / f'train_batch{ni}.jpg' # filename f = save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start() Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
if tb_writer: if tb_writer and ni == 0:
tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs, strict=False), []) # model graph with warnings.catch_warnings():
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) warnings.simplefilter('ignore') # suppress jit trace warning
tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs, strict=False), []) # graph
elif plots and ni == 10 and wandb_logger.wandb: elif plots and ni == 10 and wandb_logger.wandb:
wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in wandb_logger.log({'Mosaics': [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
save_dir.glob('train*.jpg') if x.exists()]}) save_dir.glob('train*.jpg') if x.exists()]})
# end batch ------------------------------------------------------------------------------------------------ # end batch ------------------------------------------------------------------------------------------------
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论