Unverified 提交 3b57cb56 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Simplified inference (#1153)

上级 c67e7220
......@@ -149,8 +149,8 @@ if __name__ == '__main__':
parser.add_argument('--source', type=str, default='inference/images', help='source') # file/folder, 0 for webcam
parser.add_argument('--output', type=str, default='inference/output', help='output folder') # output folder
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--conf-thres', type=float, default=0.4, help='object confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS')
parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--view-img', action='store_true', help='display results')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
......
......@@ -10,7 +10,6 @@ import os
import torch
from models.common import NMS
from models.yolo import Model
from utils.google_utils import attempt_download
......@@ -36,9 +35,7 @@ def create(name, pretrained, channels, classes):
state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].float().state_dict() # to FP32
state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
model.load_state_dict(state_dict, strict=False) # load
model.add_nms() # add NMS module
model.eval()
# model = model.autoshape() # cv2/PIL/np/torch inference: predictions = model(Image.open('image.jpg'))
return model
except Exception as e:
......
# This file contains modules common to various models
import math
import math
import numpy as np
import torch
import torch.nn as nn
from utils.general import non_max_suppression
from utils.datasets import letterbox
from utils.general import non_max_suppression, make_divisible, scale_coords
def autopad(k, p=None): # kernel, padding
......@@ -101,17 +104,68 @@ class Concat(nn.Module):
class NMS(nn.Module):
# Non-Maximum Suppression (NMS) module
conf = 0.3 # confidence threshold
iou = 0.6 # IoU threshold
conf = 0.25 # confidence threshold
iou = 0.45 # IoU threshold
classes = None # (optional list) filter by class
def __init__(self, dimension=1):
def __init__(self):
super(NMS, self).__init__()
def forward(self, x):
return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
class autoShape(nn.Module):
# input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
img_size = 640 # inference size (pixels)
conf = 0.25 # NMS confidence threshold
iou = 0.45 # NMS IoU threshold
classes = None # (optional list) filter by class
def __init__(self, model):
super(autoShape, self).__init__()
self.model = model
def forward(self, x, size=640, augment=False, profile=False):
# supports inference from various sources. For height=720, width=1280, RGB images example inputs are:
# opencv: x = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
# PIL: x = Image.open('image.jpg') # HWC x(720,1280,3)
# numpy: x = np.zeros((720,1280,3)) # HWC
# torch: x = torch.zeros(16,3,720,1280) # BCHW
# multiple: x = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
p = next(self.model.parameters()) # for device and type
if isinstance(x, torch.Tensor): # torch
return self.model(x.to(p.device).type_as(p), augment, profile) # inference
# Pre-process
if not isinstance(x, list):
x = [x]
shape0, shape1 = [], [] # image and inference shapes
batch = range(len(x)) # batch size
for i in batch:
x[i] = np.array(x[i])[:, :, :3] # up to 3 channels if png
s = x[i].shape[:2] # HWC
shape0.append(s) # image shape
g = (size / max(s)) # gain
shape1.append([y * g for y in s])
shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
x = [letterbox(x[i], new_shape=shape1, auto=False)[0] for i in batch] # pad
x = np.stack(x, 0) if batch[-1] else x[0][None] # stack
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
# Inference
x = self.model(x, augment, profile) # forward
x = non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
# Post-process
for i in batch:
if x[i] is not None:
x[i][:, :4] = scale_coords(shape1, x[i][:, :4], shape0[i])
return x
class Flatten(nn.Module):
# Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions
@staticmethod
......
import argparse
import logging
import math
import sys
from copy import deepcopy
from pathlib import Path
import math
sys.path.append('./') # to run '$ python *.py' files in subdirectories
logger = logging.getLogger(__name__)
import torch
import torch.nn as nn
from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS
from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS, autoShape
from models.experimental import MixConv2d, CrossConv, C3
from utils.general import check_anchor_order, make_divisible, check_file, set_logging
from utils.torch_utils import (
time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, select_device)
from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
select_device, copy_attr
class Detect(nn.Module):
......@@ -140,6 +141,7 @@ class Model(nn.Module):
return x
def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
# https://arxiv.org/abs/1708.02002 section 3.3
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
m = self.model[-1] # Detect() module
for mi, s in zip(m.m, m.stride): # from
......@@ -170,15 +172,26 @@ class Model(nn.Module):
self.info()
return self
def add_nms(self): # fuse model Conv2d() + BatchNorm2d() layers
if type(self.model[-1]) is not NMS: # if missing NMS
print('Adding NMS module... ')
def nms(self, mode=True): # add or remove NMS module
present = type(self.model[-1]) is NMS # last layer is NMS
if mode and not present:
print('Adding NMS... ')
m = NMS() # module
m.f = -1 # from
m.i = self.model[-1].i + 1 # index
self.model.add_module(name='%s' % m.i, module=m) # add
self.eval()
elif not mode and present:
print('Removing NMS... ')
self.model = self.model[:-1] # remove
return self
def autoshape(self): # add autoShape module
print('Adding autoShape... ')
m = autoShape(self) # wrap model
copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
return m
def info(self, verbose=False): # print model information
model_info(self, verbose)
......@@ -263,10 +276,6 @@ if __name__ == '__main__':
# img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
# y = model(img, profile=True)
# ONNX export
# model.model[-1].export = True
# torch.onnx.export(model, img, opt.cfg.replace('.yaml', '.onnx'), verbose=True, opset_version=11)
# Tensorboard
# from torch.utils.tensorboard import SummaryWriter
# tb_writer = SummaryWriter()
......
import argparse
import glob
import json
import os
import shutil
from pathlib import Path
......@@ -8,19 +7,17 @@ from pathlib import Path
import numpy as np
import torch
import yaml
from sotabencheval.object_detection import COCOEvaluator
from sotabencheval.utils import is_server
from tqdm import tqdm
from models.experimental import attempt_load
from utils.datasets import create_dataloader
from utils.general import (
coco80_to_coco91_class, check_dataset, check_file, check_img_size, compute_loss, non_max_suppression, scale_coords,
xyxy2xywh, clip_coords, plot_images, xywh2xyxy, box_iou, output_to_target, ap_per_class, set_logging)
xyxy2xywh, clip_coords, set_logging)
from utils.torch_utils import select_device, time_synchronized
from sotabencheval.object_detection import COCOEvaluator
from sotabencheval.utils import is_server
DATA_ROOT = './.data/vision/coco' if is_server() else '../coco' # sotabench data dir
......
import glob
import math
import os
import random
import shutil
......@@ -8,6 +7,7 @@ from pathlib import Path
from threading import Thread
import cv2
import math
import numpy as np
import torch
from PIL import Image, ExifTags
......
import logging
import math
import os
import time
from copy import deepcopy
import math
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论