Unverified 提交 37495731 authored 作者: Yonghye Kwon's avatar Yonghye Kwon 提交者: GitHub

Add `xyxy2xywhn()` (#3765)

* Edit Comments for numpy2torch tensor process Edit Comments for numpy2torch tensor process * add xyxy2xywhn add xyxy2xywhn * add xyxy2xywhn * formatting * pass arguments pass arguments * edit comment for xyxy2xywhn() edit comment for xyxy2xywhn() * cleanup datasets.py Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 03281f8c
...@@ -23,8 +23,8 @@ from PIL import Image, ExifTags ...@@ -23,8 +23,8 @@ from PIL import Image, ExifTags
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm from tqdm import tqdm
from utils.general import check_requirements, check_file, check_dataset, xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, \ from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \
segment2box, segments2boxes, resample_segments, clean_str xyn2xy, segment2box, segments2boxes, resample_segments, clean_str
from utils.torch_utils import torch_distributed_zero_first from utils.torch_utils import torch_distributed_zero_first
# Parameters # Parameters
...@@ -192,7 +192,7 @@ class LoadImages: # for inference ...@@ -192,7 +192,7 @@ class LoadImages: # for inference
img = letterbox(img0, self.img_size, stride=self.stride)[0] img = letterbox(img0, self.img_size, stride=self.stride)[0]
# Convert # Convert
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB and HWC to CHW
img = np.ascontiguousarray(img) img = np.ascontiguousarray(img)
return path, img, img0, self.cap return path, img, img0, self.cap
...@@ -255,7 +255,7 @@ class LoadWebcam: # for inference ...@@ -255,7 +255,7 @@ class LoadWebcam: # for inference
img = letterbox(img0, self.img_size, stride=self.stride)[0] img = letterbox(img0, self.img_size, stride=self.stride)[0]
# Convert # Convert
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB and HWC to CHW
img = np.ascontiguousarray(img) img = np.ascontiguousarray(img)
return img_path, img, img0, None return img_path, img, img0, None
...@@ -336,7 +336,7 @@ class LoadStreams: # multiple IP or RTSP cameras ...@@ -336,7 +336,7 @@ class LoadStreams: # multiple IP or RTSP cameras
img = np.stack(img, 0) img = np.stack(img, 0)
# Convert # Convert
img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416 img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB and BHWC to BCHW
img = np.ascontiguousarray(img) img = np.ascontiguousarray(img)
return self.sources, img, img0, None return self.sources, img, img0, None
...@@ -552,9 +552,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -552,9 +552,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
nL = len(labels) # number of labels nL = len(labels) # number of labels
if nL: if nL:
labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0]) # xyxy to xywh normalized
labels[:, [2, 4]] /= img.shape[0] # normalized height 0-1
labels[:, [1, 3]] /= img.shape[1] # normalized width 0-1
if self.augment: if self.augment:
# flip up-down # flip up-down
......
...@@ -393,6 +393,16 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): ...@@ -393,6 +393,16 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
return y return y
def xyxy2xywhn(x, w=640, h=640):
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
return y
def xyn2xy(x, w=640, h=640, padw=0, padh=0): def xyn2xy(x, w=640, h=640, padw=0, padh=0):
# Convert normalized segments into pixel segments, shape (n,2) # Convert normalized segments into pixel segments, shape (n,2)
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论