Unverified 提交 fad27c00 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update DDP for `torch.distributed.run` with `gloo` backend (#3680)

* Update DDP for `torch.distributed.run` * Add LOCAL_RANK * remove opt.local_rank * backend="gloo|nccl" * print * print * debug * debug * os.getenv * gloo * gloo * gloo * cleanup * fix getenv * cleanup * cleanup destroy * try nccl * return opt * add --local_rank * add timeout * add init_method * gloo * move destroy * move destroy * move print(opt) under if RANK * destroy only RANK 0 * move destroy inside train() * restore destroy outside train() * update print(opt) * cleanup * nccl * gloo with 60 second timeout * update namespace printing
上级 5bab9a28
......@@ -8,8 +8,8 @@ import torch.backends.cudnn as cudnn
from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
from utils.general import check_img_size, check_requirements, check_imshow, colorstr, non_max_suppression, \
apply_classifier, scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
from utils.plots import colors, plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized
......@@ -202,7 +202,7 @@ def parse_opt():
def main(opt):
print(opt)
print(colorstr('detect: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
check_requirements(exclude=('tensorboard', 'thop'))
detect(**vars(opt))
......
......@@ -163,8 +163,8 @@ def parse_opt():
def main(opt):
print(opt)
set_logging()
print(colorstr('export: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
export(**vars(opt))
......
......@@ -51,7 +51,6 @@ def test(data,
device = next(model.parameters()).device # get model device
else: # called directly
set_logging()
device = select_device(device, batch_size=batch_size)
# Directories
......@@ -323,7 +322,8 @@ def parse_opt():
def main(opt):
print(opt)
set_logging()
print(colorstr('test: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
check_requirements(exclude=('tensorboard', 'thop'))
if opt.task in ('train', 'val', 'test'): # run normally
......
差异被折叠。
......@@ -64,7 +64,7 @@ def exif_size(img):
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
rect=False, rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with torch_distributed_zero_first(rank):
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
......@@ -79,7 +79,7 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non
prefix=prefix)
batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
# Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
......
......@@ -13,6 +13,7 @@ from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torchvision
......@@ -30,10 +31,10 @@ def torch_distributed_zero_first(local_rank: int):
Decorator to make all processes in distributed training wait for each local_master to do something.
"""
if local_rank not in [-1, 0]:
torch.distributed.barrier()
dist.barrier()
yield
if local_rank == 0:
torch.distributed.barrier()
dist.barrier()
def init_torch_seeds(seed=0):
......
"""Utilities and tools for tracking runs with Weights & Biases."""
import logging
import os
import sys
from contextlib import contextmanager
from pathlib import Path
......@@ -18,6 +19,7 @@ try:
except ImportError:
wandb = None
RANK = int(os.getenv('RANK', -1))
WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
......@@ -42,10 +44,10 @@ def get_run_info(run_path):
def check_wandb_resume(opt):
process_wandb_config_ddp_mode(opt) if opt.global_rank not in [-1, 0] else None
process_wandb_config_ddp_mode(opt) if RANK not in [-1, 0] else None
if isinstance(opt.resume, str):
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
if opt.global_rank not in [-1, 0]: # For resuming DDP runs
if RANK not in [-1, 0]: # For resuming DDP runs
entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
api = wandb.Api()
artifact = api.artifact(entity + '/' + project + '/' + model_artifact_name + ':latest')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论