Unverified 提交 14b0abe2 authored 作者: NanoCode012's avatar NanoCode012 提交者: GitHub

autoShape() default for PyTorch Hub models (#1692)

* Add autoshape parameter * Remove autoshape call in ReadMe * Update hubconf.py * file/URI inputs and autoshape check passthrough Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 c0ffcdf9
......@@ -106,7 +106,7 @@ import torch
from PIL import Image
# Model
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True).autoshape() # for PIL/cv2/np inputs and NMS
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) # for PIL/cv2/np inputs and NMS
# Images
img1 = Image.open('zidane.jpg')
......
......@@ -17,7 +17,7 @@ dependencies = ['torch', 'yaml']
set_logging()
def create(name, pretrained, channels, classes):
def create(name, pretrained, channels, classes, autoshape):
"""Creates a specified YOLOv5 model
Arguments:
......@@ -41,7 +41,8 @@ def create(name, pretrained, channels, classes):
model.load_state_dict(state_dict, strict=False) # load
if len(ckpt['model'].names) == classes:
model.names = ckpt['model'].names # set class names attribute
# model = model.autoshape() # for PIL/cv2/np inputs and NMS
if autoshape:
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
return model
except Exception as e:
......@@ -50,7 +51,7 @@ def create(name, pretrained, channels, classes):
raise Exception(s) from e
def yolov5s(pretrained=False, channels=3, classes=80):
def yolov5s(pretrained=False, channels=3, classes=80, autoshape=True):
"""YOLOv5-small model from https://github.com/ultralytics/yolov5
Arguments:
......@@ -61,10 +62,10 @@ def yolov5s(pretrained=False, channels=3, classes=80):
Returns:
pytorch model
"""
return create('yolov5s', pretrained, channels, classes)
return create('yolov5s', pretrained, channels, classes, autoshape)
def yolov5m(pretrained=False, channels=3, classes=80):
def yolov5m(pretrained=False, channels=3, classes=80, autoshape=True):
"""YOLOv5-medium model from https://github.com/ultralytics/yolov5
Arguments:
......@@ -75,10 +76,10 @@ def yolov5m(pretrained=False, channels=3, classes=80):
Returns:
pytorch model
"""
return create('yolov5m', pretrained, channels, classes)
return create('yolov5m', pretrained, channels, classes, autoshape)
def yolov5l(pretrained=False, channels=3, classes=80):
def yolov5l(pretrained=False, channels=3, classes=80, autoshape=True):
"""YOLOv5-large model from https://github.com/ultralytics/yolov5
Arguments:
......@@ -89,10 +90,10 @@ def yolov5l(pretrained=False, channels=3, classes=80):
Returns:
pytorch model
"""
return create('yolov5l', pretrained, channels, classes)
return create('yolov5l', pretrained, channels, classes, autoshape)
def yolov5x(pretrained=False, channels=3, classes=80):
def yolov5x(pretrained=False, channels=3, classes=80, autoshape=True):
"""YOLOv5-xlarge model from https://github.com/ultralytics/yolov5
Arguments:
......@@ -103,10 +104,10 @@ def yolov5x(pretrained=False, channels=3, classes=80):
Returns:
pytorch model
"""
return create('yolov5x', pretrained, channels, classes)
return create('yolov5x', pretrained, channels, classes, autoshape)
def custom(path_or_model='path/to/model.pt'):
def custom(path_or_model='path/to/model.pt', autoshape=True):
"""YOLOv5-custom model from https://github.com/ultralytics/yolov5
Arguments (3 options):
......@@ -124,13 +125,12 @@ def custom(path_or_model='path/to/model.pt'):
hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
hub_model.load_state_dict(model.float().state_dict()) # load state_dict
hub_model.names = model.names # class names
return hub_model
return hub_model.autoshape() if autoshape else hub_model
if __name__ == '__main__':
model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # pretrained example
model = create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True) # pretrained example
# model = custom(path_or_model='path/to/model.pt') # custom example
model = model.autoshape() # for PIL/cv2/np inputs and NMS
# Verify inference
from PIL import Image
......
......@@ -2,6 +2,7 @@
import math
import numpy as np
import requests
import torch
import torch.nn as nn
from PIL import Image, ImageDraw
......@@ -143,35 +144,42 @@ class autoShape(nn.Module):
super(autoShape, self).__init__()
self.model = model.eval()
def autoshape(self):
print('autoShape already enabled, skipping... ') # model already converted to model.autoshape()
return self
def forward(self, imgs, size=640, augment=False, profile=False):
# supports inference from various sources. For height=720, width=1280, RGB images example inputs are:
# opencv: imgs = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
# PIL: imgs = Image.open('image.jpg') # HWC x(720,1280,3)
# numpy: imgs = np.zeros((720,1280,3)) # HWC
# torch: imgs = torch.zeros(16,3,720,1280) # BCHW
# multiple: imgs = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
# Inference from various sources. For height=720, width=1280, RGB images example inputs are:
# filename: imgs = 'data/samples/zidane.jpg'
# URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
# PIL: = Image.open('image.jpg') # HWC x(720,1280,3)
# numpy: = np.zeros((720,1280,3)) # HWC
# torch: = torch.zeros(16,3,720,1280) # BCHW
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
p = next(self.model.parameters()) # for device and type
if isinstance(imgs, torch.Tensor): # torch
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
# Pre-process
if not isinstance(imgs, list):
imgs = [imgs]
n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
shape0, shape1 = [], [] # image and inference shapes
batch = range(len(imgs)) # batch size
for i in batch:
imgs[i] = np.array(imgs[i]) # to numpy
if imgs[i].shape[0] < 5: # image in CHW
imgs[i] = imgs[i].transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
imgs[i] = imgs[i][:, :, :3] if imgs[i].ndim == 3 else np.tile(imgs[i][:, :, None], 3) # enforce 3ch input
s = imgs[i].shape[:2] # HWC
for i, im in enumerate(imgs):
if isinstance(im, str): # filename or uri
im = Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im) # open
im = np.array(im) # to numpy
if im.shape[0] < 5: # image in CHW
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input
s = im.shape[:2] # HWC
shape0.append(s) # image shape
g = (size / max(s)) # gain
shape1.append([y * g for y in s])
imgs[i] = im # update
shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
x = [letterbox(imgs[i], new_shape=shape1, auto=False)[0] for i in batch] # pad
x = np.stack(x, 0) if batch[-1] else x[0][None] # stack
x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
x = np.stack(x, 0) if n > 1 else x[0][None] # stack
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
......@@ -181,7 +189,7 @@ class autoShape(nn.Module):
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
# Post-process
for i in batch:
for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i])
return Detections(imgs, y, self.names)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论