Unverified 提交 c91d1db7 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update dataloaders.py (#9250)

上级 223c59db
...@@ -40,6 +40,7 @@ IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', ...@@ -40,6 +40,7 @@ IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp',
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
BAR_FORMAT = '{l_bar}{bar:10}{r_bar}{bar:-10b}' # tqdm bar format BAR_FORMAT = '{l_bar}{bar:10}{r_bar}{bar:-10b}' # tqdm bar format
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
# Get orientation exif tag # Get orientation exif tag
for orientation in ExifTags.TAGS.keys(): for orientation in ExifTags.TAGS.keys():
...@@ -83,7 +84,7 @@ def exif_transpose(image): ...@@ -83,7 +84,7 @@ def exif_transpose(image):
5: Image.TRANSPOSE, 5: Image.TRANSPOSE,
6: Image.ROTATE_270, 6: Image.ROTATE_270,
7: Image.TRANSVERSE, 7: Image.TRANSVERSE,
8: Image.ROTATE_90,}.get(orientation) 8: Image.ROTATE_90}.get(orientation)
if method is not None: if method is not None:
image = image.transpose(method) image = image.transpose(method)
del exif[0x0112] del exif[0x0112]
...@@ -144,7 +145,7 @@ def create_dataloader(path, ...@@ -144,7 +145,7 @@ def create_dataloader(path,
shuffle=shuffle and sampler is None, shuffle=shuffle and sampler is None,
num_workers=nw, num_workers=nw,
sampler=sampler, sampler=sampler,
pin_memory=True, pin_memory=PIN_MEMORY,
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn, collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn,
worker_init_fn=seed_worker, worker_init_fn=seed_worker,
generator=generator), dataset generator=generator), dataset
...@@ -1152,6 +1153,6 @@ def create_classification_dataloader(path, ...@@ -1152,6 +1153,6 @@ def create_classification_dataloader(path,
shuffle=shuffle and sampler is None, shuffle=shuffle and sampler is None,
num_workers=nw, num_workers=nw,
sampler=sampler, sampler=sampler,
pin_memory=True, pin_memory=PIN_MEMORY,
worker_init_fn=seed_worker, worker_init_fn=seed_worker,
generator=generator) # or DataLoader(persistent_workers=True) generator=generator) # or DataLoader(persistent_workers=True)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论