Unverified 提交 fa8f1fb0 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Simplify autoshape() post-process (#1653)

* Simplify autoshape() post-process * cleanup * cleanup
上级 84f9bb5d
...@@ -108,7 +108,7 @@ def yolov5x(pretrained=False, channels=3, classes=80): ...@@ -108,7 +108,7 @@ def yolov5x(pretrained=False, channels=3, classes=80):
if __name__ == '__main__': if __name__ == '__main__':
model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # example model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # example
model = model.fuse().autoshape() # for PIL/cv2/np inputs and NMS model = model.autoshape() # for PIL/cv2/np inputs and NMS
# Verify inference # Verify inference
from PIL import Image from PIL import Image
......
...@@ -167,8 +167,7 @@ class autoShape(nn.Module): ...@@ -167,8 +167,7 @@ class autoShape(nn.Module):
# Post-process # Post-process
for i in batch: for i in batch:
if y[i] is not None: scale_coords(shape1, y[i][:, :4], shape0[i])
y[i][:, :4] = scale_coords(shape1, y[i][:, :4], shape0[i])
return Detections(imgs, y, self.names) return Detections(imgs, y, self.names)
...@@ -177,13 +176,13 @@ class Detections: ...@@ -177,13 +176,13 @@ class Detections:
# detections class for YOLOv5 inference results # detections class for YOLOv5 inference results
def __init__(self, imgs, pred, names=None): def __init__(self, imgs, pred, names=None):
super(Detections, self).__init__() super(Detections, self).__init__()
d = pred[0].device # device
gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations
self.imgs = imgs # list of images as numpy arrays self.imgs = imgs # list of images as numpy arrays
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls) self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
self.names = names # class names self.names = names # class names
self.xyxy = pred # xyxy pixels self.xyxy = pred # xyxy pixels
self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
d = pred[0].device # device
gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
self.n = len(self.pred) self.n = len(self.pred)
......
...@@ -9,8 +9,8 @@ Pillow ...@@ -9,8 +9,8 @@ Pillow
PyYAML>=5.3 PyYAML>=5.3
scipy>=1.4.1 scipy>=1.4.1
tensorboard>=2.2 tensorboard>=2.2
torch>=1.6.0 torch>=1.7.0
torchvision>=0.7.0 torchvision>=0.8.1
tqdm>=4.41.0 tqdm>=4.41.0
# logging ------------------------------------- # logging -------------------------------------
...@@ -26,5 +26,5 @@ pandas ...@@ -26,5 +26,5 @@ pandas
# scikit-learn==0.19.2 # for coreml quantization # scikit-learn==0.19.2 # for coreml quantization
# extras -------------------------------------- # extras --------------------------------------
# thop # FLOPS computation thop # FLOPS computation
# pycocotools>=2.0 # COCO mAP pycocotools>=2.0 # COCO mAP
...@@ -258,7 +258,7 @@ def wh_iou(wh1, wh2): ...@@ -258,7 +258,7 @@ def wh_iou(wh1, wh2):
return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter) return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, classes=None, agnostic=False, labels=()): def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
"""Performs Non-Maximum Suppression (NMS) on inference results """Performs Non-Maximum Suppression (NMS) on inference results
Returns: Returns:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论