提交 c687d5c1 authored 作者: Glenn Jocher's avatar Glenn Jocher

reorganize train initialization steps

上级 bc1fd13a
...@@ -161,7 +161,7 @@ def train(hyp, opt, device, tb_writer=None): ...@@ -161,7 +161,7 @@ def train(hyp, opt, device, tb_writer=None):
# DDP mode # DDP mode
if cuda and rank != -1: if cuda and rank != -1:
model = DDP(model, device_ids=[opt.local_rank], output_device=(opt.local_rank)) model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank)
# Trainloader # Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
...@@ -171,23 +171,14 @@ def train(hyp, opt, device, tb_writer=None): ...@@ -171,23 +171,14 @@ def train(hyp, opt, device, tb_writer=None):
nb = len(dataloader) # number of batches nb = len(dataloader) # number of batches
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1) assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
# Testloader # Process 0
if rank in [-1, 0]: if rank in [-1, 0]:
ema.updates = start_epoch * nb // accumulate # set EMA updates ema.updates = start_epoch * nb // accumulate # set EMA updates
testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt,
hyp=hyp, augment=False, cache=opt.cache_images, rect=True, rank=-1, hyp=hyp, augment=False, cache=opt.cache_images, rect=True, rank=-1,
world_size=opt.world_size, workers=opt.workers)[0] # only runs on process 0 world_size=opt.world_size, workers=opt.workers)[0] # testloader
# Model parameters if not opt.resume:
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
model.nc = nc # attach number of classes to model
model.hyp = hyp # attach hyperparameters to model
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
model.names = names
# Classes and Anchors
if rank in [-1, 0] and not opt.resume:
labels = np.concatenate(dataset.labels, 0) labels = np.concatenate(dataset.labels, 0)
c = torch.tensor(labels[:, 0]) # classes c = torch.tensor(labels[:, 0]) # classes
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
...@@ -201,6 +192,14 @@ def train(hyp, opt, device, tb_writer=None): ...@@ -201,6 +192,14 @@ def train(hyp, opt, device, tb_writer=None):
if not opt.noautoanchor: if not opt.noautoanchor:
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
# Model parameters
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
model.nc = nc # attach number of classes to model
model.hyp = hyp # attach hyperparameters to model
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
model.names = names
# Start training # Start training
t0 = time.time() t0 = time.time()
nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations) nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
...@@ -209,10 +208,8 @@ def train(hyp, opt, device, tb_writer=None): ...@@ -209,10 +208,8 @@ def train(hyp, opt, device, tb_writer=None):
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification' results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
scheduler.last_epoch = start_epoch - 1 # do not move scheduler.last_epoch = start_epoch - 1 # do not move
scaler = amp.GradScaler(enabled=cuda) scaler = amp.GradScaler(enabled=cuda)
logger.info('Image sizes %g train, %g test' % (imgsz, imgsz_test)) logger.info('Image sizes %g train, %g test\nUsing %g dataloader workers\nLogging results to %s\n'
logger.info('Using %g dataloader workers' % dataloader.num_workers) 'Starting training for %g epochs...' % (imgsz, imgsz_test, dataloader.num_workers, log_dir, epochs))
logger.info('Starting training for %g epochs...' % epochs)
# torch.autograd.set_detect_anomaly(True)
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
model.train() model.train()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论