Unverified 提交 2e109099 authored 作者: Jackson Argo's avatar Jackson Argo 提交者: GitHub

Fix missing attr model.model when loading custom yolov model (#8830)

* Update hubconf.py Loading a custom yolov model causes this line to fail. Adding a test to check if the model actually has a model.model field. With this check, I'm able to load the model no prob. Loading model via ```py model = torch.hub.load( 'ultralytics/yolov5', 'custom', 'models/frozen_backbone_coco_unlabeled_best.onnx', autoshape=True, force_reload=False ) ``` Causes traceback: ``` Traceback (most recent call last): File "/Users/jackson/Documents/GitHub/w210-capstone/.venv/lib/python3.10/site-packages/flask/app.py", line 2077, in wsgi_app response = self.full_dispatch_request() File "/Users/jackson/Documents/GitHub/w210-capstone/.venv/lib/python3.10/site-packages/flask/app.py", line 1525, in full_dispatch_request rv = self.handle_user_exception(e) File "/Users/jackson/Documents/GitHub/w210-capstone/.venv/lib/python3.10/site-packages/flask/app.py", line 1523, in full_dispatch_request rv = self.dispatch_request() File "/Users/jackson/Documents/GitHub/w210-capstone/.venv/lib/python3.10/site-packages/flask/app.py", line 1509, in dispatch_request return self.ensure_sync(self.view_functions[rule.endpoint])(**req.view_args) File "/Users/jackson/Documents/GitHub/w210-capstone/api/endpoints/predictions.py", line 26, in post_predictions yolov_predictions = predict_bounding_boxes_for_collection(collection_id) File "/Users/jackson/Documents/GitHub/w210-capstone/api/predictions/predict_bounding_boxes.py", line 43, in predict_bounding_boxes_for_collection model = torch.hub.load( File "/Users/jackson/Documents/GitHub/w210-capstone/.venv/lib/python3.10/site-packages/torch/hub.py", line 404, in load model = _load_local(repo_or_dir, model, *args, **kwargs) File "/Users/jackson/Documents/GitHub/w210-capstone/.venv/lib/python3.10/site-packages/torch/hub.py", line 433, in _load_local model = entry(*args, **kwargs) File "/Users/jackson/.cache/torch/hub/ultralytics_yolov5_master/hubconf.py", line 72, in custom return _create(path, autoshape=autoshape, verbose=_verbose, device=device) File "/Users/jackson/.cache/torch/hub/ultralytics_yolov5_master/hubconf.py", line 67, in _create raise Exception(s) from e Exception: 'DetectMultiBackend' object has no attribute 'model'. Cache may be out of date, try `force_reload=True` or see https://github.com/ultralytics/yolov5/issues/36 for help. Exception on /api/v1/predictions [POST] Traceback (most recent call last): File "/Users/jackson/.cache/torch/hub/ultralytics_yolov5_master/hubconf.py", line 58, in _create model.model.model[-1].inplace = False # Detect.inplace=False for safe multithread inference File "/Users/jackson/Documents/GitHub/w210-capstone/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1185, in __getattr__ raise AttributeError("'{}' object has no attribute '{}'".format( AttributeError: 'DetectMultiBackend' object has no attribute 'model' ``` * Update hubconf.py * Update common.py Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 b7635efb
......@@ -29,6 +29,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
from pathlib import Path
from models.common import AutoShape, DetectMultiBackend
from models.experimental import attempt_load
from models.yolo import Model
from utils.downloads import attempt_download
from utils.general import LOGGER, check_requirements, intersect_dicts, logging
......@@ -42,8 +43,12 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
try:
device = select_device(device)
if pretrained and channels == 3 and classes == 80:
model = DetectMultiBackend(path, device=device, fuse=autoshape) # download/load FP32 model
# model = models.experimental.attempt_load(path, map_location=device) # download/load FP32 model
try:
model = DetectMultiBackend(path, device=device, fuse=autoshape) # detection model
if autoshape:
model = AutoShape(model) # for file/URI/PIL/cv2/np inputs and NMS
except Exception:
model = attempt_load(path, device=device, fuse=False) # arbitrary model
else:
cfg = list((Path(__file__).parent / 'models').rglob(f'{path.stem}.yaml'))[0] # model.yaml path
model = Model(cfg, channels, classes) # create model
......@@ -54,9 +59,6 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
model.load_state_dict(csd, strict=False) # load
if len(ckpt['model'].names) == classes:
model.names = ckpt['model'].names # set class names attribute
if autoshape:
model.model.model[-1].inplace = False # Detect.inplace=False for safe multithread inference
model = AutoShape(model) # for file/URI/PIL/cv2/np inputs and NMS
if not verbose:
LOGGER.setLevel(logging.INFO) # reset to default
return model.to(device)
......
......@@ -562,6 +562,9 @@ class AutoShape(nn.Module):
self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
self.pt = not self.dmb or model.pt # PyTorch model
self.model = model.eval()
if self.pt:
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
m.inplace = False # Detect.inplace=False for safe multithread inference
def _apply(self, fn):
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论