提交 22d60882 authored 作者: Glenn Jocher's avatar Glenn Jocher

speed-reproducibility fix #17

上级 55ca5c74
...@@ -63,7 +63,7 @@ def train(hyp): ...@@ -63,7 +63,7 @@ def train(hyp):
weights = opt.weights # initial training weights weights = opt.weights # initial training weights
# Configure # Configure
init_seeds() init_seeds(1)
with open(opt.data) as f: with open(opt.data) as f:
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
train_path = data_dict['train'] train_path = data_dict['train']
......
...@@ -12,8 +12,11 @@ import torch.nn.functional as F ...@@ -12,8 +12,11 @@ import torch.nn.functional as F
def init_seeds(seed=0): def init_seeds(seed=0):
torch.manual_seed(seed) torch.manual_seed(seed)
# Reduce randomness (may be slower on Tesla GPUs) # https://pytorch.org/docs/stable/notes/randomness.html # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
if seed == 0: if seed == 0: # slower, more reproducible
cudnn.deterministic = True
cudnn.benchmark = False
else: # faster, less reproducible
cudnn.deterministic = False cudnn.deterministic = False
cudnn.benchmark = True cudnn.benchmark = True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论