Unverified 提交 10e93d29 authored 作者: Yonghye Kwon's avatar Yonghye Kwon 提交者: GitHub

Set a seed of generator with an option for more randomness when training several…

Set a seed of generator with an option for more randomness when training several models with different seeds (#10486) * set seed with parameter Signed-off-by: 's avatarYonghye Kwon <developer.0hye@gmail.com> * make seed to be a large number * set seed with a parameter * set a seed of dataloader with opt for more randomness Signed-off-by: 's avatarYonghye Kwon <developer.0hye@gmail.com> Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 b2a0f1cd
...@@ -198,7 +198,8 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio ...@@ -198,7 +198,8 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
image_weights=opt.image_weights, image_weights=opt.image_weights,
quad=opt.quad, quad=opt.quad,
prefix=colorstr('train: '), prefix=colorstr('train: '),
shuffle=True) shuffle=True,
seed=opt.seed)
labels = np.concatenate(dataset.labels, 0) labels = np.concatenate(dataset.labels, 0)
mlc = int(labels[:, 0].max()) # max label class mlc = int(labels[:, 0].max()) # max label class
assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}' assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
......
...@@ -115,7 +115,8 @@ def create_dataloader(path, ...@@ -115,7 +115,8 @@ def create_dataloader(path,
image_weights=False, image_weights=False,
quad=False, quad=False,
prefix='', prefix='',
shuffle=False): shuffle=False,
seed=0):
if rect and shuffle: if rect and shuffle:
LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False') LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False')
shuffle = False shuffle = False
...@@ -140,7 +141,7 @@ def create_dataloader(path, ...@@ -140,7 +141,7 @@ def create_dataloader(path,
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
generator = torch.Generator() generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK) generator.manual_seed(6148914691236517205 + seed + RANK)
return loader(dataset, return loader(dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=shuffle and sampler is None, shuffle=shuffle and sampler is None,
......
...@@ -37,7 +37,8 @@ def create_dataloader(path, ...@@ -37,7 +37,8 @@ def create_dataloader(path,
prefix='', prefix='',
shuffle=False, shuffle=False,
mask_downsample_ratio=1, mask_downsample_ratio=1,
overlap_mask=False): overlap_mask=False,
seed=0):
if rect and shuffle: if rect and shuffle:
LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False') LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False')
shuffle = False shuffle = False
...@@ -64,7 +65,7 @@ def create_dataloader(path, ...@@ -64,7 +65,7 @@ def create_dataloader(path,
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
generator = torch.Generator() generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK) generator.manual_seed(6148914691236517205 + seed + RANK)
return loader( return loader(
dataset, dataset,
batch_size=batch_size, batch_size=batch_size,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论