Unverified 提交 48e56d3c authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Add optional `transforms` argument to LoadStreams() (#9105)

* Add optional `transforms` argument to LoadStreams() Prepare for streaming classification support Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Cleanup Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * fix * batch size > 1 fix Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 d0fa0042
...@@ -251,7 +251,7 @@ class LoadImages: ...@@ -251,7 +251,7 @@ class LoadImages:
s = f'image {self.count}/{self.nf} {path}: ' s = f'image {self.count}/{self.nf} {path}: '
if self.transforms: if self.transforms:
im = self.transforms(cv2.cvtColor(im0, cv2.COLOR_BGR2RGB)) # classify transforms im = self.transforms(cv2.cvtColor(im0, cv2.COLOR_BGR2RGB)) # transforms
else: else:
im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
...@@ -289,22 +289,20 @@ class LoadWebcam: # for inference ...@@ -289,22 +289,20 @@ class LoadWebcam: # for inference
raise StopIteration raise StopIteration
# Read frame # Read frame
ret_val, img0 = self.cap.read() ret_val, im0 = self.cap.read()
img0 = cv2.flip(img0, 1) # flip left-right im0 = cv2.flip(im0, 1) # flip left-right
# Print # Print
assert ret_val, f'Camera Error {self.pipe}' assert ret_val, f'Camera Error {self.pipe}'
img_path = 'webcam.jpg' img_path = 'webcam.jpg'
s = f'webcam {self.count}: ' s = f'webcam {self.count}: '
# Padded resize # Process
img = letterbox(img0, self.img_size, stride=self.stride)[0] im = letterbox(im0, self.img_size, stride=self.stride)[0] # resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous
# Convert return img_path, im, im0, None, s
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
img = np.ascontiguousarray(img)
return img_path, img, img0, None, s
def __len__(self): def __len__(self):
return 0 return 0
...@@ -312,7 +310,7 @@ class LoadWebcam: # for inference ...@@ -312,7 +310,7 @@ class LoadWebcam: # for inference
class LoadStreams: class LoadStreams:
# YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams` # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True): def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True, transforms=None):
self.mode = 'stream' self.mode = 'stream'
self.img_size = img_size self.img_size = img_size
self.stride = stride self.stride = stride
...@@ -326,7 +324,6 @@ class LoadStreams: ...@@ -326,7 +324,6 @@ class LoadStreams:
n = len(sources) n = len(sources)
self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
self.sources = [clean_str(x) for x in sources] # clean source names for later self.sources = [clean_str(x) for x in sources] # clean source names for later
self.auto = auto
for i, s in enumerate(sources): # index, source for i, s in enumerate(sources): # index, source
# Start thread to read frames from video stream # Start thread to read frames from video stream
st = f'{i + 1}/{n}: {s}... ' st = f'{i + 1}/{n}: {s}... '
...@@ -353,8 +350,10 @@ class LoadStreams: ...@@ -353,8 +350,10 @@ class LoadStreams:
LOGGER.info('') # newline LOGGER.info('') # newline
# check for common shapes # check for common shapes
s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs]) s = np.stack([letterbox(x, img_size, stride=stride, auto=auto)[0].shape for x in self.imgs])
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
self.auto = auto and self.rect
self.transforms = transforms # optional
if not self.rect: if not self.rect:
LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.') LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.')
...@@ -385,18 +384,15 @@ class LoadStreams: ...@@ -385,18 +384,15 @@ class LoadStreams:
cv2.destroyAllWindows() cv2.destroyAllWindows()
raise StopIteration raise StopIteration
# Letterbox im0 = self.imgs.copy()
img0 = self.imgs.copy() if self.transforms:
img = [letterbox(x, self.img_size, stride=self.stride, auto=self.rect and self.auto)[0] for x in img0] im = np.stack([self.transforms(cv2.cvtColor(x, cv2.COLOR_BGR2RGB)) for x in im0]) # transforms
else:
# Stack im = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0] for x in im0]) # resize
img = np.stack(img, 0) im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
im = np.ascontiguousarray(im) # contiguous
# Convert
img = img[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
img = np.ascontiguousarray(img)
return self.sources, img, img0, None, '' return self.sources, im, im0, None, ''
def __len__(self): def __len__(self):
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
...@@ -836,7 +832,7 @@ class LoadImagesAndLabels(Dataset): ...@@ -836,7 +832,7 @@ class LoadImagesAndLabels(Dataset):
@staticmethod @staticmethod
def collate_fn4(batch): def collate_fn4(batch):
img, label, path, shapes = zip(*batch) # transposed im, label, path, shapes = zip(*batch) # transposed
n = len(shapes) // 4 n = len(shapes) // 4
im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n] im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
...@@ -846,13 +842,13 @@ class LoadImagesAndLabels(Dataset): ...@@ -846,13 +842,13 @@ class LoadImagesAndLabels(Dataset):
for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
i *= 4 i *= 4
if random.random() < 0.5: if random.random() < 0.5:
im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear', im1 = F.interpolate(im[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
align_corners=False)[0].type(img[i].type()) align_corners=False)[0].type(im[i].type())
lb = label[i] lb = label[i]
else: else:
im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2) im1 = torch.cat((torch.cat((im[i], im[i + 1]), 1), torch.cat((im[i + 2], im[i + 3]), 1)), 2)
lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
im4.append(im) im4.append(im1)
label4.append(lb) label4.append(lb)
for i, lb in enumerate(label4): for i, lb in enumerate(label4):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论