Unverified 提交 262187e9 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Two dimensional `size=(h,w)` AutoShape support (#9072)

* Two dimensional `size=(h,w)` AutoShape support May resolve https://github.com/ultralytics/yolov5/issues/9039Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update hubconf.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ciSigned-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
上级 db2ee5fa
......@@ -30,7 +30,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
from models.common import AutoShape, DetectMultiBackend
from models.experimental import attempt_load
from models.yolo import DetectionModel
from models.yolo import ClassificationModel, DetectionModel
from utils.downloads import attempt_download
from utils.general import LOGGER, check_requirements, intersect_dicts, logging
from utils.torch_utils import select_device
......@@ -45,8 +45,12 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
if pretrained and channels == 3 and classes == 80:
try:
model = DetectMultiBackend(path, device=device, fuse=autoshape) # detection model
if autoshape and isinstance(model.model, DetectionModel):
model = AutoShape(model) # for file/URI/PIL/cv2/np inputs and NMS
if autoshape:
if model.pt and isinstance(model.model, ClassificationModel):
LOGGER.warning('WARNING: YOLOv5 v6.2 ClassificationModel is not yet AutoShape compatible. '
'You must pass torch tensors in BCHW to this model, i.e. shape(1,3,224,224).')
else:
model = AutoShape(model) # for file/URI/PIL/cv2/np inputs and NMS
except Exception:
model = attempt_load(path, device=device, fuse=False) # arbitrary model
else:
......
......@@ -589,7 +589,7 @@ class AutoShape(nn.Module):
@smart_inference_mode()
def forward(self, ims, size=640, augment=False, profile=False):
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
# Inference from various sources. For size(height=640, width=1280), RGB images example inputs are:
# file: ims = 'data/images/zidane.jpg' # str or PosixPath
# URI: = 'https://ultralytics.com/images/zidane.jpg'
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
......@@ -600,6 +600,8 @@ class AutoShape(nn.Module):
dt = (Profile(), Profile(), Profile())
with dt[0]:
if isinstance(size, int): # expand
size = (size, size)
p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
if isinstance(ims, torch.Tensor): # torch
......@@ -622,10 +624,10 @@ class AutoShape(nn.Module):
im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
s = im.shape[:2] # HWC
shape0.append(s) # image shape
g = (size / max(s)) # gain
g = max(size) / max(s) # gain
shape1.append([y * g for y in s])
ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
shape1 = [make_divisible(x, self.stride) if self.pt else size for x in np.array(shape1).max(0)] # inf shape
shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] if self.pt else size # inf shape
x = [letterbox(im, shape1, auto=False)[0] for im in ims] # pad
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论