Unverified 提交 a18efc3a authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Add variable-stride inference support (#2091)

上级 aa02b948
...@@ -31,7 +31,8 @@ def detect(save_img=False): ...@@ -31,7 +31,8 @@ def detect(save_img=False):
# Load model # Load model
model = attempt_load(weights, map_location=device) # load FP32 model model = attempt_load(weights, map_location=device) # load FP32 model
imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size stride = int(model.stride.max()) # model stride
imgsz = check_img_size(imgsz, s=stride) # check img_size
if half: if half:
model.half() # to FP16 model.half() # to FP16
...@@ -46,10 +47,10 @@ def detect(save_img=False): ...@@ -46,10 +47,10 @@ def detect(save_img=False):
if webcam: if webcam:
view_img = True view_img = True
cudnn.benchmark = True # set True to speed up constant image size inference cudnn.benchmark = True # set True to speed up constant image size inference
dataset = LoadStreams(source, img_size=imgsz) dataset = LoadStreams(source, img_size=imgsz, stride=stride)
else: else:
save_img = True save_img = True
dataset = LoadImages(source, img_size=imgsz) dataset = LoadImages(source, img_size=imgsz, stride=stride)
# Get names and colors # Get names and colors
names = model.module.names if hasattr(model, 'module') else model.names names = model.module.names if hasattr(model, 'module') else model.names
......
...@@ -119,7 +119,7 @@ class _RepeatSampler(object): ...@@ -119,7 +119,7 @@ class _RepeatSampler(object):
class LoadImages: # for inference class LoadImages: # for inference
def __init__(self, path, img_size=640): def __init__(self, path, img_size=640, stride=32):
p = str(Path(path)) # os-agnostic p = str(Path(path)) # os-agnostic
p = os.path.abspath(p) # absolute path p = os.path.abspath(p) # absolute path
if '*' in p: if '*' in p:
...@@ -136,6 +136,7 @@ class LoadImages: # for inference ...@@ -136,6 +136,7 @@ class LoadImages: # for inference
ni, nv = len(images), len(videos) ni, nv = len(images), len(videos)
self.img_size = img_size self.img_size = img_size
self.stride = stride
self.files = images + videos self.files = images + videos
self.nf = ni + nv # number of files self.nf = ni + nv # number of files
self.video_flag = [False] * ni + [True] * nv self.video_flag = [False] * ni + [True] * nv
...@@ -181,7 +182,7 @@ class LoadImages: # for inference ...@@ -181,7 +182,7 @@ class LoadImages: # for inference
print(f'image {self.count}/{self.nf} {path}: ', end='') print(f'image {self.count}/{self.nf} {path}: ', end='')
# Padded resize # Padded resize
img = letterbox(img0, new_shape=self.img_size)[0] img = letterbox(img0, self.img_size, stride=self.stride)[0]
# Convert # Convert
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
...@@ -199,8 +200,9 @@ class LoadImages: # for inference ...@@ -199,8 +200,9 @@ class LoadImages: # for inference
class LoadWebcam: # for inference class LoadWebcam: # for inference
def __init__(self, pipe='0', img_size=640): def __init__(self, pipe='0', img_size=640, stride=32):
self.img_size = img_size self.img_size = img_size
self.stride = stride
if pipe.isnumeric(): if pipe.isnumeric():
pipe = eval(pipe) # local camera pipe = eval(pipe) # local camera
...@@ -243,7 +245,7 @@ class LoadWebcam: # for inference ...@@ -243,7 +245,7 @@ class LoadWebcam: # for inference
print(f'webcam {self.count}: ', end='') print(f'webcam {self.count}: ', end='')
# Padded resize # Padded resize
img = letterbox(img0, new_shape=self.img_size)[0] img = letterbox(img0, self.img_size, stride=self.stride)[0]
# Convert # Convert
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
...@@ -256,9 +258,10 @@ class LoadWebcam: # for inference ...@@ -256,9 +258,10 @@ class LoadWebcam: # for inference
class LoadStreams: # multiple IP or RTSP cameras class LoadStreams: # multiple IP or RTSP cameras
def __init__(self, sources='streams.txt', img_size=640): def __init__(self, sources='streams.txt', img_size=640, stride=32):
self.mode = 'stream' self.mode = 'stream'
self.img_size = img_size self.img_size = img_size
self.stride = stride
if os.path.isfile(sources): if os.path.isfile(sources):
with open(sources, 'r') as f: with open(sources, 'r') as f:
...@@ -284,7 +287,7 @@ class LoadStreams: # multiple IP or RTSP cameras ...@@ -284,7 +287,7 @@ class LoadStreams: # multiple IP or RTSP cameras
print('') # newline print('') # newline
# check for common shapes # check for common shapes
s = np.stack([letterbox(x, new_shape=self.img_size)[0].shape for x in self.imgs], 0) # inference shapes s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
if not self.rect: if not self.rect:
print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.') print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
...@@ -313,7 +316,7 @@ class LoadStreams: # multiple IP or RTSP cameras ...@@ -313,7 +316,7 @@ class LoadStreams: # multiple IP or RTSP cameras
raise StopIteration raise StopIteration
# Letterbox # Letterbox
img = [letterbox(x, new_shape=self.img_size, auto=self.rect)[0] for x in img0] img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]
# Stack # Stack
img = np.stack(img, 0) img = np.stack(img, 0)
...@@ -784,8 +787,8 @@ def replicate(img, labels): ...@@ -784,8 +787,8 @@ def replicate(img, labels):
return img, labels return img, labels
def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True): def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
# Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232 # Resize and pad image while meeting stride-multiple constraints
shape = img.shape[:2] # current shape [height, width] shape = img.shape[:2] # current shape [height, width]
if isinstance(new_shape, int): if isinstance(new_shape, int):
new_shape = (new_shape, new_shape) new_shape = (new_shape, new_shape)
...@@ -800,7 +803,7 @@ def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale ...@@ -800,7 +803,7 @@ def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
if auto: # minimum rectangle if auto: # minimum rectangle
dw, dh = np.mod(dw, 32), np.mod(dh, 32) # wh padding dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
elif scaleFill: # stretch elif scaleFill: # stretch
dw, dh = 0.0, 0.0 dw, dh = 0.0, 0.0
new_unpad = (new_shape[1], new_shape[0]) new_unpad = (new_shape[1], new_shape[0])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论