提交 d3f9bf2b authored 作者: Glenn Jocher's avatar Glenn Jocher

Update datasets.py

上级 901243c7
...@@ -62,26 +62,25 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa ...@@ -62,26 +62,25 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
batch_size = min(batch_size, len(dataset)) batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
dataloader = InfiniteDataLoader (dataset, dataloader = InfiniteDataLoader(dataset,
batch_size=batch_size, batch_size=batch_size,
num_workers=nw, num_workers=nw,
sampler=train_sampler, sampler=sampler,
pin_memory=True, pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn) collate_fn=LoadImagesAndLabels.collate_fn)
return dataloader, dataset return dataloader, dataset
class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader): class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
''' """ Dataloader that reuses workers.
Dataloader that reuses workers.
Uses same syntax as vanilla DataLoader. Uses same syntax as vanilla DataLoader.
''' """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) object.__setattr__(self, 'batch_sampler', self._RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__() self.iterator = super().__iter__()
def __len__(self): def __len__(self):
...@@ -91,14 +90,12 @@ class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader): ...@@ -91,14 +90,12 @@ class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
for i in range(len(self)): for i in range(len(self)):
yield next(self.iterator) yield next(self.iterator)
class _RepeatSampler(object):
class _RepeatSampler(object): """ Sampler that repeats forever.
'''
Sampler that repeats forever.
Args: Args:
sampler (Sampler) sampler (Sampler)
''' """
def __init__(self, sampler): def __init__(self, sampler):
self.sampler = sampler self.sampler = sampler
...@@ -684,14 +681,10 @@ def load_mosaic(self, index): ...@@ -684,14 +681,10 @@ def load_mosaic(self, index):
# Concat/clip labels # Concat/clip labels
if len(labels4): if len(labels4):
labels4 = np.concatenate(labels4, 0) labels4 = np.concatenate(labels4, 0)
# np.clip(labels4[:, 1:] - s / 2, 0, s, out=labels4[:, 1:]) # use with center crop np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_perspective
np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_affine # img4, labels4 = replicate(img4, labels4) # replicate
# Replicate
# img4, labels4 = replicate(img4, labels4)
# Augment # Augment
# img4 = img4[s // 2: int(s * 1.5), s // 2:int(s * 1.5)] # center crop (WARNING, requires box pruning)
img4, labels4 = random_perspective(img4, labels4, img4, labels4 = random_perspective(img4, labels4,
degrees=self.hyp['degrees'], degrees=self.hyp['degrees'],
translate=self.hyp['translate'], translate=self.hyp['translate'],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论