Unverified 提交 fad27c00 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update DDP for `torch.distributed.run` with `gloo` backend (#3680)

* Update DDP for `torch.distributed.run` * Add LOCAL_RANK * remove opt.local_rank * backend="gloo|nccl" * print * print * debug * debug * os.getenv * gloo * gloo * gloo * cleanup * fix getenv * cleanup * cleanup destroy * try nccl * return opt * add --local_rank * add timeout * add init_method * gloo * move destroy * move destroy * move print(opt) under if RANK * destroy only RANK 0 * move destroy inside train() * restore destroy outside train() * update print(opt) * cleanup * nccl * gloo with 60 second timeout * update namespace printing
上级 5bab9a28
...@@ -8,8 +8,8 @@ import torch.backends.cudnn as cudnn ...@@ -8,8 +8,8 @@ import torch.backends.cudnn as cudnn
from models.experimental import attempt_load from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \ from utils.general import check_img_size, check_requirements, check_imshow, colorstr, non_max_suppression, \
scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box apply_classifier, scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
from utils.plots import colors, plot_one_box from utils.plots import colors, plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized from utils.torch_utils import select_device, load_classifier, time_synchronized
...@@ -202,7 +202,7 @@ def parse_opt(): ...@@ -202,7 +202,7 @@ def parse_opt():
def main(opt): def main(opt):
print(opt) print(colorstr('detect: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
check_requirements(exclude=('tensorboard', 'thop')) check_requirements(exclude=('tensorboard', 'thop'))
detect(**vars(opt)) detect(**vars(opt))
......
...@@ -163,8 +163,8 @@ def parse_opt(): ...@@ -163,8 +163,8 @@ def parse_opt():
def main(opt): def main(opt):
print(opt)
set_logging() set_logging()
print(colorstr('export: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
export(**vars(opt)) export(**vars(opt))
......
...@@ -51,7 +51,6 @@ def test(data, ...@@ -51,7 +51,6 @@ def test(data,
device = next(model.parameters()).device # get model device device = next(model.parameters()).device # get model device
else: # called directly else: # called directly
set_logging()
device = select_device(device, batch_size=batch_size) device = select_device(device, batch_size=batch_size)
# Directories # Directories
...@@ -323,7 +322,8 @@ def parse_opt(): ...@@ -323,7 +322,8 @@ def parse_opt():
def main(opt): def main(opt):
print(opt) set_logging()
print(colorstr('test: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
check_requirements(exclude=('tensorboard', 'thop')) check_requirements(exclude=('tensorboard', 'thop'))
if opt.task in ('train', 'val', 'test'): # run normally if opt.task in ('train', 'val', 'test'): # run normally
......
差异被折叠。
...@@ -64,7 +64,7 @@ def exif_size(img): ...@@ -64,7 +64,7 @@ def exif_size(img):
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0, def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
rect=False, rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''): rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with torch_distributed_zero_first(rank): with torch_distributed_zero_first(rank):
dataset = LoadImagesAndLabels(path, imgsz, batch_size, dataset = LoadImagesAndLabels(path, imgsz, batch_size,
...@@ -79,7 +79,7 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non ...@@ -79,7 +79,7 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non
prefix=prefix) prefix=prefix)
batch_size = min(batch_size, len(dataset)) batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
# Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader() # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
......
...@@ -13,6 +13,7 @@ from pathlib import Path ...@@ -13,6 +13,7 @@ from pathlib import Path
import torch import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision import torchvision
...@@ -30,10 +31,10 @@ def torch_distributed_zero_first(local_rank: int): ...@@ -30,10 +31,10 @@ def torch_distributed_zero_first(local_rank: int):
Decorator to make all processes in distributed training wait for each local_master to do something. Decorator to make all processes in distributed training wait for each local_master to do something.
""" """
if local_rank not in [-1, 0]: if local_rank not in [-1, 0]:
torch.distributed.barrier() dist.barrier()
yield yield
if local_rank == 0: if local_rank == 0:
torch.distributed.barrier() dist.barrier()
def init_torch_seeds(seed=0): def init_torch_seeds(seed=0):
......
"""Utilities and tools for tracking runs with Weights & Biases.""" """Utilities and tools for tracking runs with Weights & Biases."""
import logging import logging
import os
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
...@@ -18,6 +19,7 @@ try: ...@@ -18,6 +19,7 @@ try:
except ImportError: except ImportError:
wandb = None wandb = None
RANK = int(os.getenv('RANK', -1))
WANDB_ARTIFACT_PREFIX = 'wandb-artifact://' WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
...@@ -42,10 +44,10 @@ def get_run_info(run_path): ...@@ -42,10 +44,10 @@ def get_run_info(run_path):
def check_wandb_resume(opt): def check_wandb_resume(opt):
process_wandb_config_ddp_mode(opt) if opt.global_rank not in [-1, 0] else None process_wandb_config_ddp_mode(opt) if RANK not in [-1, 0] else None
if isinstance(opt.resume, str): if isinstance(opt.resume, str):
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
if opt.global_rank not in [-1, 0]: # For resuming DDP runs if RANK not in [-1, 0]: # For resuming DDP runs
entity, project, run_id, model_artifact_name = get_run_info(opt.resume) entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
api = wandb.Api() api = wandb.Api()
artifact = api.artifact(entity + '/' + project + '/' + model_artifact_name + ':latest') artifact = api.artifact(entity + '/' + project + '/' + model_artifact_name + ':latest')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论