Unverified 提交 8f354362 authored 作者: Yono Mittlefehldt's avatar Yono Mittlefehldt 提交者: GitHub

Fix Detections class `tolist()` method (#5945)

* Fix tolist() to add the file for each Detection * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix PEP8 requirement for 2 spaces before an inline comment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleanup Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 8f875d93
...@@ -525,7 +525,7 @@ class AutoShape(nn.Module): ...@@ -525,7 +525,7 @@ class AutoShape(nn.Module):
class Detections: class Detections:
# YOLOv5 detections class for inference results # YOLOv5 detections class for inference results
def __init__(self, imgs, pred, files, times=None, names=None, shape=None): def __init__(self, imgs, pred, files, times=(0, 0, 0, 0), names=None, shape=None):
super().__init__() super().__init__()
d = pred[0].device # device d = pred[0].device # device
gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in imgs] # normalizations gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in imgs] # normalizations
...@@ -533,6 +533,7 @@ class Detections: ...@@ -533,6 +533,7 @@ class Detections:
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls) self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
self.names = names # class names self.names = names # class names
self.files = files # image filenames self.files = files # image filenames
self.times = times # profiling times
self.xyxy = pred # xyxy pixels self.xyxy = pred # xyxy pixels
self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
...@@ -612,10 +613,11 @@ class Detections: ...@@ -612,10 +613,11 @@ class Detections:
def tolist(self): def tolist(self):
# return a list of Detections objects, i.e. 'for result in results.tolist():' # return a list of Detections objects, i.e. 'for result in results.tolist():'
x = [Detections([self.imgs[i]], [self.pred[i]], names=self.names, shape=self.s) for i in range(self.n)] r = range(self.n) # iterable
for d in x: x = [Detections([self.imgs[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']: # for d in x:
setattr(d, k, getattr(d, k)[0]) # pop out of list # for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
# setattr(d, k, getattr(d, k)[0]) # pop out of list
return x return x
def __len__(self): def __len__(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论