Unverified 提交 09d17038 authored 作者: Werner Duvaud's avatar Werner Duvaud 提交者: GitHub

Default DataLoader `shuffle=True` for training (#5623)

* Fix shuffle DataLoader argument * Add shuffle argument * Disable shuffle when rect * Cleanup, add rect warning * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleanup2 * Cleanup3 Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
上级 7473f0f9
...@@ -212,7 +212,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -212,7 +212,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls, train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK, hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK,
workers=workers, image_weights=opt.image_weights, quad=opt.quad, workers=workers, image_weights=opt.image_weights, quad=opt.quad,
prefix=colorstr('train: ')) prefix=colorstr('train: '), shuffle=True)
mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
nb = len(train_loader) # number of batches nb = len(train_loader) # number of batches
assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}' assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
......
...@@ -22,7 +22,7 @@ import torch ...@@ -22,7 +22,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import yaml import yaml
from PIL import ExifTags, Image, ImageOps from PIL import ExifTags, Image, ImageOps
from torch.utils.data import Dataset from torch.utils.data import DataLoader, Dataset, dataloader, distributed
from tqdm import tqdm from tqdm import tqdm
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
...@@ -93,13 +93,15 @@ def exif_transpose(image): ...@@ -93,13 +93,15 @@ def exif_transpose(image):
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0, def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''): rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix='', shuffle=False):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache if rect and shuffle:
with torch_distributed_zero_first(rank): LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False')
shuffle = False
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = LoadImagesAndLabels(path, imgsz, batch_size, dataset = LoadImagesAndLabels(path, imgsz, batch_size,
augment=augment, # augment images augment=augment, # augmentation
hyp=hyp, # augmentation hyperparameters hyp=hyp, # hyperparameters
rect=rect, # rectangular training rect=rect, # rectangular batches
cache_images=cache, cache_images=cache,
single_cls=single_cls, single_cls=single_cls,
stride=int(stride), stride=int(stride),
...@@ -109,19 +111,18 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non ...@@ -109,19 +111,18 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non
batch_size = min(batch_size, len(dataset)) batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
# Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader() return loader(dataset,
dataloader = loader(dataset, batch_size=batch_size,
batch_size=batch_size, shuffle=shuffle and sampler is None,
num_workers=nw, num_workers=nw,
sampler=sampler, sampler=sampler,
pin_memory=True, pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn) collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset
return dataloader, dataset
class InfiniteDataLoader(dataloader.DataLoader):
class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
""" Dataloader that reuses workers """ Dataloader that reuses workers
Uses same syntax as vanilla DataLoader Uses same syntax as vanilla DataLoader
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论