提交 25e51bce authored 作者: Alex Stoken's avatar Alex Stoken

add util function to get most recent last.pt file

added logic in train.py __main__ to handle resuming from a run
上级 490f1e7b
...@@ -198,10 +198,10 @@ def train(hyp): ...@@ -198,10 +198,10 @@ def train(hyp):
model.names = data_dict['names'] model.names = data_dict['names']
#save hyperparamter and training options in run folder #save hyperparamter and training options in run folder
with open(os.path.join(log_dir, 'hyp.yaml', 'w')) as f: with open(os.path.join(log_dir, 'hyp.yaml'), 'w') as f:
yaml.dump(hyp, f) yaml.dump(hyp, f)
with open(os.path.join(log_dir, 'opt.yaml', 'w')) as f: with open(os.path.join(log_dir, 'opt.yaml'), 'w') as f:
yaml.dump(opt, f) yaml.dump(opt, f)
# Class frequency # Class frequency
...@@ -294,7 +294,7 @@ def train(hyp): ...@@ -294,7 +294,7 @@ def train(hyp):
# Plot # Plot
if ni < 3: if ni < 3:
f = 'train_batch%g.jpg' % i # filename f = os.path.join(log_dir, 'train_batch%g.jpg' % i) # filename
res = plot_images(images=imgs, targets=targets, paths=paths, fname=f) res = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
if tb_writer: if tb_writer:
tb_writer.add_image(f, res, dataformats='HWC', global_step=epoch) tb_writer.add_image(f, res, dataformats='HWC', global_step=epoch)
...@@ -385,6 +385,7 @@ if __name__ == '__main__': ...@@ -385,6 +385,7 @@ if __name__ == '__main__':
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes') parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
parser.add_argument('--rect', action='store_true', help='rectangular training') parser.add_argument('--rect', action='store_true', help='rectangular training')
parser.add_argument('--resume', action='store_true', help='resume training from last.pt') parser.add_argument('--resume', action='store_true', help='resume training from last.pt')
parser.add_argument('--resume_from_run', type=str, default='', 'resume training from last.pt in this dir')
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
parser.add_argument('--notest', action='store_true', help='only test final epoch') parser.add_argument('--notest', action='store_true', help='only test final epoch')
parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters') parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
...@@ -398,6 +399,12 @@ if __name__ == '__main__': ...@@ -398,6 +399,12 @@ if __name__ == '__main__':
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
parser.add_argument('--hyp', type=str, default='', help ='path to hyp yaml file') parser.add_argument('--hyp', type=str, default='', help ='path to hyp yaml file')
opt = parser.parse_args() opt = parser.parse_args()
if opt.resume and not opt.resume_from_run:
last = get_latest_run()
print(f'WARNING: No run provided to resume from. Resuming from most recent run found at {last}')
else:
last = opt.resume_from_run
opt.weights = last if opt.resume else opt.weights opt.weights = last if opt.resume else opt.weights
opt.cfg = check_file(opt.cfg) # check file opt.cfg = check_file(opt.cfg) # check file
opt.data = check_file(opt.data) # check file opt.data = check_file(opt.data) # check file
......
...@@ -36,6 +36,12 @@ def init_seeds(seed=0): ...@@ -36,6 +36,12 @@ def init_seeds(seed=0):
np.random.seed(seed) np.random.seed(seed)
torch_utils.init_seeds(seed=seed) torch_utils.init_seeds(seed=seed)
def get_latest_run(search_dir = './runs/'):
# get path to most recent 'last.pt' in run dirs
# assumes most recently saved 'last.pt' is the desired weights to --resume from
last_list = glob.glob('runs/*/last.pt')
latest = max(last_list, key = os.path.getctime)
return latest
def check_git_status(): def check_git_status():
# Suggest 'git pull' if repo is out of date # Suggest 'git pull' if repo is out of date
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论