Unverified 提交 eb359c3a authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Add PyTorch Hub classification CI checks (#9027)

* Add PyTorch Hub classification CI checks Add PyTorch Hub loading of official and custom trained classification models to CI checks. May help resolve https://github.com/ultralytics/yolov5/issues/8790#issuecomment-1219840718Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update hubconf.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 a5a47c52
...@@ -133,3 +133,8 @@ jobs: ...@@ -133,3 +133,8 @@ jobs:
python classify/predict.py --imgsz 32 --weights $b --source ../datasets/mnist2560/test/7/60.png # predict python classify/predict.py --imgsz 32 --weights $b --source ../datasets/mnist2560/test/7/60.png # predict
python classify/predict.py --imgsz 32 --weights $m --source data/images/bus.jpg # predict python classify/predict.py --imgsz 32 --weights $m --source data/images/bus.jpg # predict
python export.py --weights $b --img 64 --imgsz 224 --include torchscript # export python export.py --weights $b --img 64 --imgsz 224 --include torchscript # export
python - <<EOF
import torch
for path in '$m', '$b':
model = torch.hub.load('.', 'custom', path=path, source='local')
EOF
...@@ -30,7 +30,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo ...@@ -30,7 +30,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
from models.common import AutoShape, DetectMultiBackend from models.common import AutoShape, DetectMultiBackend
from models.experimental import attempt_load from models.experimental import attempt_load
from models.yolo import Model from models.yolo import DetectionModel
from utils.downloads import attempt_download from utils.downloads import attempt_download
from utils.general import LOGGER, check_requirements, intersect_dicts, logging from utils.general import LOGGER, check_requirements, intersect_dicts, logging
from utils.torch_utils import select_device from utils.torch_utils import select_device
...@@ -45,13 +45,13 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo ...@@ -45,13 +45,13 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
if pretrained and channels == 3 and classes == 80: if pretrained and channels == 3 and classes == 80:
try: try:
model = DetectMultiBackend(path, device=device, fuse=autoshape) # detection model model = DetectMultiBackend(path, device=device, fuse=autoshape) # detection model
if autoshape: if autoshape and isinstance(model.model, DetectionModel):
model = AutoShape(model) # for file/URI/PIL/cv2/np inputs and NMS model = AutoShape(model) # for file/URI/PIL/cv2/np inputs and NMS
except Exception: except Exception:
model = attempt_load(path, device=device, fuse=False) # arbitrary model model = attempt_load(path, device=device, fuse=False) # arbitrary model
else: else:
cfg = list((Path(__file__).parent / 'models').rglob(f'{path.stem}.yaml'))[0] # model.yaml path cfg = list((Path(__file__).parent / 'models').rglob(f'{path.stem}.yaml'))[0] # model.yaml path
model = Model(cfg, channels, classes) # create model model = DetectionModel(cfg, channels, classes) # create model
if pretrained: if pretrained:
ckpt = torch.load(attempt_download(path), map_location=device) # load ckpt = torch.load(attempt_download(path), map_location=device) # load
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论