Unverified 提交 ba6f3f97 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Enable direct `--weights URL` definition (#3373)

* Enable direct `--weights URL` definition @KalenMike this PR will enable direct --weights URL definition. Example use case: ``` python train.py --weights https://storage.googleapis.com/bucket/dir/model.pt ``` * cleanup * bug fixes * weights = attempt_download(weights) * Update experimental.py * Update hubconf.py * return bug fix * comment mirror * min_bytes
上级 b78e30dd
......@@ -41,8 +41,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
model = Model(cfg, channels, classes) # create model
if pretrained:
attempt_download(fname) # download if not found locally
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
ckpt = torch.load(attempt_download(fname), map_location=torch.device('cpu')) # load
msd = model.state_dict() # model state_dict
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
......
......@@ -116,8 +116,7 @@ def attempt_load(weights, map_location=None, inplace=True):
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
attempt_download(w)
ckpt = torch.load(w, map_location=map_location) # load
ckpt = torch.load(attempt_download(w), map_location=map_location) # load
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
# Compatibility updates
......
......@@ -83,7 +83,7 @@ def train(hyp, opt, device, tb_writer=None):
pretrained = weights.endswith('.pt')
if pretrained:
with torch_distributed_zero_first(rank):
attempt_download(weights) # download if not found locally
weights = attempt_download(weights) # download if not found locally
ckpt = torch.load(weights, map_location=device) # load checkpoint
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
......
......@@ -16,11 +16,37 @@ def gsutil_getsize(url=''):
return eval(s.split(' ')[0]) if len(s) else 0 # bytes
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
file = Path(file)
try: # GitHub
print(f'Downloading {url} to {file}...')
torch.hub.download_url_to_file(url, str(file))
assert file.exists() and file.stat().st_size > min_bytes # check
except Exception as e: # GCP
file.unlink(missing_ok=True) # remove partial downloads
print(f'Download error: {e}\nRe-attempting {url2 or url} to {file}...')
os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
finally:
if not file.exists() or file.stat().st_size < min_bytes: # check
file.unlink(missing_ok=True) # remove partial downloads
print(f'ERROR: Download failure: {error_msg or url}')
print('')
def attempt_download(file, repo='ultralytics/yolov5'):
# Attempt file download if does not exist
file = Path(str(file).strip().replace("'", ''))
if not file.exists():
# URL specified
name = file.name
if str(file).startswith(('http:/', 'https:/')): # download
url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
safe_download(file=name, url=url, min_bytes=1E5)
return name
# GitHub assets
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
try:
response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api
......@@ -34,27 +60,14 @@ def attempt_download(file, repo='ultralytics/yolov5'):
except:
tag = 'v5.0' # current release
name = file.name
if name in assets:
msg = f'{file} missing, try downloading from https://github.com/{repo}/releases/'
redundant = False # second download option
try: # GitHub
url = f'https://github.com/{repo}/releases/download/{tag}/{name}'
print(f'Downloading {url} to {file}...')
torch.hub.download_url_to_file(url, file)
assert file.exists() and file.stat().st_size > 1E6 # check
except Exception as e: # GCP
print(f'Download error: {e}')
assert redundant, 'No secondary mirror'
url = f'https://storage.googleapis.com/{repo}/ckpt/{name}'
print(f'Downloading {url} to {file}...')
os.system(f"curl -L '{url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
finally:
if not file.exists() or file.stat().st_size < 1E6: # check
file.unlink(missing_ok=True) # remove partial downloads
print(f'ERROR: Download failure: {msg}')
print('')
return
safe_download(file,
url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
# url2=f'https://storage.googleapis.com/{repo}/ckpt/{name}', # backup url (optional)
min_bytes=1E5,
error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/')
return str(file)
def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论