Unverified 提交 d9212140 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Add xywhn2xyxy() (#1983)

上级 1ca2d26b
...@@ -20,7 +20,7 @@ from PIL import Image, ExifTags ...@@ -20,7 +20,7 @@ from PIL import Image, ExifTags
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm from tqdm import tqdm
from utils.general import xyxy2xywh, xywh2xyxy, clean_str from utils.general import xyxy2xywh, xywh2xyxy, xywhn2xyxy, clean_str
from utils.torch_utils import torch_distributed_zero_first from utils.torch_utils import torch_distributed_zero_first
# Parameters # Parameters
...@@ -515,16 +515,9 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -515,16 +515,9 @@ class LoadImagesAndLabels(Dataset): # for training/testing
img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment) img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
# Load labels labels = self.labels[index].copy()
labels = [] if labels.size: # normalized xywh to pixel xyxy format
x = self.labels[index] labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
if x.size > 0:
# Normalized xywh to pixel xyxy format
labels = x.copy()
labels[:, 1] = ratio[0] * w * (x[:, 1] - x[:, 3] / 2) + pad[0] # pad width
labels[:, 2] = ratio[1] * h * (x[:, 2] - x[:, 4] / 2) + pad[1] # pad height
labels[:, 3] = ratio[0] * w * (x[:, 1] + x[:, 3] / 2) + pad[0]
labels[:, 4] = ratio[1] * h * (x[:, 2] + x[:, 4] / 2) + pad[1]
if self.augment: if self.augment:
# Augment imagespace # Augment imagespace
...@@ -674,13 +667,9 @@ def load_mosaic(self, index): ...@@ -674,13 +667,9 @@ def load_mosaic(self, index):
padh = y1a - y1b padh = y1a - y1b
# Labels # Labels
x = self.labels[index] labels = self.labels[index].copy()
labels = x.copy() if labels.size:
if x.size > 0: # Normalized xywh to pixel xyxy format labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
labels[:, 1] = w * (x[:, 1] - x[:, 3] / 2) + padw
labels[:, 2] = h * (x[:, 2] - x[:, 4] / 2) + padh
labels[:, 3] = w * (x[:, 1] + x[:, 3] / 2) + padw
labels[:, 4] = h * (x[:, 2] + x[:, 4] / 2) + padh
labels4.append(labels) labels4.append(labels)
# Concat/clip labels # Concat/clip labels
...@@ -737,13 +726,9 @@ def load_mosaic9(self, index): ...@@ -737,13 +726,9 @@ def load_mosaic9(self, index):
x1, y1, x2, y2 = [max(x, 0) for x in c] # allocate coords x1, y1, x2, y2 = [max(x, 0) for x in c] # allocate coords
# Labels # Labels
x = self.labels[index] labels = self.labels[index].copy()
labels = x.copy() if labels.size:
if x.size > 0: # Normalized xywh to pixel xyxy format labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
labels[:, 1] = w * (x[:, 1] - x[:, 3] / 2) + padx
labels[:, 2] = h * (x[:, 2] - x[:, 4] / 2) + pady
labels[:, 3] = w * (x[:, 1] + x[:, 3] / 2) + padx
labels[:, 4] = h * (x[:, 2] + x[:, 4] / 2) + pady
labels9.append(labels) labels9.append(labels)
# Image # Image
......
...@@ -223,6 +223,16 @@ def xywh2xyxy(x): ...@@ -223,6 +223,16 @@ def xywh2xyxy(x):
return y return y
def xywhn2xyxy(x, w=640, h=640, padw=32, padh=32):
# Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
return y
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
# Rescale coords (xyxy) from img1_shape to img0_shape # Rescale coords (xyxy) from img1_shape to img0_shape
if ratio_pad is None: # calculate from img0_shape if ratio_pad is None: # calculate from img0_shape
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论