Unverified 提交 5e1a9553 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

AutoBatch `cudnn.benchmark=True` fix (#9448)

* AutoBatch `cudnn.benchmark=True` fix May resolve https://github.com/ultralytics/yolov5/issues/9287Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update autobatch.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update autobatch.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update general.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 db06f495
...@@ -33,6 +33,9 @@ def autobatch(model, imgsz=640, fraction=0.8, batch_size=16): ...@@ -33,6 +33,9 @@ def autobatch(model, imgsz=640, fraction=0.8, batch_size=16):
if device.type == 'cpu': if device.type == 'cpu':
LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}') LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
return batch_size return batch_size
if torch.backends.cudnn.benchmark:
LOGGER.info(f'{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}')
return batch_size
# Inspect CUDA memory # Inspect CUDA memory
gb = 1 << 30 # bytes to GiB (1024 ** 3) gb = 1 << 30 # bytes to GiB (1024 ** 3)
......
...@@ -223,7 +223,7 @@ def init_seeds(seed=0, deterministic=False): ...@@ -223,7 +223,7 @@ def init_seeds(seed=0, deterministic=False):
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
torch.backends.cudnn.benchmark = True # for faster training # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213 if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
torch.use_deterministic_algorithms(True) torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论