Unverified 提交 2d990632 authored 作者: junji hashimoto's avatar junji hashimoto 提交者: GitHub

Feature `python train.py --cache disk` (#4049)

* Add cache-on-disk and cache-directory to cache images on disk * Fix load_image with cache_on_disk * Add no_cache flag for load_image * Revert the parts('logging' and a new line) that do not need to be modified * Add the assertion for shapes of cached images * Add a suffix string for cached images * Fix boundary-error of letterbox for load_mosaic * Add prefix as cache-key of cache-on-disk * Update cache-function on disk * Add psutil in requirements.txt * Update train.py * Cleanup1 * Cleanup2 * Skip existing npy * Include re-space * Export return character fix Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 621caea5
...@@ -156,8 +156,8 @@ def run(weights='./yolov5s.pt', # weights path ...@@ -156,8 +156,8 @@ def run(weights='./yolov5s.pt', # weights path
# Finish # Finish
print(f'\nExport complete ({time.time() - t:.2f}s)' print(f'\nExport complete ({time.time() - t:.2f}s)'
f"Results saved to {colorstr('bold', file.parent.resolve())}\n" f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
f'Visualize with https://netron.app') f'\nVisualize with https://netron.app')
def parse_opt(): def parse_opt():
......
...@@ -201,7 +201,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -201,7 +201,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Trainloader # Trainloader
train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls, train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK, hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=RANK,
workers=workers, image_weights=opt.image_weights, quad=opt.quad, workers=workers, image_weights=opt.image_weights, quad=opt.quad,
prefix=colorstr('train: ')) prefix=colorstr('train: '))
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
...@@ -211,7 +211,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -211,7 +211,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Process 0 # Process 0
if RANK in [-1, 0]: if RANK in [-1, 0]:
val_loader = create_dataloader(val_path, imgsz, batch_size // WORLD_SIZE * 2, gs, single_cls, val_loader = create_dataloader(val_path, imgsz, batch_size // WORLD_SIZE * 2, gs, single_cls,
hyp=hyp, cache=opt.cache_images and not noval, rect=True, rank=-1, hyp=hyp, cache=None if noval else opt.cache, rect=True, rank=-1,
workers=workers, pad=0.5, workers=workers, pad=0.5,
prefix=colorstr('val: '))[0] prefix=colorstr('val: '))[0]
...@@ -389,7 +389,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -389,7 +389,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# end epoch ---------------------------------------------------------------------------------------------------- # end epoch ----------------------------------------------------------------------------------------------------
# end training ----------------------------------------------------------------------------------------------------- # end training -----------------------------------------------------------------------------------------------------
if RANK in [-1, 0]: if RANK in [-1, 0]:
LOGGER.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n') LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
if not evolve: if not evolve:
if is_coco: # COCO dataset if is_coco: # COCO dataset
for m in [last, best] if best.exists() else [last]: # speed, mAP tests for m in [last, best] if best.exists() else [last]: # speed, mAP tests
...@@ -430,7 +430,7 @@ def parse_opt(known=False): ...@@ -430,7 +430,7 @@ def parse_opt(known=False):
parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check') parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations') parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
parser.add_argument('--cache-images', action='store_true', help='cache images for faster training') parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training') parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
......
...@@ -455,16 +455,25 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -455,16 +455,25 @@ class LoadImagesAndLabels(Dataset): # for training/testing
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
# Cache images into memory for faster training (WARNING: large datasets may exceed system RAM) # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
self.imgs = [None] * n self.imgs, self.img_npy = [None] * n, [None] * n
if cache_images: if cache_images:
if cache_images == 'disk':
self.im_cache_dir = Path(Path(self.img_files[0]).parent.as_posix() + '_npy')
self.img_npy = [self.im_cache_dir / Path(f).with_suffix('.npy').name for f in self.img_files]
self.im_cache_dir.mkdir(parents=True, exist_ok=True)
gb = 0 # Gigabytes of cached images gb = 0 # Gigabytes of cached images
self.img_hw0, self.img_hw = [None] * n, [None] * n self.img_hw0, self.img_hw = [None] * n, [None] * n
results = ThreadPool(NUM_THREADS).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) results = ThreadPool(NUM_THREADS).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
pbar = tqdm(enumerate(results), total=n) pbar = tqdm(enumerate(results), total=n)
for i, x in pbar: for i, x in pbar:
self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i) if cache_images == 'disk':
gb += self.imgs[i].nbytes if not self.img_npy[i].exists():
pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)' np.save(self.img_npy[i].as_posix(), x[0])
gb += self.img_npy[i].stat().st_size
else:
self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
gb += self.imgs[i].nbytes
pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB {cache_images})'
pbar.close() pbar.close()
def cache_labels(self, path=Path('./labels.cache'), prefix=''): def cache_labels(self, path=Path('./labels.cache'), prefix=''):
...@@ -618,21 +627,25 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -618,21 +627,25 @@ class LoadImagesAndLabels(Dataset): # for training/testing
# Ancillary functions -------------------------------------------------------------------------------------------------- # Ancillary functions --------------------------------------------------------------------------------------------------
def load_image(self, index): def load_image(self, i):
# loads 1 image from dataset, returns img, original hw, resized hw # loads 1 image from dataset index 'i', returns im, original hw, resized hw
img = self.imgs[index] im = self.imgs[i]
if img is None: # not cached if im is None: # not cached in ram
path = self.img_files[index] npy = self.img_npy[i]
img = cv2.imread(path) # BGR if npy and npy.exists(): # load npy
assert img is not None, 'Image Not Found ' + path im = np.load(npy)
h0, w0 = img.shape[:2] # orig hw else: # read image
path = self.img_files[i]
im = cv2.imread(path) # BGR
assert im is not None, 'Image Not Found ' + path
h0, w0 = im.shape[:2] # orig hw
r = self.img_size / max(h0, w0) # ratio r = self.img_size / max(h0, w0) # ratio
if r != 1: # if sizes are not equal if r != 1: # if sizes are not equal
img = cv2.resize(img, (int(w0 * r), int(h0 * r)), im = cv2.resize(im, (int(w0 * r), int(h0 * r)),
interpolation=cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR) interpolation=cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR)
return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
else: else:
return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized return self.imgs[i], self.img_hw0[i], self.img_hw[i] # im, hw_original, hw_resized
def load_mosaic(self, index): def load_mosaic(self, index):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论