Unverified 提交 87e8dead authored 作者: 0zppd's avatar 0zppd 提交者: GitHub

zero-mAP fix remove `torch.empty()` forward pass in `.train()` mode (#9068)

上级 e6b4bf0b
...@@ -296,7 +296,7 @@ def log_tensorboard_graph(tb, model, imgsz=(640, 640)): ...@@ -296,7 +296,7 @@ def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
try: try:
p = next(model.parameters()) # for device, type p = next(model.parameters()) # for device, type
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz # expand imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz # expand
im = torch.empty((1, 3, *imgsz)).to(p.device).type_as(p) # input image im = torch.zeros((1, 3, *imgsz)).to(p.device).type_as(p) # input image (WARNING: must be zeros, not empty)
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning warnings.simplefilter('ignore') # suppress jit trace warning
tb.add_graph(torch.jit.trace(de_parallel(model), im, strict=False), []) tb.add_graph(torch.jit.trace(de_parallel(model), im, strict=False), [])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论