Unverified 提交 a2d617ec authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update loss for FP16 `tobj` (#7088)

上级 6f128031
......@@ -125,7 +125,7 @@ class ComputeLoss:
# Losses
for i, pi in enumerate(p): # layer index, layer predictions
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
tobj = torch.zeros(pi.shape[:4], device=self.device) # target obj
tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device) # target obj
n = b.shape[0] # number of targets
if n:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论