提交 481d46cf authored 作者: Glenn Jocher's avatar Glenn Jocher

Improved corruption handling during scan and cache (#999)

上级 d61930e0
......@@ -328,6 +328,14 @@ class LoadStreams: # multiple IP or RTSP cameras
class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1):
self.img_size = img_size
self.augment = augment
self.hyp = hyp
self.image_weights = image_weights
self.rect = False if image_weights else rect
self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
self.mosaic_border = [-img_size // 2, -img_size // 2]
self.stride = stride
def img2label_paths(img_paths):
# Define label paths as a function of image paths
......@@ -349,25 +357,10 @@ class LoadImagesAndLabels(Dataset): # for training/testing
raise Exception('%s does not exist' % p)
self.img_files = sorted(
[x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats])
assert len(self.img_files) > 0, 'No images found'
except Exception as e:
raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))
n = len(self.img_files)
assert n > 0, 'No images found in %s. See %s' % (path, help_url)
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
nb = bi[-1] + 1 # number of batches
self.n = n # number of images
self.batch = bi # batch index of image
self.img_size = img_size
self.augment = augment
self.hyp = hyp
self.image_weights = image_weights
self.rect = False if image_weights else rect
self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
self.mosaic_border = [-img_size // 2, -img_size // 2]
self.stride = stride
# Check cache
self.label_files = img2label_paths(self.img_files) # labels
cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels
......@@ -386,6 +379,12 @@ class LoadImagesAndLabels(Dataset): # for training/testing
self.img_files = list(cache.keys()) # update
self.label_files = img2label_paths(cache.keys()) # update
n = len(shapes) # number of images
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
nb = bi[-1] + 1 # number of batches
self.batch = bi # batch index of image
self.n = n
# Rectangular Training
if self.rect:
# Sort by aspect ratio
......@@ -500,7 +499,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
l = np.zeros((0, 5), dtype=np.float32)
x[img] = [l, shape]
except Exception as e:
print('WARNING: Ignoring corrupted image and/or label:%s: %s' % (img, e))
print('WARNING: Ignoring corrupted image and/or label %s: %s' % (img, e))
x['hash'] = get_hash(self.label_files + self.img_files)
torch.save(x, path) # save for next time
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论