Unverified 提交 09d17038 authored 作者: Werner Duvaud's avatar Werner Duvaud 提交者: GitHub

Default DataLoader `shuffle=True` for training (#5623)

* Fix shuffle DataLoader argument * Add shuffle argument * Disable shuffle when rect * Cleanup, add rect warning * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleanup2 * Cleanup3 Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
上级 7473f0f9
......@@ -212,7 +212,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK,
workers=workers, image_weights=opt.image_weights, quad=opt.quad,
prefix=colorstr('train: '))
prefix=colorstr('train: '), shuffle=True)
mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
nb = len(train_loader) # number of batches
assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
......
......@@ -22,7 +22,7 @@ import torch
import torch.nn.functional as F
import yaml
from PIL import ExifTags, Image, ImageOps
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
from tqdm import tqdm
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
......@@ -93,13 +93,15 @@ def exif_transpose(image):
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with torch_distributed_zero_first(rank):
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix='', shuffle=False):
if rect and shuffle:
LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False')
shuffle = False
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
augment=augment, # augment images
hyp=hyp, # augmentation hyperparameters
rect=rect, # rectangular training
augment=augment, # augmentation
hyp=hyp, # hyperparameters
rect=rect, # rectangular batches
cache_images=cache,
single_cls=single_cls,
stride=int(stride),
......@@ -109,19 +111,18 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non
batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
# Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
dataloader = loader(dataset,
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
return loader(dataset,
batch_size=batch_size,
shuffle=shuffle and sampler is None,
num_workers=nw,
sampler=sampler,
pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
return dataloader, dataset
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset
class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
class InfiniteDataLoader(dataloader.DataLoader):
""" Dataloader that reuses workers
Uses same syntax as vanilla DataLoader
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论