Unverified 提交 b57abb17 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Move trainloader functions to class methods (#6559)

* Move trainloader functions to class methods * results = ThreadPool(NUM_THREADS).imap(self.load_image, range(n)) * Cleanup
上级 dc7e0930
...@@ -484,7 +484,7 @@ class LoadImagesAndLabels(Dataset): ...@@ -484,7 +484,7 @@ class LoadImagesAndLabels(Dataset):
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
# Cache images into memory for faster training (WARNING: large datasets may exceed system RAM) # Cache images into RAM/disk for faster training (WARNING: large datasets may exceed system resources)
self.imgs, self.img_npy = [None] * n, [None] * n self.imgs, self.img_npy = [None] * n, [None] * n
if cache_images: if cache_images:
if cache_images == 'disk': if cache_images == 'disk':
...@@ -493,14 +493,14 @@ class LoadImagesAndLabels(Dataset): ...@@ -493,14 +493,14 @@ class LoadImagesAndLabels(Dataset):
self.im_cache_dir.mkdir(parents=True, exist_ok=True) self.im_cache_dir.mkdir(parents=True, exist_ok=True)
gb = 0 # Gigabytes of cached images gb = 0 # Gigabytes of cached images
self.img_hw0, self.img_hw = [None] * n, [None] * n self.img_hw0, self.img_hw = [None] * n, [None] * n
results = ThreadPool(NUM_THREADS).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) results = ThreadPool(NUM_THREADS).imap(self.load_image, range(n))
pbar = tqdm(enumerate(results), total=n) pbar = tqdm(enumerate(results), total=n)
for i, x in pbar: for i, x in pbar:
if cache_images == 'disk': if cache_images == 'disk':
if not self.img_npy[i].exists(): if not self.img_npy[i].exists():
np.save(self.img_npy[i].as_posix(), x[0]) np.save(self.img_npy[i].as_posix(), x[0])
gb += self.img_npy[i].stat().st_size gb += self.img_npy[i].stat().st_size
else: else: # 'ram'
self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i) self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
gb += self.imgs[i].nbytes gb += self.imgs[i].nbytes
pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB {cache_images})' pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB {cache_images})'
...@@ -558,16 +558,16 @@ class LoadImagesAndLabels(Dataset): ...@@ -558,16 +558,16 @@ class LoadImagesAndLabels(Dataset):
mosaic = self.mosaic and random.random() < hyp['mosaic'] mosaic = self.mosaic and random.random() < hyp['mosaic']
if mosaic: if mosaic:
# Load mosaic # Load mosaic
img, labels = load_mosaic(self, index) img, labels = self.load_mosaic(index)
shapes = None shapes = None
# MixUp augmentation # MixUp augmentation
if random.random() < hyp['mixup']: if random.random() < hyp['mixup']:
img, labels = mixup(img, labels, *load_mosaic(self, random.randint(0, self.n - 1))) img, labels = mixup(img, labels, *self.load_mosaic(random.randint(0, self.n - 1)))
else: else:
# Load image # Load image
img, (h0, w0), (h, w) = load_image(self, index) img, (h0, w0), (h, w) = self.load_image(index)
# Letterbox # Letterbox
shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
...@@ -624,63 +624,28 @@ class LoadImagesAndLabels(Dataset): ...@@ -624,63 +624,28 @@ class LoadImagesAndLabels(Dataset):
return torch.from_numpy(img), labels_out, self.img_files[index], shapes return torch.from_numpy(img), labels_out, self.img_files[index], shapes
@staticmethod def load_image(self, i):
def collate_fn(batch): # loads 1 image from dataset index 'i', returns (im, original hw, resized hw)
img, label, path, shapes = zip(*batch) # transposed
for i, lb in enumerate(label):
lb[:, 0] = i # add target image index for build_targets()
return torch.stack(img, 0), torch.cat(label, 0), path, shapes
@staticmethod
def collate_fn4(batch):
img, label, path, shapes = zip(*batch) # transposed
n = len(shapes) // 4
img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]]) # scale
for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
i *= 4
if random.random() < 0.5:
im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear', align_corners=False)[
0].type(img[i].type())
lb = label[i]
else:
im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
img4.append(im)
label4.append(lb)
for i, lb in enumerate(label4):
lb[:, 0] = i # add target image index for build_targets()
return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4
# Ancillary functions --------------------------------------------------------------------------------------------------
def load_image(self, i):
# loads 1 image from dataset index 'i', returns im, original hw, resized hw
im = self.imgs[i] im = self.imgs[i]
if im is None: # not cached in ram if im is None: # not cached in RAM
npy = self.img_npy[i] npy = self.img_npy[i]
if npy and npy.exists(): # load npy if npy and npy.exists(): # load npy
im = np.load(npy) im = np.load(npy)
else: # read image else: # read image
path = self.img_files[i] f = self.img_files[i]
im = cv2.imread(path) # BGR im = cv2.imread(f) # BGR
assert im is not None, f'Image Not Found {path}' assert im is not None, f'Image Not Found {f}'
h0, w0 = im.shape[:2] # orig hw h0, w0 = im.shape[:2] # orig hw
r = self.img_size / max(h0, w0) # ratio r = self.img_size / max(h0, w0) # ratio
if r != 1: # if sizes are not equal if r != 1: # if sizes are not equal
im = cv2.resize(im, (int(w0 * r), int(h0 * r)), im = cv2.resize(im,
interpolation=cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR) (int(w0 * r), int(h0 * r)),
interpolation=cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA)
return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
else: else:
return self.imgs[i], self.img_hw0[i], self.img_hw[i] # im, hw_original, hw_resized return self.imgs[i], self.img_hw0[i], self.img_hw[i] # im, hw_original, hw_resized
def load_mosaic(self, index):
def load_mosaic(self, index):
# YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
labels4, segments4 = [], [] labels4, segments4 = [], []
s = self.img_size s = self.img_size
...@@ -689,7 +654,7 @@ def load_mosaic(self, index): ...@@ -689,7 +654,7 @@ def load_mosaic(self, index):
random.shuffle(indices) random.shuffle(indices)
for i, index in enumerate(indices): for i, index in enumerate(indices):
# Load image # Load image
img, _, (h, w) = load_image(self, index) img, _, (h, w) = self.load_image(index)
# place img in img4 # place img in img4
if i == 0: # top left if i == 0: # top left
...@@ -736,8 +701,7 @@ def load_mosaic(self, index): ...@@ -736,8 +701,7 @@ def load_mosaic(self, index):
return img4, labels4 return img4, labels4
def load_mosaic9(self, index):
def load_mosaic9(self, index):
# YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic # YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic
labels9, segments9 = [], [] labels9, segments9 = [], []
s = self.img_size s = self.img_size
...@@ -746,7 +710,7 @@ def load_mosaic9(self, index): ...@@ -746,7 +710,7 @@ def load_mosaic9(self, index):
hp, wp = -1, -1 # height, width previous hp, wp = -1, -1 # height, width previous
for i, index in enumerate(indices): for i, index in enumerate(indices):
# Load image # Load image
img, _, (h, w) = load_image(self, index) img, _, (h, w) = self.load_image(index)
# place img in img9 # place img in img9
if i == 0: # center if i == 0: # center
...@@ -811,7 +775,41 @@ def load_mosaic9(self, index): ...@@ -811,7 +775,41 @@ def load_mosaic9(self, index):
return img9, labels9 return img9, labels9
@staticmethod
def collate_fn(batch):
img, label, path, shapes = zip(*batch) # transposed
for i, lb in enumerate(label):
lb[:, 0] = i # add target image index for build_targets()
return torch.stack(img, 0), torch.cat(label, 0), path, shapes
@staticmethod
def collate_fn4(batch):
img, label, path, shapes = zip(*batch) # transposed
n = len(shapes) // 4
img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]]) # scale
for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
i *= 4
if random.random() < 0.5:
im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear', align_corners=False)[
0].type(img[i].type())
lb = label[i]
else:
im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
img4.append(im)
label4.append(lb)
for i, lb in enumerate(label4):
lb[:, 0] = i # add target image index for build_targets()
return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4
# Ancillary functions --------------------------------------------------------------------------------------------------
def create_folder(path='./new'): def create_folder(path='./new'):
# Create folder # Create folder
if os.path.exists(path): if os.path.exists(path):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论