Unverified 提交 6d6e2ca6 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update train.py (#3667)

上级 ac348345
...@@ -22,7 +22,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -22,7 +22,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm from tqdm import tqdm
import test # import test.py to get mAP after each epoch import test # for end-of-epoch mAP
from models.experimental import attempt_load from models.experimental import attempt_load
from models.yolo import Model from models.yolo import Model
from utils.autoanchor import check_anchors from utils.autoanchor import check_anchors
...@@ -39,7 +39,11 @@ from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume ...@@ -39,7 +39,11 @@ from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def train(hyp, opt, device, tb_writer=None): def train(hyp,
opt,
device,
tb_writer=None
):
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \ save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
...@@ -341,7 +345,7 @@ def train(hyp, opt, device, tb_writer=None): ...@@ -341,7 +345,7 @@ def train(hyp, opt, device, tb_writer=None):
save_dir.glob('train*.jpg') if x.exists()]}) save_dir.glob('train*.jpg') if x.exists()]})
# end batch ------------------------------------------------------------------------------------------------ # end batch ------------------------------------------------------------------------------------------------
# Scheduler # Scheduler
lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
scheduler.step() scheduler.step()
...@@ -404,12 +408,11 @@ def train(hyp, opt, device, tb_writer=None): ...@@ -404,12 +408,11 @@ def train(hyp, opt, device, tb_writer=None):
torch.save(ckpt, best) torch.save(ckpt, best)
if wandb_logger.wandb: if wandb_logger.wandb:
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1: if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
wandb_logger.log_model( wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi)
last.parent, opt, epoch, fi, best_model=best_fitness == fi)
del ckpt del ckpt
# end epoch ---------------------------------------------------------------------------------------------------- # end epoch ----------------------------------------------------------------------------------------------------
# end training # end training -----------------------------------------------------------------------------------------------------
if rank in [-1, 0]: if rank in [-1, 0]:
logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n') logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
if plots: if plots:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论