提交 d61930e0 authored 作者: Glenn Jocher's avatar Glenn Jocher

Improved corruption handling during scan and cache (#999)

上级 0fda95aa
...@@ -328,6 +328,12 @@ class LoadStreams: # multiple IP or RTSP cameras ...@@ -328,6 +328,12 @@ class LoadStreams: # multiple IP or RTSP cameras
class LoadImagesAndLabels(Dataset): # for training/testing class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False, def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1): cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1):
def img2label_paths(img_paths):
# Define label paths as a function of image paths
sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
return [x.replace(sa, sb, 1).replace(os.path.splitext(x)[-1], '.txt') for x in img_paths]
try: try:
f = [] # image files f = [] # image files
for p in path if isinstance(path, list) else [path]: for p in path if isinstance(path, list) else [path]:
...@@ -362,11 +368,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -362,11 +368,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
self.mosaic_border = [-img_size // 2, -img_size // 2] self.mosaic_border = [-img_size // 2, -img_size // 2]
self.stride = stride self.stride = stride
# Define labels
sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
self.label_files = [x.replace(sa, sb, 1).replace(os.path.splitext(x)[-1], '.txt') for x in self.img_files]
# Check cache # Check cache
self.label_files = img2label_paths(self.img_files) # labels
cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels
if os.path.isfile(cache_path): if os.path.isfile(cache_path):
cache = torch.load(cache_path) # load cache = torch.load(cache_path) # load
...@@ -375,12 +378,15 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -375,12 +378,15 @@ class LoadImagesAndLabels(Dataset): # for training/testing
else: else:
cache = self.cache_labels(cache_path) # cache cache = self.cache_labels(cache_path) # cache
# Get labels # Read cache
labels, shapes = zip(*[cache[x] for x in self.img_files]) cache.pop('hash') # remove hash
self.shapes = np.array(shapes, dtype=np.float64) labels, shapes = zip(*cache.values())
self.labels = list(labels) self.labels = list(labels)
self.shapes = np.array(shapes, dtype=np.float64)
self.img_files = list(cache.keys()) # update
self.label_files = img2label_paths(cache.keys()) # update
# Rectangular Training https://github.com/ultralytics/yolov3/issues/232 # Rectangular Training
if self.rect: if self.rect:
# Sort by aspect ratio # Sort by aspect ratio
s = self.shapes # wh s = self.shapes # wh
...@@ -404,7 +410,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -404,7 +410,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
# Cache labels # Check labels
create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False
nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
pbar = enumerate(self.label_files) pbar = enumerate(self.label_files)
...@@ -483,10 +489,9 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -483,10 +489,9 @@ class LoadImagesAndLabels(Dataset): # for training/testing
for (img, label) in pbar: for (img, label) in pbar:
try: try:
l = [] l = []
image = Image.open(img) im = Image.open(img)
image.verify() # PIL verify im.verify() # PIL verify
# _ = io.imread(img) # skimage verify (from skimage import io) shape = exif_size(im) # image size
shape = exif_size(image) # image size
assert (shape[0] > 9) & (shape[1] > 9), 'image size <10 pixels' assert (shape[0] > 9) & (shape[1] > 9), 'image size <10 pixels'
if os.path.isfile(label): if os.path.isfile(label):
with open(label, 'r') as f: with open(label, 'r') as f:
...@@ -495,8 +500,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -495,8 +500,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
l = np.zeros((0, 5), dtype=np.float32) l = np.zeros((0, 5), dtype=np.float32)
x[img] = [l, shape] x[img] = [l, shape]
except Exception as e: except Exception as e:
x[img] = [None, None] print('WARNING: Ignoring corrupted image and/or label:%s: %s' % (img, e))
print('WARNING: %s: %s' % (img, e))
x['hash'] = get_hash(self.label_files + self.img_files) x['hash'] = get_hash(self.label_files + self.img_files)
torch.save(x, path) # save for next time torch.save(x, path) # save for next time
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论