Unverified 提交 6a3ee7cf authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Hub models `map_location=device` (#3894)

* Hub models `map_location=device` * cleanup
上级 8930e22c
...@@ -36,13 +36,15 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo ...@@ -36,13 +36,15 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
fname = Path(name).with_suffix('.pt') # checkpoint filename fname = Path(name).with_suffix('.pt') # checkpoint filename
try: try:
device = select_device(('0' if torch.cuda.is_available() else 'cpu') if device is None else device)
if pretrained and channels == 3 and classes == 80: if pretrained and channels == 3 and classes == 80:
model = attempt_load(fname, map_location=torch.device('cpu')) # download/load FP32 model model = attempt_load(fname, map_location=device) # download/load FP32 model
else: else:
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
model = Model(cfg, channels, classes) # create model model = Model(cfg, channels, classes) # create model
if pretrained: if pretrained:
ckpt = torch.load(attempt_download(fname), map_location=torch.device('cpu')) # load ckpt = torch.load(attempt_download(fname), map_location=device) # load
msd = model.state_dict() # model state_dict msd = model.state_dict() # model state_dict
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
...@@ -51,7 +53,6 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo ...@@ -51,7 +53,6 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
model.names = ckpt['model'].names # set class names attribute model.names = ckpt['model'].names # set class names attribute
if autoshape: if autoshape:
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
device = select_device('0' if torch.cuda.is_available() else 'cpu') if device is None else torch.device(device)
return model.to(device) return model.to(device)
except Exception as e: except Exception as e:
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
import datetime import datetime
import logging import logging
import math
import os import os
import platform import platform
import subprocess import subprocess
...@@ -11,6 +10,7 @@ from contextlib import contextmanager ...@@ -11,6 +10,7 @@ from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
import math
import torch import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import torch.distributed as dist import torch.distributed as dist
...@@ -64,7 +64,8 @@ def git_describe(path=Path(__file__).parent): # path must be a directory ...@@ -64,7 +64,8 @@ def git_describe(path=Path(__file__).parent): # path must be a directory
def select_device(device='', batch_size=None): def select_device(device='', batch_size=None):
# device = 'cpu' or '0' or '0,1,2,3' # device = 'cpu' or '0' or '0,1,2,3'
s = f'YOLOv5 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string s = f'YOLOv5 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
cpu = device.lower() == 'cpu' device = str(device).strip().lower().replace('cuda:', '') # to string, 'cuda:0' to '0'
cpu = device == 'cpu'
if cpu: if cpu:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
elif device: # non-cpu device requested elif device: # non-cpu device requested
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论