提交 c4cb7857 authored 作者: Glenn Jocher's avatar Glenn Jocher

add NMS to pretrained pytorch hub models

上级 5a9c5c1d
...@@ -10,6 +10,7 @@ import os ...@@ -10,6 +10,7 @@ import os
import torch import torch
from models.common import NMS
from models.yolo import Model from models.yolo import Model
from utils.google_utils import attempt_download from utils.google_utils import attempt_download
...@@ -35,6 +36,12 @@ def create(name, pretrained, channels, classes): ...@@ -35,6 +36,12 @@ def create(name, pretrained, channels, classes):
state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].float().state_dict() # to FP32 state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].float().state_dict() # to FP32
state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
model.load_state_dict(state_dict, strict=False) # load model.load_state_dict(state_dict, strict=False) # load
m = NMS()
m.f = -1 # from
m.i = model.model[-1].i + 1 # index
model.model.add_module(name='%s' % m.i, module=m) # add NMS
model.eval()
return model return model
except Exception as e: except Exception as e:
......
...@@ -3,6 +3,7 @@ import math ...@@ -3,6 +3,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from utils.general import non_max_suppression
def autopad(k, p=None): # kernel, padding def autopad(k, p=None): # kernel, padding
...@@ -98,6 +99,19 @@ class Concat(nn.Module): ...@@ -98,6 +99,19 @@ class Concat(nn.Module):
return torch.cat(x, self.d) return torch.cat(x, self.d)
class NMS(nn.Module):
# Non-Maximum Suppression (NMS) module
conf = 0.3 # confidence threshold
iou = 0.6 # IoU threshold
classes = None # (optional list) filter by class
def __init__(self, dimension=1):
super(NMS, self).__init__()
def forward(self, x):
return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
class Flatten(nn.Module): class Flatten(nn.Module):
# Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions # Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions
@staticmethod @staticmethod
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论