Unverified 提交 a75a1105 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Self-contained checkpoint `--resume` (#8839)

* Single checkpoint resume * Update train.py * Add hyp * Add hyp * Add hyp * FIX * avoid resume on url data * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * avoid resume on url data * avoid resume on url data * Update Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
上级 4d8d84b0
...@@ -43,7 +43,7 @@ from utils.autoanchor import check_anchors ...@@ -43,7 +43,7 @@ from utils.autoanchor import check_anchors
from utils.autobatch import check_train_batch_size from utils.autobatch import check_train_batch_size
from utils.callbacks import Callbacks from utils.callbacks import Callbacks
from utils.dataloaders import create_dataloader from utils.dataloaders import create_dataloader
from utils.downloads import attempt_download from utils.downloads import attempt_download, is_url
from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size, from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size,
check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path, check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path,
init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
...@@ -77,6 +77,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio ...@@ -77,6 +77,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
with open(hyp, errors='ignore') as f: with open(hyp, errors='ignore') as f:
hyp = yaml.safe_load(f) # load hyps dict hyp = yaml.safe_load(f) # load hyps dict
LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
opt.hyp = hyp.copy() # for saving hyps to checkpoints
# Save run settings # Save run settings
if not evolve: if not evolve:
...@@ -377,6 +378,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio ...@@ -377,6 +378,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
'updates': ema.updates, 'updates': ema.updates,
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None, 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None,
'opt': vars(opt),
'date': datetime.now().isoformat()} 'date': datetime.now().isoformat()}
# Save last, best and delete # Save last, best and delete
...@@ -472,8 +474,7 @@ def parse_opt(known=False): ...@@ -472,8 +474,7 @@ def parse_opt(known=False):
parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval') parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval')
parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use') parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use')
opt = parser.parse_known_args()[0] if known else parser.parse_args() return parser.parse_known_args()[0] if known else parser.parse_args()
return opt
def main(opt, callbacks=Callbacks()): def main(opt, callbacks=Callbacks()):
...@@ -484,12 +485,20 @@ def main(opt, callbacks=Callbacks()): ...@@ -484,12 +485,20 @@ def main(opt, callbacks=Callbacks()):
check_requirements(exclude=['thop']) check_requirements(exclude=['thop'])
# Resume # Resume
if opt.resume and not check_wandb_resume(opt) and not opt.evolve: # resume an interrupted run if opt.resume and not (check_wandb_resume(opt) or opt.evolve): # resume an interrupted run
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path last = Path(opt.resume if isinstance(opt.resume, str) else get_latest_run()) # specified or most recent last.pt
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' assert last.is_file(), f'ERROR: --resume checkpoint {last} does not exist'
with open(Path(ckpt).parent.parent / 'opt.yaml', errors='ignore') as f: opt_yaml = last.parent.parent / 'opt.yaml' # train options yaml
opt = argparse.Namespace(**yaml.safe_load(f)) # replace opt_data = opt.data # original dataset
opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate if opt_yaml.is_file():
with open(opt_yaml, errors='ignore') as f:
d = yaml.safe_load(f)
else:
d = torch.load(last, map_location='cpu')['opt']
opt = argparse.Namespace(**d) # replace
opt.cfg, opt.weights, opt.resume = '', str(last), True # reinstate
if is_url(opt.data):
opt.data = str(opt_data) # avoid HUB resume auth timeout
else: else:
opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \ opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks
......
...@@ -16,12 +16,14 @@ import requests ...@@ -16,12 +16,14 @@ import requests
import torch import torch
def is_url(url): def is_url(url, check_online=True):
# Check if online file exists # Check if online file exists
try: try:
r = urllib.request.urlopen(url) # response url = str(url)
return r.getcode() == 200 result = urllib.parse.urlparse(url)
except urllib.request.HTTPError: assert all([result.scheme, result.netloc, result.path]) # check if is url
return (urllib.request.urlopen(url).getcode() == 200) if check_online else True # check if exists online
except (AssertionError, urllib.request.HTTPError):
return False return False
......
...@@ -317,8 +317,9 @@ def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, re ...@@ -317,8 +317,9 @@ def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, re
ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
ema.updates = ckpt['updates'] ema.updates = ckpt['updates']
if resume: if resume:
assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.' assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.\n' \
LOGGER.info(f'Resuming training from {weights} for {epochs - start_epoch} more epochs to {epochs} total epochs') f"Start a new training without --resume, i.e. 'python train.py --weights {weights}'"
LOGGER.info(f'Resuming training from {weights} from epoch {start_epoch} to {epochs} total epochs')
if epochs < start_epoch: if epochs < start_epoch:
LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.") LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
epochs += ckpt['epoch'] # finetune additional epochs epochs += ckpt['epoch'] # finetune additional epochs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论