Unverified 提交 aad99b63 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

TensorBoard DP/DDP graph fix (#3325)

上级 407dc500
...@@ -32,7 +32,7 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima ...@@ -32,7 +32,7 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima
from utils.google_utils import attempt_download from utils.google_utils import attempt_download
from utils.loss import ComputeLoss from utils.loss import ComputeLoss
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -331,7 +331,7 @@ def train(hyp, opt, device, tb_writer=None): ...@@ -331,7 +331,7 @@ def train(hyp, opt, device, tb_writer=None):
f = save_dir / f'train_batch{ni}.jpg' # filename f = save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start() Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
if tb_writer: if tb_writer:
tb_writer.add_graph(torch.jit.trace(model, imgs, strict=False), []) # add model graph tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs, strict=False), []) # model graph
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
elif plots and ni == 10 and wandb_logger.wandb: elif plots and ni == 10 and wandb_logger.wandb:
wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
...@@ -390,7 +390,7 @@ def train(hyp, opt, device, tb_writer=None): ...@@ -390,7 +390,7 @@ def train(hyp, opt, device, tb_writer=None):
ckpt = {'epoch': epoch, ckpt = {'epoch': epoch,
'best_fitness': best_fitness, 'best_fitness': best_fitness,
'training_results': results_file.read_text(), 'training_results': results_file.read_text(),
'model': deepcopy(model.module if is_parallel(model) else model).half(), 'model': deepcopy(de_parallel(model)).half(),
'ema': deepcopy(ema.ema).half(), 'ema': deepcopy(ema.ema).half(),
'updates': ema.updates, 'updates': ema.updates,
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
......
...@@ -134,9 +134,15 @@ def profile(x, ops, n=100, device=None): ...@@ -134,9 +134,15 @@ def profile(x, ops, n=100, device=None):
def is_parallel(model): def is_parallel(model):
# Returns True if model is of type DP or DDP
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
def de_parallel(model):
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
return model.module if is_parallel(model) else model
def intersect_dicts(da, db, exclude=()): def intersect_dicts(da, db, exclude=()):
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape} return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论