Unverified 提交 28bff22d authored 作者: Dean Mark's avatar Dean Mark 提交者: GitHub

Use multi-threading in cache_labels (#3505)

* Use multi threading in cache_labels * PEP8 reformat * Add num_threads * changed ThreadPool.imap_unordered to Pool.imap_unordered * Remove inplace additions * Update datasets.py refactor initial desc Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 c058a61e
...@@ -9,7 +9,7 @@ import random ...@@ -9,7 +9,7 @@ import random
import shutil import shutil
import time import time
from itertools import repeat from itertools import repeat
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool, Pool
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread
...@@ -29,6 +29,7 @@ from utils.torch_utils import torch_distributed_zero_first ...@@ -29,6 +29,7 @@ from utils.torch_utils import torch_distributed_zero_first
help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data' help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes
vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes
num_threads = min(8, os.cpu_count()) # number of multiprocessing threads
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Get orientation exif tag # Get orientation exif tag
...@@ -447,7 +448,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -447,7 +448,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
if cache_images: if cache_images:
gb = 0 # Gigabytes of cached images gb = 0 # Gigabytes of cached images
self.img_hw0, self.img_hw = [None] * n, [None] * n self.img_hw0, self.img_hw = [None] * n, [None] * n
results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) # 8 threads results = ThreadPool(num_threads).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
pbar = tqdm(enumerate(results), total=n) pbar = tqdm(enumerate(results), total=n)
for i, x in pbar: for i, x in pbar:
self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i) self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
...@@ -458,53 +459,24 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -458,53 +459,24 @@ class LoadImagesAndLabels(Dataset): # for training/testing
def cache_labels(self, path=Path('./labels.cache'), prefix=''): def cache_labels(self, path=Path('./labels.cache'), prefix=''):
# Cache dataset labels, check images and read shapes # Cache dataset labels, check images and read shapes
x = {} # dict x = {} # dict
nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, duplicate nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, corrupt
pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files)) desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
for i, (im_file, lb_file) in enumerate(pbar): with Pool(num_threads) as pool:
try: pbar = tqdm(pool.imap_unordered(verify_image_label,
# verify images zip(self.img_files, self.label_files, repeat(prefix))),
im = Image.open(im_file) desc=desc, total=len(self.img_files))
im.verify() # PIL verify for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f in pbar:
shape = exif_size(im) # image size if im_file:
segments = [] # instance segments x[im_file] = [l, shape, segments]
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' nm, nf, ne, nc = nm + nm_f, nf + nf_f, ne + ne_f, nc + nc_f
assert im.format.lower() in img_formats, f'invalid image format {im.format}' pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
# verify labels
if os.path.isfile(lb_file):
nf += 1 # label found
with open(lb_file, 'r') as f:
l = [x.split() for x in f.read().strip().splitlines() if len(x)]
if any([len(x) > 8 for x in l]): # is segment
classes = np.array([x[0] for x in l], dtype=np.float32)
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...)
l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
l = np.array(l, dtype=np.float32)
if len(l):
assert l.shape[1] == 5, 'labels require 5 columns each'
assert (l >= 0).all(), 'negative labels'
assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
else:
ne += 1 # label empty
l = np.zeros((0, 5), dtype=np.float32)
else:
nm += 1 # label missing
l = np.zeros((0, 5), dtype=np.float32)
x[im_file] = [l, shape, segments]
except Exception as e:
nc += 1
logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \
f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
pbar.close() pbar.close()
if nf == 0: if nf == 0:
logging.info(f'{prefix}WARNING: No labels found in {path}. See {help_url}') logging.info(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
x['hash'] = get_hash(self.label_files + self.img_files) x['hash'] = get_hash(self.label_files + self.img_files)
x['results'] = nf, nm, ne, nc, i + 1 x['results'] = nf, nm, ne, nc, len(self.img_files)
x['version'] = 0.2 # cache version x['version'] = 0.2 # cache version
try: try:
torch.save(x, path) # save cache for next time torch.save(x, path) # save cache for next time
...@@ -1069,3 +1041,44 @@ def autosplit(path='../coco128', weights=(0.9, 0.1, 0.0), annotated_only=False): ...@@ -1069,3 +1041,44 @@ def autosplit(path='../coco128', weights=(0.9, 0.1, 0.0), annotated_only=False):
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
with open(path / txt[i], 'a') as f: with open(path / txt[i], 'a') as f:
f.write(str(img) + '\n') # add image to txt file f.write(str(img) + '\n') # add image to txt file
def verify_image_label(params):
# Verify one image-label pair
im_file, lb_file, prefix = params
nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, corrupt
try:
# verify images
im = Image.open(im_file)
im.verify() # PIL verify
shape = exif_size(im) # image size
segments = [] # instance segments
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
assert im.format.lower() in img_formats, f'invalid image format {im.format}'
# verify labels
if os.path.isfile(lb_file):
nf = 1 # label found
with open(lb_file, 'r') as f:
l = [x.split() for x in f.read().strip().splitlines() if len(x)]
if any([len(x) > 8 for x in l]): # is segment
classes = np.array([x[0] for x in l], dtype=np.float32)
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...)
l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
l = np.array(l, dtype=np.float32)
if len(l):
assert l.shape[1] == 5, 'labels require 5 columns each'
assert (l >= 0).all(), 'negative labels'
assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
else:
ne = 1 # label empty
l = np.zeros((0, 5), dtype=np.float32)
else:
nm = 1 # label missing
l = np.zeros((0, 5), dtype=np.float32)
return im_file, l, shape, segments, nm, nf, ne, nc
except Exception as e:
nc = 1
logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
return [None] * 4 + [nm, nf, ne, nc]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论