Unverified 提交 a18efc3a authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Add variable-stride inference support (#2091)

上级 aa02b948
......@@ -31,7 +31,8 @@ def detect(save_img=False):
# Load model
model = attempt_load(weights, map_location=device) # load FP32 model
imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
stride = int(model.stride.max()) # model stride
imgsz = check_img_size(imgsz, s=stride) # check img_size
if half:
model.half() # to FP16
......@@ -46,10 +47,10 @@ def detect(save_img=False):
if webcam:
view_img = True
cudnn.benchmark = True # set True to speed up constant image size inference
dataset = LoadStreams(source, img_size=imgsz)
dataset = LoadStreams(source, img_size=imgsz, stride=stride)
else:
save_img = True
dataset = LoadImages(source, img_size=imgsz)
dataset = LoadImages(source, img_size=imgsz, stride=stride)
# Get names and colors
names = model.module.names if hasattr(model, 'module') else model.names
......
......@@ -119,7 +119,7 @@ class _RepeatSampler(object):
class LoadImages: # for inference
def __init__(self, path, img_size=640):
def __init__(self, path, img_size=640, stride=32):
p = str(Path(path)) # os-agnostic
p = os.path.abspath(p) # absolute path
if '*' in p:
......@@ -136,6 +136,7 @@ class LoadImages: # for inference
ni, nv = len(images), len(videos)
self.img_size = img_size
self.stride = stride
self.files = images + videos
self.nf = ni + nv # number of files
self.video_flag = [False] * ni + [True] * nv
......@@ -181,7 +182,7 @@ class LoadImages: # for inference
print(f'image {self.count}/{self.nf} {path}: ', end='')
# Padded resize
img = letterbox(img0, new_shape=self.img_size)[0]
img = letterbox(img0, self.img_size, stride=self.stride)[0]
# Convert
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
......@@ -199,8 +200,9 @@ class LoadImages: # for inference
class LoadWebcam: # for inference
def __init__(self, pipe='0', img_size=640):
def __init__(self, pipe='0', img_size=640, stride=32):
self.img_size = img_size
self.stride = stride
if pipe.isnumeric():
pipe = eval(pipe) # local camera
......@@ -243,7 +245,7 @@ class LoadWebcam: # for inference
print(f'webcam {self.count}: ', end='')
# Padded resize
img = letterbox(img0, new_shape=self.img_size)[0]
img = letterbox(img0, self.img_size, stride=self.stride)[0]
# Convert
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
......@@ -256,9 +258,10 @@ class LoadWebcam: # for inference
class LoadStreams: # multiple IP or RTSP cameras
def __init__(self, sources='streams.txt', img_size=640):
def __init__(self, sources='streams.txt', img_size=640, stride=32):
self.mode = 'stream'
self.img_size = img_size
self.stride = stride
if os.path.isfile(sources):
with open(sources, 'r') as f:
......@@ -284,7 +287,7 @@ class LoadStreams: # multiple IP or RTSP cameras
print('') # newline
# check for common shapes
s = np.stack([letterbox(x, new_shape=self.img_size)[0].shape for x in self.imgs], 0) # inference shapes
s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
if not self.rect:
print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
......@@ -313,7 +316,7 @@ class LoadStreams: # multiple IP or RTSP cameras
raise StopIteration
# Letterbox
img = [letterbox(x, new_shape=self.img_size, auto=self.rect)[0] for x in img0]
img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]
# Stack
img = np.stack(img, 0)
......@@ -784,8 +787,8 @@ def replicate(img, labels):
return img, labels
def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
# Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
# Resize and pad image while meeting stride-multiple constraints
shape = img.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
......@@ -800,7 +803,7 @@ def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
if auto: # minimum rectangle
dw, dh = np.mod(dw, 32), np.mod(dh, 32) # wh padding
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
elif scaleFill: # stretch
dw, dh = 0.0, 0.0
new_unpad = (new_shape[1], new_shape[0])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论