Unverified 提交 c67e7220 authored 作者: Jirka Borovec's avatar Jirka Borovec 提交者: GitHub

fix compatibility for hyper config (#1146)

* fix/hyper * Hyp giou check to train.py * restore general.py * train.py overwrite fix * restore general.py and pep8 update Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 4d3680c8
...@@ -5,6 +5,7 @@ import random ...@@ -5,6 +5,7 @@ import random
import shutil import shutil
import time import time
from pathlib import Path from pathlib import Path
from warnings import warn
import math import math
import numpy as np import numpy as np
...@@ -430,9 +431,8 @@ if __name__ == '__main__': ...@@ -430,9 +431,8 @@ if __name__ == '__main__':
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
log_dir = increment_dir(Path(opt.logdir) / 'exp', opt.name) # runs/exp1 log_dir = increment_dir(Path(opt.logdir) / 'exp', opt.name) # runs/exp1
device = select_device(opt.device, batch_size=opt.batch_size)
# DDP mode # DDP mode
device = select_device(opt.device, batch_size=opt.batch_size)
if opt.local_rank != -1: if opt.local_rank != -1:
assert torch.cuda.device_count() > opt.local_rank assert torch.cuda.device_count() > opt.local_rank
torch.cuda.set_device(opt.local_rank) torch.cuda.set_device(opt.local_rank)
...@@ -441,11 +441,16 @@ if __name__ == '__main__': ...@@ -441,11 +441,16 @@ if __name__ == '__main__':
assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count' assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
opt.batch_size = opt.total_batch_size // opt.world_size opt.batch_size = opt.total_batch_size // opt.world_size
logger.info(opt) # Hyperparameters
with open(opt.hyp) as f: with open(opt.hyp) as f:
hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
if 'box' not in hyp:
warn('Compatibility: %s missing "box" which was renamed from "giou" in %s' %
(opt.hyp, 'https://github.com/ultralytics/yolov5/pull/1120'))
hyp['box'] = hyp.pop('giou')
# Train # Train
logger.info(opt)
if not opt.evolve: if not opt.evolve:
tb_writer = None tb_writer = None
if opt.global_rank in [-1, 0]: if opt.global_rank in [-1, 0]:
......
import glob import glob
import logging import logging
import math
import os import os
import platform import platform
import random import random
import re
import shutil import shutil
import subprocess import subprocess
import time import time
import re
from contextlib import contextmanager from contextlib import contextmanager
from copy import copy from copy import copy
from pathlib import Path from pathlib import Path
import cv2 import cv2
import math
import matplotlib import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论