提交 09402a21 authored 作者: Glenn Jocher's avatar Glenn Jocher

torch.from_tensor() bug fix

上级 4fb8cb35
...@@ -225,7 +225,7 @@ def train(hyp, opt, device, tb_writer=None): ...@@ -225,7 +225,7 @@ def train(hyp, opt, device, tb_writer=None):
if rank != -1: if rank != -1:
indices = torch.zeros([dataset.n], dtype=torch.int) indices = torch.zeros([dataset.n], dtype=torch.int)
if rank == 0: if rank == 0:
indices[:] = torch.from_tensor(dataset.indices, dtype=torch.int) indices[:] = torch.tensor(dataset.indices, dtype=torch.int)
dist.broadcast(indices, 0) dist.broadcast(indices, 0)
if rank != 0: if rank != 0:
dataset.indices = indices.cpu().numpy() dataset.indices = indices.cpu().numpy()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论