Unverified 提交 5d4787ba authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

`cudnn.benchmark = True` on Seed 0 (#9259)

* `cudnn.benchmark = True` on Seed 0 Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update general.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update general.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 ffdb58b0
...@@ -217,20 +217,17 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False): ...@@ -217,20 +217,17 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
def init_seeds(seed=0, deterministic=False): def init_seeds(seed=0, deterministic=False):
# Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
# cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
import torch.backends.cudnn as cudnn
if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
torch.use_deterministic_algorithms(True)
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
torch.backends.cudnn.benchmark = True # for faster training
if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
os.environ['PYTHONHASHSEED'] = str(seed)
def intersect_dicts(da, db, exclude=()): def intersect_dicts(da, db, exclude=()):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论