Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
Y
yolov5
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
Administrator
yolov5
Commits
3b57cb56
Unverified
提交
3b57cb56
authored
10月 15, 2020
作者:
Glenn Jocher
提交者:
GitHub
10月 15, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Simplified inference (#1153)
上级
c67e7220
隐藏空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
87 行增加
和
30 行删除
+87
-30
detect.py
detect.py
+2
-2
hubconf.py
hubconf.py
+1
-4
common.py
models/common.py
+59
-5
yolo.py
models/yolo.py
+20
-11
sotabench.py
sotabench.py
+3
-6
datasets.py
utils/datasets.py
+1
-1
torch_utils.py
utils/torch_utils.py
+1
-1
没有找到文件。
detect.py
浏览文件 @
3b57cb56
...
...
@@ -149,8 +149,8 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--source'
,
type
=
str
,
default
=
'inference/images'
,
help
=
'source'
)
# file/folder, 0 for webcam
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
'inference/output'
,
help
=
'output folder'
)
# output folder
parser
.
add_argument
(
'--img-size'
,
type
=
int
,
default
=
640
,
help
=
'inference size (pixels)'
)
parser
.
add_argument
(
'--conf-thres'
,
type
=
float
,
default
=
0.
4
,
help
=
'object confidence threshold'
)
parser
.
add_argument
(
'--iou-thres'
,
type
=
float
,
default
=
0.5
,
help
=
'IOU threshold for NMS'
)
parser
.
add_argument
(
'--conf-thres'
,
type
=
float
,
default
=
0.
25
,
help
=
'object confidence threshold'
)
parser
.
add_argument
(
'--iou-thres'
,
type
=
float
,
default
=
0.
4
5
,
help
=
'IOU threshold for NMS'
)
parser
.
add_argument
(
'--device'
,
default
=
''
,
help
=
'cuda device, i.e. 0 or 0,1,2,3 or cpu'
)
parser
.
add_argument
(
'--view-img'
,
action
=
'store_true'
,
help
=
'display results'
)
parser
.
add_argument
(
'--save-txt'
,
action
=
'store_true'
,
help
=
'save results to *.txt'
)
...
...
hubconf.py
浏览文件 @
3b57cb56
...
...
@@ -10,7 +10,6 @@ import os
import
torch
from
models.common
import
NMS
from
models.yolo
import
Model
from
utils.google_utils
import
attempt_download
...
...
@@ -36,9 +35,7 @@ def create(name, pretrained, channels, classes):
state_dict
=
torch
.
load
(
ckpt
,
map_location
=
torch
.
device
(
'cpu'
))[
'model'
]
.
float
()
.
state_dict
()
# to FP32
state_dict
=
{
k
:
v
for
k
,
v
in
state_dict
.
items
()
if
model
.
state_dict
()[
k
]
.
shape
==
v
.
shape
}
# filter
model
.
load_state_dict
(
state_dict
,
strict
=
False
)
# load
model
.
add_nms
()
# add NMS module
model
.
eval
()
# model = model.autoshape() # cv2/PIL/np/torch inference: predictions = model(Image.open('image.jpg'))
return
model
except
Exception
as
e
:
...
...
models/common.py
浏览文件 @
3b57cb56
# This file contains modules common to various models
import
math
import
math
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
utils.general
import
non_max_suppression
from
utils.datasets
import
letterbox
from
utils.general
import
non_max_suppression
,
make_divisible
,
scale_coords
def
autopad
(
k
,
p
=
None
):
# kernel, padding
...
...
@@ -101,17 +104,68 @@ class Concat(nn.Module):
class
NMS
(
nn
.
Module
):
# Non-Maximum Suppression (NMS) module
conf
=
0.
3
# confidence threshold
iou
=
0.
6
# IoU threshold
conf
=
0.
25
# confidence threshold
iou
=
0.
45
# IoU threshold
classes
=
None
# (optional list) filter by class
def
__init__
(
self
,
dimension
=
1
):
def
__init__
(
self
):
super
(
NMS
,
self
)
.
__init__
()
def
forward
(
self
,
x
):
return
non_max_suppression
(
x
[
0
],
conf_thres
=
self
.
conf
,
iou_thres
=
self
.
iou
,
classes
=
self
.
classes
)
class
autoShape
(
nn
.
Module
):
# input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
img_size
=
640
# inference size (pixels)
conf
=
0.25
# NMS confidence threshold
iou
=
0.45
# NMS IoU threshold
classes
=
None
# (optional list) filter by class
def
__init__
(
self
,
model
):
super
(
autoShape
,
self
)
.
__init__
()
self
.
model
=
model
def
forward
(
self
,
x
,
size
=
640
,
augment
=
False
,
profile
=
False
):
# supports inference from various sources. For height=720, width=1280, RGB images example inputs are:
# opencv: x = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
# PIL: x = Image.open('image.jpg') # HWC x(720,1280,3)
# numpy: x = np.zeros((720,1280,3)) # HWC
# torch: x = torch.zeros(16,3,720,1280) # BCHW
# multiple: x = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
p
=
next
(
self
.
model
.
parameters
())
# for device and type
if
isinstance
(
x
,
torch
.
Tensor
):
# torch
return
self
.
model
(
x
.
to
(
p
.
device
)
.
type_as
(
p
),
augment
,
profile
)
# inference
# Pre-process
if
not
isinstance
(
x
,
list
):
x
=
[
x
]
shape0
,
shape1
=
[],
[]
# image and inference shapes
batch
=
range
(
len
(
x
))
# batch size
for
i
in
batch
:
x
[
i
]
=
np
.
array
(
x
[
i
])[:,
:,
:
3
]
# up to 3 channels if png
s
=
x
[
i
]
.
shape
[:
2
]
# HWC
shape0
.
append
(
s
)
# image shape
g
=
(
size
/
max
(
s
))
# gain
shape1
.
append
([
y
*
g
for
y
in
s
])
shape1
=
[
make_divisible
(
x
,
int
(
self
.
stride
.
max
()))
for
x
in
np
.
stack
(
shape1
,
0
)
.
max
(
0
)]
# inference shape
x
=
[
letterbox
(
x
[
i
],
new_shape
=
shape1
,
auto
=
False
)[
0
]
for
i
in
batch
]
# pad
x
=
np
.
stack
(
x
,
0
)
if
batch
[
-
1
]
else
x
[
0
][
None
]
# stack
x
=
np
.
ascontiguousarray
(
x
.
transpose
((
0
,
3
,
1
,
2
)))
# BHWC to BCHW
x
=
torch
.
from_numpy
(
x
)
.
to
(
p
.
device
)
.
type_as
(
p
)
/
255.
# uint8 to fp16/32
# Inference
x
=
self
.
model
(
x
,
augment
,
profile
)
# forward
x
=
non_max_suppression
(
x
[
0
],
conf_thres
=
self
.
conf
,
iou_thres
=
self
.
iou
,
classes
=
self
.
classes
)
# NMS
# Post-process
for
i
in
batch
:
if
x
[
i
]
is
not
None
:
x
[
i
][:,
:
4
]
=
scale_coords
(
shape1
,
x
[
i
][:,
:
4
],
shape0
[
i
])
return
x
class
Flatten
(
nn
.
Module
):
# Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions
@staticmethod
...
...
models/yolo.py
浏览文件 @
3b57cb56
import
argparse
import
logging
import
math
import
sys
from
copy
import
deepcopy
from
pathlib
import
Path
import
math
sys
.
path
.
append
(
'./'
)
# to run '$ python *.py' files in subdirectories
logger
=
logging
.
getLogger
(
__name__
)
import
torch
import
torch.nn
as
nn
from
models.common
import
Conv
,
Bottleneck
,
SPP
,
DWConv
,
Focus
,
BottleneckCSP
,
Concat
,
NMS
from
models.common
import
Conv
,
Bottleneck
,
SPP
,
DWConv
,
Focus
,
BottleneckCSP
,
Concat
,
NMS
,
autoShape
from
models.experimental
import
MixConv2d
,
CrossConv
,
C3
from
utils.general
import
check_anchor_order
,
make_divisible
,
check_file
,
set_logging
from
utils.torch_utils
import
(
time_synchronized
,
fuse_conv_and_bn
,
model_info
,
scale_img
,
initialize_weights
,
select_device
)
from
utils.torch_utils
import
time_synchronized
,
fuse_conv_and_bn
,
model_info
,
scale_img
,
initialize_weights
,
\
select_device
,
copy_attr
class
Detect
(
nn
.
Module
):
...
...
@@ -140,6 +141,7 @@ class Model(nn.Module):
return
x
def
_initialize_biases
(
self
,
cf
=
None
):
# initialize biases into Detect(), cf is class frequency
# https://arxiv.org/abs/1708.02002 section 3.3
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
m
=
self
.
model
[
-
1
]
# Detect() module
for
mi
,
s
in
zip
(
m
.
m
,
m
.
stride
):
# from
...
...
@@ -170,15 +172,26 @@ class Model(nn.Module):
self
.
info
()
return
self
def
add_nms
(
self
):
# fuse model Conv2d() + BatchNorm2d() layers
if
type
(
self
.
model
[
-
1
])
is
not
NMS
:
# if missing NMS
print
(
'Adding NMS module... '
)
def
nms
(
self
,
mode
=
True
):
# add or remove NMS module
present
=
type
(
self
.
model
[
-
1
])
is
NMS
# last layer is NMS
if
mode
and
not
present
:
print
(
'Adding NMS... '
)
m
=
NMS
()
# module
m
.
f
=
-
1
# from
m
.
i
=
self
.
model
[
-
1
]
.
i
+
1
# index
self
.
model
.
add_module
(
name
=
'
%
s'
%
m
.
i
,
module
=
m
)
# add
self
.
eval
()
elif
not
mode
and
present
:
print
(
'Removing NMS... '
)
self
.
model
=
self
.
model
[:
-
1
]
# remove
return
self
def
autoshape
(
self
):
# add autoShape module
print
(
'Adding autoShape... '
)
m
=
autoShape
(
self
)
# wrap model
copy_attr
(
m
,
self
,
include
=
(
'yaml'
,
'nc'
,
'hyp'
,
'names'
,
'stride'
),
exclude
=
())
# copy attributes
return
m
def
info
(
self
,
verbose
=
False
):
# print model information
model_info
(
self
,
verbose
)
...
...
@@ -263,10 +276,6 @@ if __name__ == '__main__':
# img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
# y = model(img, profile=True)
# ONNX export
# model.model[-1].export = True
# torch.onnx.export(model, img, opt.cfg.replace('.yaml', '.onnx'), verbose=True, opset_version=11)
# Tensorboard
# from torch.utils.tensorboard import SummaryWriter
# tb_writer = SummaryWriter()
...
...
sotabench.py
浏览文件 @
3b57cb56
import
argparse
import
glob
import
json
import
os
import
shutil
from
pathlib
import
Path
...
...
@@ -8,19 +7,17 @@ from pathlib import Path
import
numpy
as
np
import
torch
import
yaml
from
sotabencheval.object_detection
import
COCOEvaluator
from
sotabencheval.utils
import
is_server
from
tqdm
import
tqdm
from
models.experimental
import
attempt_load
from
utils.datasets
import
create_dataloader
from
utils.general
import
(
coco80_to_coco91_class
,
check_dataset
,
check_file
,
check_img_size
,
compute_loss
,
non_max_suppression
,
scale_coords
,
xyxy2xywh
,
clip_coords
,
plot_images
,
xywh2xyxy
,
box_iou
,
output_to_target
,
ap_per_class
,
set_logging
)
xyxy2xywh
,
clip_coords
,
set_logging
)
from
utils.torch_utils
import
select_device
,
time_synchronized
from
sotabencheval.object_detection
import
COCOEvaluator
from
sotabencheval.utils
import
is_server
DATA_ROOT
=
'./.data/vision/coco'
if
is_server
()
else
'../coco'
# sotabench data dir
...
...
utils/datasets.py
浏览文件 @
3b57cb56
import
glob
import
math
import
os
import
random
import
shutil
...
...
@@ -8,6 +7,7 @@ from pathlib import Path
from
threading
import
Thread
import
cv2
import
math
import
numpy
as
np
import
torch
from
PIL
import
Image
,
ExifTags
...
...
utils/torch_utils.py
浏览文件 @
3b57cb56
import
logging
import
math
import
os
import
time
from
copy
import
deepcopy
import
math
import
torch
import
torch.backends.cudnn
as
cudnn
import
torch.nn
as
nn
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论