Unverified 提交 7aa263c5 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update DetectMultiBackend for tuple outputs 2 (#9275)

* Update DetectMultiBackend for tuple outputs 2 Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update * Update * Update Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
上级 96c3c7f7
......@@ -457,7 +457,7 @@ class DetectMultiBackend(nn.Module):
self.__dict__.update(locals()) # assign all variables to self
def forward(self, im, augment=False, visualize=False, val=False):
def forward(self, im, augment=False, visualize=False):
# YOLOv5 MultiBackend inference
b, ch, h, w = im.shape # batch, channel, height, width
if self.fp16 and im.dtype != torch.float16:
......@@ -521,10 +521,12 @@ class DetectMultiBackend(nn.Module):
y[..., :4] *= [w, h, w, h] # xywh normalized to pixels
if isinstance(y, (list, tuple)):
y = y[0]
if isinstance(y, np.ndarray):
y = torch.from_numpy(y).to(self.device)
return (y, []) if val else y
return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
else:
return self.from_numpy(y)
def from_numpy(self, x):
return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x
def warmup(self, imgsz=(1, 3, 640, 640)):
# Warmup model by running inference once
......
......@@ -813,6 +813,9 @@ def non_max_suppression(prediction,
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output
bs = prediction.shape[0] # batch size
nc = prediction.shape[2] - 5 # number of classes
xc = prediction[..., 4] > conf_thres # candidates
......
......@@ -204,11 +204,11 @@ def run(
# Inference
with dt[1]:
out, train_out = model(im) if training else model(im, augment=augment, val=True) # inference, loss outputs
out, train_out = model(im) if compute_loss else (model(im, augment=augment), None)
# Loss
if compute_loss:
loss += compute_loss([x.float() for x in train_out], targets)[1] # box, obj, cls
loss += compute_loss(train_out, targets)[1] # box, obj, cls
# NMS
targets[:, 2:] *= torch.tensor((width, height, width, height), device=device) # to pixels
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论