Unverified 提交 fad27c00 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update DDP for `torch.distributed.run` with `gloo` backend (#3680)

* Update DDP for `torch.distributed.run` * Add LOCAL_RANK * remove opt.local_rank * backend="gloo|nccl" * print * print * debug * debug * os.getenv * gloo * gloo * gloo * cleanup * fix getenv * cleanup * cleanup destroy * try nccl * return opt * add --local_rank * add timeout * add init_method * gloo * move destroy * move destroy * move print(opt) under if RANK * destroy only RANK 0 * move destroy inside train() * restore destroy outside train() * update print(opt) * cleanup * nccl * gloo with 60 second timeout * update namespace printing
上级 5bab9a28
...@@ -8,8 +8,8 @@ import torch.backends.cudnn as cudnn ...@@ -8,8 +8,8 @@ import torch.backends.cudnn as cudnn
from models.experimental import attempt_load from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \ from utils.general import check_img_size, check_requirements, check_imshow, colorstr, non_max_suppression, \
scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box apply_classifier, scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
from utils.plots import colors, plot_one_box from utils.plots import colors, plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized from utils.torch_utils import select_device, load_classifier, time_synchronized
...@@ -202,7 +202,7 @@ def parse_opt(): ...@@ -202,7 +202,7 @@ def parse_opt():
def main(opt): def main(opt):
print(opt) print(colorstr('detect: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
check_requirements(exclude=('tensorboard', 'thop')) check_requirements(exclude=('tensorboard', 'thop'))
detect(**vars(opt)) detect(**vars(opt))
......
...@@ -163,8 +163,8 @@ def parse_opt(): ...@@ -163,8 +163,8 @@ def parse_opt():
def main(opt): def main(opt):
print(opt)
set_logging() set_logging()
print(colorstr('export: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
export(**vars(opt)) export(**vars(opt))
......
...@@ -51,7 +51,6 @@ def test(data, ...@@ -51,7 +51,6 @@ def test(data,
device = next(model.parameters()).device # get model device device = next(model.parameters()).device # get model device
else: # called directly else: # called directly
set_logging()
device = select_device(device, batch_size=batch_size) device = select_device(device, batch_size=batch_size)
# Directories # Directories
...@@ -323,7 +322,8 @@ def parse_opt(): ...@@ -323,7 +322,8 @@ def parse_opt():
def main(opt): def main(opt):
print(opt) set_logging()
print(colorstr('test: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
check_requirements(exclude=('tensorboard', 'thop')) check_requirements(exclude=('tensorboard', 'thop'))
if opt.task in ('train', 'val', 'test'): # run normally if opt.task in ('train', 'val', 'test'): # run normally
......
...@@ -37,15 +37,17 @@ from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_di ...@@ -37,15 +37,17 @@ from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_di
from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
def train(hyp, # path/to/hyp.yaml or hyp dictionary def train(hyp, # path/to/hyp.yaml or hyp dictionary
opt, opt,
device, device,
): ):
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \ save_dir, epochs, batch_size, total_batch_size, weights, single_cls = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.single_cls
opt.single_cls
# Directories # Directories
wdir = save_dir / 'weights' wdir = save_dir / 'weights'
...@@ -69,13 +71,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -69,13 +71,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Configure # Configure
plots = not opt.evolve # create plots plots = not opt.evolve # create plots
cuda = device.type != 'cpu' cuda = device.type != 'cpu'
init_seeds(2 + rank) init_seeds(2 + RANK)
with open(opt.data) as f: with open(opt.data) as f:
data_dict = yaml.safe_load(f) # data dict data_dict = yaml.safe_load(f) # data dict
# Loggers # Loggers
loggers = {'wandb': None, 'tb': None} # loggers dict loggers = {'wandb': None, 'tb': None} # loggers dict
if rank in [-1, 0]: if RANK in [-1, 0]:
# TensorBoard # TensorBoard
if not opt.evolve: if not opt.evolve:
prefix = colorstr('tensorboard: ') prefix = colorstr('tensorboard: ')
...@@ -99,7 +101,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -99,7 +101,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Model # Model
pretrained = weights.endswith('.pt') pretrained = weights.endswith('.pt')
if pretrained: if pretrained:
with torch_distributed_zero_first(rank): with torch_distributed_zero_first(RANK):
weights = attempt_download(weights) # download if not found locally weights = attempt_download(weights) # download if not found locally
ckpt = torch.load(weights, map_location=device) # load checkpoint ckpt = torch.load(weights, map_location=device) # load checkpoint
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
...@@ -110,7 +112,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -110,7 +112,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
else: else:
model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
with torch_distributed_zero_first(rank): with torch_distributed_zero_first(RANK):
check_dataset(data_dict) # check check_dataset(data_dict) # check
train_path = data_dict['train'] train_path = data_dict['train']
test_path = data_dict['val'] test_path = data_dict['val']
...@@ -158,7 +160,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -158,7 +160,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# plot_lr_scheduler(optimizer, scheduler, epochs) # plot_lr_scheduler(optimizer, scheduler, epochs)
# EMA # EMA
ema = ModelEMA(model) if rank in [-1, 0] else None ema = ModelEMA(model) if RANK in [-1, 0] else None
# Resume # Resume
start_epoch, best_fitness = 0, 0.0 start_epoch, best_fitness = 0, 0.0
...@@ -194,28 +196,28 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -194,28 +196,28 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
# DP mode # DP mode
if cuda and rank == -1 and torch.cuda.device_count() > 1: if cuda and RANK == -1 and torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
# SyncBatchNorm # SyncBatchNorm
if opt.sync_bn and cuda and rank != -1: if opt.sync_bn and cuda and RANK != -1:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
logger.info('Using SyncBatchNorm()') logger.info('Using SyncBatchNorm()')
# Trainloader # Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, single_cls, dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, single_cls,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank, hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
world_size=opt.world_size, workers=opt.workers, workers=opt.workers,
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: ')) image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
nb = len(dataloader) # number of batches nb = len(dataloader) # number of batches
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1) assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
# Process 0 # Process 0
if rank in [-1, 0]: if RANK in [-1, 0]:
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, single_cls, testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, single_cls,
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1, hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
world_size=opt.world_size, workers=opt.workers, workers=opt.workers,
pad=0.5, prefix=colorstr('val: '))[0] pad=0.5, prefix=colorstr('val: '))[0]
if not opt.resume: if not opt.resume:
...@@ -234,8 +236,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -234,8 +236,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
model.half().float() # pre-reduce anchor precision model.half().float() # pre-reduce anchor precision
# DDP mode # DDP mode
if cuda and rank != -1: if cuda and RANK != -1:
model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank, model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK,
# nn.MultiheadAttention incompatibility with DDP https://github.com/pytorch/pytorch/issues/26698 # nn.MultiheadAttention incompatibility with DDP https://github.com/pytorch/pytorch/issues/26698
find_unused_parameters=any(isinstance(layer, nn.MultiheadAttention) for layer in model.modules())) find_unused_parameters=any(isinstance(layer, nn.MultiheadAttention) for layer in model.modules()))
...@@ -269,15 +271,15 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -269,15 +271,15 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Update image weights (optional) # Update image weights (optional)
if opt.image_weights: if opt.image_weights:
# Generate indices # Generate indices
if rank in [-1, 0]: if RANK in [-1, 0]:
cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
# Broadcast if DDP # Broadcast if DDP
if rank != -1: if RANK != -1:
indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int() indices = (torch.tensor(dataset.indices) if RANK == 0 else torch.zeros(dataset.n)).int()
dist.broadcast(indices, 0) dist.broadcast(indices, 0)
if rank != 0: if RANK != 0:
dataset.indices = indices.cpu().numpy() dataset.indices = indices.cpu().numpy()
# Update mosaic border # Update mosaic border
...@@ -285,11 +287,11 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -285,11 +287,11 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
mloss = torch.zeros(4, device=device) # mean losses mloss = torch.zeros(4, device=device) # mean losses
if rank != -1: if RANK != -1:
dataloader.sampler.set_epoch(epoch) dataloader.sampler.set_epoch(epoch)
pbar = enumerate(dataloader) pbar = enumerate(dataloader)
logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size')) logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size'))
if rank in [-1, 0]: if RANK in [-1, 0]:
pbar = tqdm(pbar, total=nb) # progress bar pbar = tqdm(pbar, total=nb) # progress bar
optimizer.zero_grad() optimizer.zero_grad()
for i, (imgs, targets, paths, _) in pbar: # batch ------------------------------------------------------------- for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
...@@ -319,8 +321,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -319,8 +321,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
with amp.autocast(enabled=cuda): with amp.autocast(enabled=cuda):
pred = model(imgs) # forward pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if rank != -1: if RANK != -1:
loss *= opt.world_size # gradient averaged between devices in DDP mode loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
if opt.quad: if opt.quad:
loss *= 4. loss *= 4.
...@@ -336,7 +338,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -336,7 +338,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
ema.update(model) ema.update(model)
# Print # Print
if rank in [-1, 0]: if RANK in [-1, 0]:
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB) mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
s = ('%10s' * 2 + '%10.4g' * 6) % ( s = ('%10s' * 2 + '%10.4g' * 6) % (
...@@ -362,7 +364,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -362,7 +364,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
scheduler.step() scheduler.step()
# DDP process 0 or single-GPU # DDP process 0 or single-GPU
if rank in [-1, 0]: if RANK in [-1, 0]:
# mAP # mAP
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights']) ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs final_epoch = epoch + 1 == epochs
...@@ -424,7 +426,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -424,7 +426,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# end epoch ---------------------------------------------------------------------------------------------------- # end epoch ----------------------------------------------------------------------------------------------------
# end training ----------------------------------------------------------------------------------------------------- # end training -----------------------------------------------------------------------------------------------------
if rank in [-1, 0]: if RANK in [-1, 0]:
logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n') logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
if plots: if plots:
plot_results(save_dir=save_dir) # save as results.png plot_results(save_dir=save_dir) # save as results.png
...@@ -457,8 +459,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -457,8 +459,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
name='run_' + wandb_logger.wandb_run.id + '_model', name='run_' + wandb_logger.wandb_run.id + '_model',
aliases=['latest', 'best', 'stripped']) aliases=['latest', 'best', 'stripped'])
wandb_logger.finish_run() wandb_logger.finish_run()
else:
dist.destroy_process_group()
torch.cuda.empty_cache() torch.cuda.empty_cache()
return results return results
...@@ -486,7 +487,6 @@ def parse_opt(): ...@@ -486,7 +487,6 @@ def parse_opt():
parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class') parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer') parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode') parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers') parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
parser.add_argument('--project', default='runs/train', help='save to project/name') parser.add_argument('--project', default='runs/train', help='save to project/name')
parser.add_argument('--entity', default=None, help='W&B entity') parser.add_argument('--entity', default=None, help='W&B entity')
...@@ -499,18 +499,15 @@ def parse_opt(): ...@@ -499,18 +499,15 @@ def parse_opt():
parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B') parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch') parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used') parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
opt = parser.parse_args() opt = parser.parse_args()
# Set DDP variables
opt.world_size = int(getattr(os.environ, 'WORLD_SIZE', 1))
opt.global_rank = int(getattr(os.environ, 'RANK', -1))
return opt return opt
def main(opt): def main(opt):
print(opt) set_logging(RANK)
set_logging(opt.global_rank) if RANK in [-1, 0]:
if opt.global_rank in [-1, 0]: print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
check_git_status() check_git_status()
check_requirements(exclude=['thop']) check_requirements(exclude=['thop'])
...@@ -519,11 +516,9 @@ def main(opt): ...@@ -519,11 +516,9 @@ def main(opt):
if opt.resume and not wandb_run: # resume an interrupted run if opt.resume and not wandb_run: # resume an interrupted run
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
apriori = opt.global_rank, opt.local_rank
with open(Path(ckpt).parent.parent / 'opt.yaml') as f: with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
opt = argparse.Namespace(**yaml.safe_load(f)) # replace opt = argparse.Namespace(**yaml.safe_load(f)) # replace
opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = \ opt.cfg, opt.weights, opt.resume, opt.batch_size = '', ckpt, True, opt.total_batch_size # reinstate
'', ckpt, True, opt.total_batch_size, *apriori # reinstate
logger.info('Resuming training from %s' % ckpt) logger.info('Resuming training from %s' % ckpt)
else: else:
# opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml') # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
...@@ -536,19 +531,21 @@ def main(opt): ...@@ -536,19 +531,21 @@ def main(opt):
# DDP mode # DDP mode
opt.total_batch_size = opt.batch_size opt.total_batch_size = opt.batch_size
device = select_device(opt.device, batch_size=opt.batch_size) device = select_device(opt.device, batch_size=opt.batch_size)
if opt.local_rank != -1: if LOCAL_RANK != -1:
assert torch.cuda.device_count() > opt.local_rank from datetime import timedelta
torch.cuda.set_device(opt.local_rank) assert torch.cuda.device_count() > LOCAL_RANK, 'too few GPUS for DDP command'
device = torch.device('cuda', opt.local_rank) torch.cuda.set_device(LOCAL_RANK)
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend device = torch.device('cuda', LOCAL_RANK)
assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count' dist.init_process_group(backend="gloo", timeout=timedelta(seconds=60))
assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training' assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
opt.batch_size = opt.total_batch_size // opt.world_size opt.batch_size = opt.total_batch_size // WORLD_SIZE
# Train # Train
logger.info(opt)
if not opt.evolve: if not opt.evolve:
train(opt.hyp, opt, device) train(opt.hyp, opt, device)
if WORLD_SIZE > 1 and RANK == 0:
_ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')]
# Evolve hyperparameters (optional) # Evolve hyperparameters (optional)
else: else:
...@@ -584,7 +581,7 @@ def main(opt): ...@@ -584,7 +581,7 @@ def main(opt):
with open(opt.hyp) as f: with open(opt.hyp) as f:
hyp = yaml.safe_load(f) # load hyps dict hyp = yaml.safe_load(f) # load hyps dict
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve' assert LOCAL_RANK == -1, 'DDP mode not implemented for --evolve'
opt.notest, opt.nosave = True, True # only test/save final epoch opt.notest, opt.nosave = True, True # only test/save final epoch
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
......
...@@ -64,7 +64,7 @@ def exif_size(img): ...@@ -64,7 +64,7 @@ def exif_size(img):
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0, def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
rect=False, rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''): rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with torch_distributed_zero_first(rank): with torch_distributed_zero_first(rank):
dataset = LoadImagesAndLabels(path, imgsz, batch_size, dataset = LoadImagesAndLabels(path, imgsz, batch_size,
...@@ -79,7 +79,7 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non ...@@ -79,7 +79,7 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non
prefix=prefix) prefix=prefix)
batch_size = min(batch_size, len(dataset)) batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
# Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader() # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
......
...@@ -13,6 +13,7 @@ from pathlib import Path ...@@ -13,6 +13,7 @@ from pathlib import Path
import torch import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision import torchvision
...@@ -30,10 +31,10 @@ def torch_distributed_zero_first(local_rank: int): ...@@ -30,10 +31,10 @@ def torch_distributed_zero_first(local_rank: int):
Decorator to make all processes in distributed training wait for each local_master to do something. Decorator to make all processes in distributed training wait for each local_master to do something.
""" """
if local_rank not in [-1, 0]: if local_rank not in [-1, 0]:
torch.distributed.barrier() dist.barrier()
yield yield
if local_rank == 0: if local_rank == 0:
torch.distributed.barrier() dist.barrier()
def init_torch_seeds(seed=0): def init_torch_seeds(seed=0):
......
"""Utilities and tools for tracking runs with Weights & Biases.""" """Utilities and tools for tracking runs with Weights & Biases."""
import logging import logging
import os
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
...@@ -18,6 +19,7 @@ try: ...@@ -18,6 +19,7 @@ try:
except ImportError: except ImportError:
wandb = None wandb = None
RANK = int(os.getenv('RANK', -1))
WANDB_ARTIFACT_PREFIX = 'wandb-artifact://' WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
...@@ -42,10 +44,10 @@ def get_run_info(run_path): ...@@ -42,10 +44,10 @@ def get_run_info(run_path):
def check_wandb_resume(opt): def check_wandb_resume(opt):
process_wandb_config_ddp_mode(opt) if opt.global_rank not in [-1, 0] else None process_wandb_config_ddp_mode(opt) if RANK not in [-1, 0] else None
if isinstance(opt.resume, str): if isinstance(opt.resume, str):
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
if opt.global_rank not in [-1, 0]: # For resuming DDP runs if RANK not in [-1, 0]: # For resuming DDP runs
entity, project, run_id, model_artifact_name = get_run_info(opt.resume) entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
api = wandb.Api() api = wandb.Api()
artifact = api.artifact(entity + '/' + project + '/' + model_artifact_name + ':latest') artifact = api.artifact(entity + '/' + project + '/' + model_artifact_name + ':latest')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论