Unverified 提交 10e93d29 authored 作者: Yonghye Kwon's avatar Yonghye Kwon 提交者: GitHub

Set a seed of generator with an option for more randomness when training several…

Set a seed of generator with an option for more randomness when training several models with different seeds (#10486) * set seed with parameter Signed-off-by: 's avatarYonghye Kwon <developer.0hye@gmail.com> * make seed to be a large number * set seed with a parameter * set a seed of dataloader with opt for more randomness Signed-off-by: 's avatarYonghye Kwon <developer.0hye@gmail.com> Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 b2a0f1cd
......@@ -198,7 +198,8 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
image_weights=opt.image_weights,
quad=opt.quad,
prefix=colorstr('train: '),
shuffle=True)
shuffle=True,
seed=opt.seed)
labels = np.concatenate(dataset.labels, 0)
mlc = int(labels[:, 0].max()) # max label class
assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
......
......@@ -115,7 +115,8 @@ def create_dataloader(path,
image_weights=False,
quad=False,
prefix='',
shuffle=False):
shuffle=False,
seed=0):
if rect and shuffle:
LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False')
shuffle = False
......@@ -140,7 +141,7 @@ def create_dataloader(path,
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK)
generator.manual_seed(6148914691236517205 + seed + RANK)
return loader(dataset,
batch_size=batch_size,
shuffle=shuffle and sampler is None,
......
......@@ -37,7 +37,8 @@ def create_dataloader(path,
prefix='',
shuffle=False,
mask_downsample_ratio=1,
overlap_mask=False):
overlap_mask=False,
seed=0):
if rect and shuffle:
LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False')
shuffle = False
......@@ -64,7 +65,7 @@ def create_dataloader(path,
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK)
generator.manual_seed(6148914691236517205 + seed + RANK)
return loader(
dataset,
batch_size=batch_size,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论