提交 7eaf225d authored 作者: Glenn Jocher's avatar Glenn Jocher

zero-target training bug fix (#609)

上级 d0d3dd10
...@@ -496,8 +496,7 @@ def compute_loss(p, targets, model): # predictions, targets, model ...@@ -496,8 +496,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
s = 3 / np # output count scaling s = 3 / np # output count scaling
lbox *= h['giou'] * s lbox *= h['giou'] * s
lobj *= h['obj'] * s * (1.4 if np == 4 else 1.) lobj *= h['obj'] * s * (1.4 if np == 4 else 1.)
if model.nc > 1: lcls *= h['cls'] * s
lcls *= h['cls'] * s
bs = tobj.shape[0] # batch size bs = tobj.shape[0] # batch size
loss = lbox + lobj + lcls loss = lbox + lobj + lcls
...@@ -524,7 +523,7 @@ def build_targets(p, targets, model): ...@@ -524,7 +523,7 @@ def build_targets(p, targets, model):
gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
# Match targets to anchors # Match targets to anchors
t, offsets = targets * gain, 0 t = targets * gain
if nt: if nt:
# Matches # Matches
r = t[:, :, 4:6] / anchors[:, None] # wh ratio r = t[:, :, 4:6] / anchors[:, None] # wh ratio
...@@ -540,6 +539,9 @@ def build_targets(p, targets, model): ...@@ -540,6 +539,9 @@ def build_targets(p, targets, model):
j = torch.stack((torch.ones_like(j), j, k, l, m)) j = torch.stack((torch.ones_like(j), j, k, l, m))
t = t.repeat((5, 1, 1))[j] t = t.repeat((5, 1, 1))[j]
offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j] offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
else:
t = targets[0]
offsets = 0
# Define # Define
b, c = t[:, :2].long().T # image, class b, c = t[:, :2].long().T # image, class
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论