Unverified 提交 3b57cb56 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Simplified inference (#1153)

上级 c67e7220
...@@ -149,8 +149,8 @@ if __name__ == '__main__': ...@@ -149,8 +149,8 @@ if __name__ == '__main__':
parser.add_argument('--source', type=str, default='inference/images', help='source') # file/folder, 0 for webcam parser.add_argument('--source', type=str, default='inference/images', help='source') # file/folder, 0 for webcam
parser.add_argument('--output', type=str, default='inference/output', help='output folder') # output folder parser.add_argument('--output', type=str, default='inference/output', help='output folder') # output folder
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--conf-thres', type=float, default=0.4, help='object confidence threshold') parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS') parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--view-img', action='store_true', help='display results') parser.add_argument('--view-img', action='store_true', help='display results')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
......
...@@ -10,7 +10,6 @@ import os ...@@ -10,7 +10,6 @@ import os
import torch import torch
from models.common import NMS
from models.yolo import Model from models.yolo import Model
from utils.google_utils import attempt_download from utils.google_utils import attempt_download
...@@ -36,9 +35,7 @@ def create(name, pretrained, channels, classes): ...@@ -36,9 +35,7 @@ def create(name, pretrained, channels, classes):
state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].float().state_dict() # to FP32 state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].float().state_dict() # to FP32
state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
model.load_state_dict(state_dict, strict=False) # load model.load_state_dict(state_dict, strict=False) # load
# model = model.autoshape() # cv2/PIL/np/torch inference: predictions = model(Image.open('image.jpg'))
model.add_nms() # add NMS module
model.eval()
return model return model
except Exception as e: except Exception as e:
......
# This file contains modules common to various models # This file contains modules common to various models
import math
import math
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from utils.general import non_max_suppression
from utils.datasets import letterbox
from utils.general import non_max_suppression, make_divisible, scale_coords
def autopad(k, p=None): # kernel, padding def autopad(k, p=None): # kernel, padding
...@@ -101,17 +104,68 @@ class Concat(nn.Module): ...@@ -101,17 +104,68 @@ class Concat(nn.Module):
class NMS(nn.Module): class NMS(nn.Module):
# Non-Maximum Suppression (NMS) module # Non-Maximum Suppression (NMS) module
conf = 0.3 # confidence threshold conf = 0.25 # confidence threshold
iou = 0.6 # IoU threshold iou = 0.45 # IoU threshold
classes = None # (optional list) filter by class classes = None # (optional list) filter by class
def __init__(self, dimension=1): def __init__(self):
super(NMS, self).__init__() super(NMS, self).__init__()
def forward(self, x): def forward(self, x):
return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
class autoShape(nn.Module):
# input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
img_size = 640 # inference size (pixels)
conf = 0.25 # NMS confidence threshold
iou = 0.45 # NMS IoU threshold
classes = None # (optional list) filter by class
def __init__(self, model):
super(autoShape, self).__init__()
self.model = model
def forward(self, x, size=640, augment=False, profile=False):
# supports inference from various sources. For height=720, width=1280, RGB images example inputs are:
# opencv: x = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
# PIL: x = Image.open('image.jpg') # HWC x(720,1280,3)
# numpy: x = np.zeros((720,1280,3)) # HWC
# torch: x = torch.zeros(16,3,720,1280) # BCHW
# multiple: x = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
p = next(self.model.parameters()) # for device and type
if isinstance(x, torch.Tensor): # torch
return self.model(x.to(p.device).type_as(p), augment, profile) # inference
# Pre-process
if not isinstance(x, list):
x = [x]
shape0, shape1 = [], [] # image and inference shapes
batch = range(len(x)) # batch size
for i in batch:
x[i] = np.array(x[i])[:, :, :3] # up to 3 channels if png
s = x[i].shape[:2] # HWC
shape0.append(s) # image shape
g = (size / max(s)) # gain
shape1.append([y * g for y in s])
shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
x = [letterbox(x[i], new_shape=shape1, auto=False)[0] for i in batch] # pad
x = np.stack(x, 0) if batch[-1] else x[0][None] # stack
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
# Inference
x = self.model(x, augment, profile) # forward
x = non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
# Post-process
for i in batch:
if x[i] is not None:
x[i][:, :4] = scale_coords(shape1, x[i][:, :4], shape0[i])
return x
class Flatten(nn.Module): class Flatten(nn.Module):
# Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions # Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions
@staticmethod @staticmethod
......
import argparse import argparse
import logging import logging
import math
import sys import sys
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
import math
sys.path.append('./') # to run '$ python *.py' files in subdirectories sys.path.append('./') # to run '$ python *.py' files in subdirectories
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
import torch import torch
import torch.nn as nn import torch.nn as nn
from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS, autoShape
from models.experimental import MixConv2d, CrossConv, C3 from models.experimental import MixConv2d, CrossConv, C3
from utils.general import check_anchor_order, make_divisible, check_file, set_logging from utils.general import check_anchor_order, make_divisible, check_file, set_logging
from utils.torch_utils import ( from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, select_device) select_device, copy_attr
class Detect(nn.Module): class Detect(nn.Module):
...@@ -140,6 +141,7 @@ class Model(nn.Module): ...@@ -140,6 +141,7 @@ class Model(nn.Module):
return x return x
def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
# https://arxiv.org/abs/1708.02002 section 3.3
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
m = self.model[-1] # Detect() module m = self.model[-1] # Detect() module
for mi, s in zip(m.m, m.stride): # from for mi, s in zip(m.m, m.stride): # from
...@@ -170,15 +172,26 @@ class Model(nn.Module): ...@@ -170,15 +172,26 @@ class Model(nn.Module):
self.info() self.info()
return self return self
def add_nms(self): # fuse model Conv2d() + BatchNorm2d() layers def nms(self, mode=True): # add or remove NMS module
if type(self.model[-1]) is not NMS: # if missing NMS present = type(self.model[-1]) is NMS # last layer is NMS
print('Adding NMS module... ') if mode and not present:
print('Adding NMS... ')
m = NMS() # module m = NMS() # module
m.f = -1 # from m.f = -1 # from
m.i = self.model[-1].i + 1 # index m.i = self.model[-1].i + 1 # index
self.model.add_module(name='%s' % m.i, module=m) # add self.model.add_module(name='%s' % m.i, module=m) # add
self.eval()
elif not mode and present:
print('Removing NMS... ')
self.model = self.model[:-1] # remove
return self return self
def autoshape(self): # add autoShape module
print('Adding autoShape... ')
m = autoShape(self) # wrap model
copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
return m
def info(self, verbose=False): # print model information def info(self, verbose=False): # print model information
model_info(self, verbose) model_info(self, verbose)
...@@ -263,10 +276,6 @@ if __name__ == '__main__': ...@@ -263,10 +276,6 @@ if __name__ == '__main__':
# img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device) # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
# y = model(img, profile=True) # y = model(img, profile=True)
# ONNX export
# model.model[-1].export = True
# torch.onnx.export(model, img, opt.cfg.replace('.yaml', '.onnx'), verbose=True, opset_version=11)
# Tensorboard # Tensorboard
# from torch.utils.tensorboard import SummaryWriter # from torch.utils.tensorboard import SummaryWriter
# tb_writer = SummaryWriter() # tb_writer = SummaryWriter()
......
import argparse import argparse
import glob import glob
import json
import os import os
import shutil import shutil
from pathlib import Path from pathlib import Path
...@@ -8,19 +7,17 @@ from pathlib import Path ...@@ -8,19 +7,17 @@ from pathlib import Path
import numpy as np import numpy as np
import torch import torch
import yaml import yaml
from sotabencheval.object_detection import COCOEvaluator
from sotabencheval.utils import is_server
from tqdm import tqdm from tqdm import tqdm
from models.experimental import attempt_load from models.experimental import attempt_load
from utils.datasets import create_dataloader from utils.datasets import create_dataloader
from utils.general import ( from utils.general import (
coco80_to_coco91_class, check_dataset, check_file, check_img_size, compute_loss, non_max_suppression, scale_coords, coco80_to_coco91_class, check_dataset, check_file, check_img_size, compute_loss, non_max_suppression, scale_coords,
xyxy2xywh, clip_coords, plot_images, xywh2xyxy, box_iou, output_to_target, ap_per_class, set_logging) xyxy2xywh, clip_coords, set_logging)
from utils.torch_utils import select_device, time_synchronized from utils.torch_utils import select_device, time_synchronized
from sotabencheval.object_detection import COCOEvaluator
from sotabencheval.utils import is_server
DATA_ROOT = './.data/vision/coco' if is_server() else '../coco' # sotabench data dir DATA_ROOT = './.data/vision/coco' if is_server() else '../coco' # sotabench data dir
......
import glob import glob
import math
import os import os
import random import random
import shutil import shutil
...@@ -8,6 +7,7 @@ from pathlib import Path ...@@ -8,6 +7,7 @@ from pathlib import Path
from threading import Thread from threading import Thread
import cv2 import cv2
import math
import numpy as np import numpy as np
import torch import torch
from PIL import Image, ExifTags from PIL import Image, ExifTags
......
import logging import logging
import math
import os import os
import time import time
from copy import deepcopy from copy import deepcopy
import math
import torch import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import torch.nn as nn import torch.nn as nn
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论