提交 d49c52ee authored 作者: Glenn Jocher's avatar Glenn Jocher

_RepeatSampler outside of InfiniteDataLoader

上级 bb8872ea
...@@ -68,7 +68,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa ...@@ -68,7 +68,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
num_workers=nw, num_workers=nw,
sampler=sampler, sampler=sampler,
pin_memory=True, pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn) collate_fn=LoadImagesAndLabels.collate_fn) # torch.utils.data.DataLoader()
return dataloader, dataset return dataloader, dataset
...@@ -80,7 +80,7 @@ class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader): ...@@ -80,7 +80,7 @@ class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
object.__setattr__(self, 'batch_sampler', self._RepeatSampler(self.batch_sampler)) object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__() self.iterator = super().__iter__()
def __len__(self): def __len__(self):
...@@ -90,19 +90,20 @@ class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader): ...@@ -90,19 +90,20 @@ class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
for i in range(len(self)): for i in range(len(self)):
yield next(self.iterator) yield next(self.iterator)
class _RepeatSampler(object):
""" Sampler that repeats forever.
Args: class _RepeatSampler(object):
sampler (Sampler) """ Sampler that repeats forever.
"""
def __init__(self, sampler): Args:
self.sampler = sampler sampler (Sampler)
"""
def __iter__(self): def __init__(self, sampler):
while True: self.sampler = sampler
yield from iter(self.sampler)
def __iter__(self):
while True:
yield from iter(self.sampler)
class LoadImages: # for inference class LoadImages: # for inference
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论