Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
Y
yolov5
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
Administrator
yolov5
Commits
f5429260
Unverified
提交
f5429260
authored
11月 16, 2020
作者:
Glenn Jocher
提交者:
GitHub
11月 16, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
PyTorch Hub and autoShape update (#1415)
* PyTorch Hub and autoShape update * comment x for imgs * reduce comment
上级
92c9b728
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
84 行增加
和
35 行删除
+84
-35
detect.py
detect.py
+1
-1
hubconf.py
hubconf.py
+8
-8
common.py
models/common.py
+71
-22
test.py
test.py
+1
-1
general.py
utils/general.py
+3
-3
没有找到文件。
detect.py
浏览文件 @
f5429260
...
@@ -89,7 +89,7 @@ def detect(save_img=False):
...
@@ -89,7 +89,7 @@ def detect(save_img=False):
txt_path
=
str
(
save_dir
/
'labels'
/
p
.
stem
)
+
(
'_
%
g'
%
dataset
.
frame
if
dataset
.
mode
==
'video'
else
''
)
txt_path
=
str
(
save_dir
/
'labels'
/
p
.
stem
)
+
(
'_
%
g'
%
dataset
.
frame
if
dataset
.
mode
==
'video'
else
''
)
s
+=
'
%
gx
%
g '
%
img
.
shape
[
2
:]
# print string
s
+=
'
%
gx
%
g '
%
img
.
shape
[
2
:]
# print string
gn
=
torch
.
tensor
(
im0
.
shape
)[[
1
,
0
,
1
,
0
]]
# normalization gain whwh
gn
=
torch
.
tensor
(
im0
.
shape
)[[
1
,
0
,
1
,
0
]]
# normalization gain whwh
if
det
is
not
None
and
len
(
det
):
if
len
(
det
):
# Rescale boxes from img_size to im0 size
# Rescale boxes from img_size to im0 size
det
[:,
:
4
]
=
scale_coords
(
img
.
shape
[
2
:],
det
[:,
:
4
],
im0
.
shape
)
.
round
()
det
[:,
:
4
]
=
scale_coords
(
img
.
shape
[
2
:],
det
[:,
:
4
],
im0
.
shape
)
.
round
()
...
...
hubconf.py
浏览文件 @
f5429260
...
@@ -5,15 +5,16 @@ Usage:
...
@@ -5,15 +5,16 @@ Usage:
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, channels=3, classes=80)
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, channels=3, classes=80)
"""
"""
dependencies
=
[
'torch'
,
'yaml'
]
from
pathlib
import
Path
from
pathlib
import
Path
import
torch
import
torch
from
PIL
import
Image
from
models.yolo
import
Model
from
models.yolo
import
Model
from
utils.general
import
set_logging
from
utils.general
import
set_logging
from
utils.google_utils
import
attempt_download
from
utils.google_utils
import
attempt_download
dependencies
=
[
'torch'
,
'yaml'
,
'pillow'
]
set_logging
()
set_logging
()
...
@@ -41,7 +42,7 @@ def create(name, pretrained, channels, classes):
...
@@ -41,7 +42,7 @@ def create(name, pretrained, channels, classes):
model
.
load_state_dict
(
state_dict
,
strict
=
False
)
# load
model
.
load_state_dict
(
state_dict
,
strict
=
False
)
# load
if
len
(
ckpt
[
'model'
]
.
names
)
==
classes
:
if
len
(
ckpt
[
'model'
]
.
names
)
==
classes
:
model
.
names
=
ckpt
[
'model'
]
.
names
# set class names attribute
model
.
names
=
ckpt
[
'model'
]
.
names
# set class names attribute
# model = model.autoshape() # for
autoshaping of
PIL/cv2/np inputs and NMS
# model = model.autoshape() # for PIL/cv2/np inputs and NMS
return
model
return
model
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -108,11 +109,10 @@ def yolov5x(pretrained=False, channels=3, classes=80):
...
@@ -108,11 +109,10 @@ def yolov5x(pretrained=False, channels=3, classes=80):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
model
=
create
(
name
=
'yolov5s'
,
pretrained
=
True
,
channels
=
3
,
classes
=
80
)
# example
model
=
create
(
name
=
'yolov5s'
,
pretrained
=
True
,
channels
=
3
,
classes
=
80
)
# example
model
=
model
.
fuse
()
.
eval
()
.
autoshape
()
# for autoshaping of
PIL/cv2/np inputs and NMS
model
=
model
.
fuse
()
.
autoshape
()
# for
PIL/cv2/np inputs and NMS
# Verify inference
# Verify inference
from
PIL
import
Image
imgs
=
[
Image
.
open
(
x
)
for
x
in
Path
(
'data/images'
)
.
glob
(
'*.jpg'
)]
results
=
model
(
imgs
)
img
=
Image
.
open
(
'data/images/zidane.jpg'
)
results
.
show
()
y
=
model
(
img
)
results
.
print
()
print
(
y
[
0
]
.
shape
)
models/common.py
浏览文件 @
f5429260
...
@@ -5,9 +5,11 @@ import math
...
@@ -5,9 +5,11 @@ import math
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
PIL
import
Image
,
ImageDraw
from
utils.datasets
import
letterbox
from
utils.datasets
import
letterbox
from
utils.general
import
non_max_suppression
,
make_divisible
,
scale_coords
from
utils.general
import
non_max_suppression
,
make_divisible
,
scale_coords
,
xyxy2xywh
from
utils.plots
import
color_list
def
autopad
(
k
,
p
=
None
):
# kernel, padding
def
autopad
(
k
,
p
=
None
):
# kernel, padding
...
@@ -125,47 +127,94 @@ class autoShape(nn.Module):
...
@@ -125,47 +127,94 @@ class autoShape(nn.Module):
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
super
(
autoShape
,
self
)
.
__init__
()
super
(
autoShape
,
self
)
.
__init__
()
self
.
model
=
model
self
.
model
=
model
.
eval
()
def
forward
(
self
,
x
,
size
=
640
,
augment
=
False
,
profile
=
False
):
def
forward
(
self
,
imgs
,
size
=
640
,
augment
=
False
,
profile
=
False
):
# supports inference from various sources. For height=720, width=1280, RGB images example inputs are:
# supports inference from various sources. For height=720, width=1280, RGB images example inputs are:
# opencv:
x
= cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
# opencv:
imgs
= cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
# PIL:
x
= Image.open('image.jpg') # HWC x(720,1280,3)
# PIL:
imgs
= Image.open('image.jpg') # HWC x(720,1280,3)
# numpy:
x
= np.zeros((720,1280,3)) # HWC
# numpy:
imgs
= np.zeros((720,1280,3)) # HWC
# torch:
x
= torch.zeros(16,3,720,1280) # BCHW
# torch:
imgs
= torch.zeros(16,3,720,1280) # BCHW
# multiple:
x
= [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
# multiple:
imgs
= [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
p
=
next
(
self
.
model
.
parameters
())
# for device and type
p
=
next
(
self
.
model
.
parameters
())
# for device and type
if
isinstance
(
x
,
torch
.
Tensor
):
# torch
if
isinstance
(
imgs
,
torch
.
Tensor
):
# torch
return
self
.
model
(
x
.
to
(
p
.
device
)
.
type_as
(
p
),
augment
,
profile
)
# inference
return
self
.
model
(
imgs
.
to
(
p
.
device
)
.
type_as
(
p
),
augment
,
profile
)
# inference
# Pre-process
# Pre-process
if
not
isinstance
(
x
,
list
):
if
not
isinstance
(
imgs
,
list
):
x
=
[
x
]
imgs
=
[
imgs
]
shape0
,
shape1
=
[],
[]
# image and inference shapes
shape0
,
shape1
=
[],
[]
# image and inference shapes
batch
=
range
(
len
(
x
))
# batch size
batch
=
range
(
len
(
imgs
))
# batch size
for
i
in
batch
:
for
i
in
batch
:
x
[
i
]
=
np
.
array
(
x
[
i
])
# to numpy
imgs
[
i
]
=
np
.
array
(
imgs
[
i
])
# to numpy
x
[
i
]
=
x
[
i
][:,
:,
:
3
]
if
x
[
i
]
.
ndim
==
3
else
np
.
tile
(
x
[
i
][:,
:,
None
],
3
)
# enforce 3ch input
imgs
[
i
]
=
imgs
[
i
][:,
:,
:
3
]
if
imgs
[
i
]
.
ndim
==
3
else
np
.
tile
(
imgs
[
i
][:,
:,
None
],
3
)
# enforce 3ch input
s
=
x
[
i
]
.
shape
[:
2
]
# HWC
s
=
imgs
[
i
]
.
shape
[:
2
]
# HWC
shape0
.
append
(
s
)
# image shape
shape0
.
append
(
s
)
# image shape
g
=
(
size
/
max
(
s
))
# gain
g
=
(
size
/
max
(
s
))
# gain
shape1
.
append
([
y
*
g
for
y
in
s
])
shape1
.
append
([
y
*
g
for
y
in
s
])
shape1
=
[
make_divisible
(
x
,
int
(
self
.
stride
.
max
()))
for
x
in
np
.
stack
(
shape1
,
0
)
.
max
(
0
)]
# inference shape
shape1
=
[
make_divisible
(
x
,
int
(
self
.
stride
.
max
()))
for
x
in
np
.
stack
(
shape1
,
0
)
.
max
(
0
)]
# inference shape
x
=
[
letterbox
(
x
[
i
],
new_shape
=
shape1
,
auto
=
False
)[
0
]
for
i
in
batch
]
# pad
x
=
[
letterbox
(
imgs
[
i
],
new_shape
=
shape1
,
auto
=
False
)[
0
]
for
i
in
batch
]
# pad
x
=
np
.
stack
(
x
,
0
)
if
batch
[
-
1
]
else
x
[
0
][
None
]
# stack
x
=
np
.
stack
(
x
,
0
)
if
batch
[
-
1
]
else
x
[
0
][
None
]
# stack
x
=
np
.
ascontiguousarray
(
x
.
transpose
((
0
,
3
,
1
,
2
)))
# BHWC to BCHW
x
=
np
.
ascontiguousarray
(
x
.
transpose
((
0
,
3
,
1
,
2
)))
# BHWC to BCHW
x
=
torch
.
from_numpy
(
x
)
.
to
(
p
.
device
)
.
type_as
(
p
)
/
255.
# uint8 to fp16/32
x
=
torch
.
from_numpy
(
x
)
.
to
(
p
.
device
)
.
type_as
(
p
)
/
255.
# uint8 to fp16/32
# Inference
# Inference
x
=
self
.
model
(
x
,
augment
,
profile
)
# forward
with
torch
.
no_grad
():
x
=
non_max_suppression
(
x
[
0
],
conf_thres
=
self
.
conf
,
iou_thres
=
self
.
iou
,
classes
=
self
.
classes
)
# NMS
y
=
self
.
model
(
x
,
augment
,
profile
)[
0
]
# forward
y
=
non_max_suppression
(
y
,
conf_thres
=
self
.
conf
,
iou_thres
=
self
.
iou
,
classes
=
self
.
classes
)
# NMS
# Post-process
# Post-process
for
i
in
batch
:
for
i
in
batch
:
if
x
[
i
]
is
not
None
:
if
y
[
i
]
is
not
None
:
x
[
i
][:,
:
4
]
=
scale_coords
(
shape1
,
x
[
i
][:,
:
4
],
shape0
[
i
])
y
[
i
][:,
:
4
]
=
scale_coords
(
shape1
,
y
[
i
][:,
:
4
],
shape0
[
i
])
return
x
return
Detections
(
imgs
,
y
,
self
.
names
)
class
Detections
:
# detections class for YOLOv5 inference results
def
__init__
(
self
,
imgs
,
pred
,
names
=
None
):
super
(
Detections
,
self
)
.
__init__
()
self
.
imgs
=
imgs
# list of images as numpy arrays
self
.
pred
=
pred
# list of tensors pred[0] = (xyxy, conf, cls)
self
.
names
=
names
# class names
self
.
xyxy
=
pred
# xyxy pixels
self
.
xywh
=
[
xyxy2xywh
(
x
)
for
x
in
pred
]
# xywh pixels
gn
=
[
torch
.
Tensor
([
*
[
im
.
shape
[
i
]
for
i
in
[
1
,
0
,
1
,
0
]],
1.
,
1.
])
for
im
in
imgs
]
# normalization gains
self
.
xyxyn
=
[
x
/
g
for
x
,
g
in
zip
(
self
.
xyxy
,
gn
)]
# xyxy normalized
self
.
xywhn
=
[
x
/
g
for
x
,
g
in
zip
(
self
.
xywh
,
gn
)]
# xywh normalized
def
display
(
self
,
pprint
=
False
,
show
=
False
,
save
=
False
):
colors
=
color_list
()
for
i
,
(
img
,
pred
)
in
enumerate
(
zip
(
self
.
imgs
,
self
.
pred
)):
str
=
f
'Image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
if
pred
is
not
None
:
for
c
in
pred
[:,
-
1
]
.
unique
():
n
=
(
pred
[:,
-
1
]
==
c
)
.
sum
()
# detections per class
str
+=
f
'{n} {self.names[int(c)]}s, '
# add to string
if
show
or
save
:
img
=
Image
.
fromarray
(
img
.
astype
(
np
.
uint8
))
if
isinstance
(
img
,
np
.
ndarray
)
else
img
# from np
for
*
box
,
conf
,
cls
in
pred
:
# xyxy, confidence, class
# str += '%s %.2f, ' % (names[int(cls)], conf) # label
ImageDraw
.
Draw
(
img
)
.
rectangle
(
box
,
width
=
4
,
outline
=
colors
[
int
(
cls
)
%
10
])
# plot
if
save
:
f
=
f
'results{i}.jpg'
str
+=
f
"saved to '{f}'"
img
.
save
(
f
)
# save
if
show
:
img
.
show
(
f
'Image {i}'
)
# show
if
pprint
:
print
(
str
)
def
print
(
self
):
self
.
display
(
pprint
=
True
)
# print results
def
show
(
self
):
self
.
display
(
show
=
True
)
# show results
def
save
(
self
):
self
.
display
(
save
=
True
)
# save results
class
Flatten
(
nn
.
Module
):
class
Flatten
(
nn
.
Module
):
...
...
test.py
浏览文件 @
f5429260
...
@@ -126,7 +126,7 @@ def test(data,
...
@@ -126,7 +126,7 @@ def test(data,
tcls
=
labels
[:,
0
]
.
tolist
()
if
nl
else
[]
# target class
tcls
=
labels
[:,
0
]
.
tolist
()
if
nl
else
[]
# target class
seen
+=
1
seen
+=
1
if
pred
is
None
:
if
len
(
pred
)
==
0
:
if
nl
:
if
nl
:
stats
.
append
((
torch
.
zeros
(
0
,
niou
,
dtype
=
torch
.
bool
),
torch
.
Tensor
(),
torch
.
Tensor
(),
tcls
))
stats
.
append
((
torch
.
zeros
(
0
,
niou
,
dtype
=
torch
.
bool
),
torch
.
Tensor
(),
torch
.
Tensor
(),
tcls
))
continue
continue
...
...
utils/general.py
浏览文件 @
f5429260
...
@@ -142,7 +142,7 @@ def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
...
@@ -142,7 +142,7 @@ def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
def
xyxy2xywh
(
x
):
def
xyxy2xywh
(
x
):
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
y
=
torch
.
zeros_like
(
x
)
if
isinstance
(
x
,
torch
.
Tensor
)
else
np
.
zeros_like
(
x
)
y
=
x
.
clone
()
if
isinstance
(
x
,
torch
.
Tensor
)
else
np
.
copy
(
x
)
y
[:,
0
]
=
(
x
[:,
0
]
+
x
[:,
2
])
/
2
# x center
y
[:,
0
]
=
(
x
[:,
0
]
+
x
[:,
2
])
/
2
# x center
y
[:,
1
]
=
(
x
[:,
1
]
+
x
[:,
3
])
/
2
# y center
y
[:,
1
]
=
(
x
[:,
1
]
+
x
[:,
3
])
/
2
# y center
y
[:,
2
]
=
x
[:,
2
]
-
x
[:,
0
]
# width
y
[:,
2
]
=
x
[:,
2
]
-
x
[:,
0
]
# width
...
@@ -152,7 +152,7 @@ def xyxy2xywh(x):
...
@@ -152,7 +152,7 @@ def xyxy2xywh(x):
def
xywh2xyxy
(
x
):
def
xywh2xyxy
(
x
):
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
y
=
torch
.
zeros_like
(
x
)
if
isinstance
(
x
,
torch
.
Tensor
)
else
np
.
zeros_like
(
x
)
y
=
x
.
clone
()
if
isinstance
(
x
,
torch
.
Tensor
)
else
np
.
copy
(
x
)
y
[:,
0
]
=
x
[:,
0
]
-
x
[:,
2
]
/
2
# top left x
y
[:,
0
]
=
x
[:,
0
]
-
x
[:,
2
]
/
2
# top left x
y
[:,
1
]
=
x
[:,
1
]
-
x
[:,
3
]
/
2
# top left y
y
[:,
1
]
=
x
[:,
1
]
-
x
[:,
3
]
/
2
# top left y
y
[:,
2
]
=
x
[:,
0
]
+
x
[:,
2
]
/
2
# bottom right x
y
[:,
2
]
=
x
[:,
0
]
+
x
[:,
2
]
/
2
# bottom right x
...
@@ -280,7 +280,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False,
...
@@ -280,7 +280,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False,
multi_label
=
nc
>
1
# multiple labels per box (adds 0.5ms/img)
multi_label
=
nc
>
1
# multiple labels per box (adds 0.5ms/img)
t
=
time
.
time
()
t
=
time
.
time
()
output
=
[
None
]
*
prediction
.
shape
[
0
]
output
=
[
torch
.
zeros
(
0
,
6
)
]
*
prediction
.
shape
[
0
]
for
xi
,
x
in
enumerate
(
prediction
):
# image index, image inference
for
xi
,
x
in
enumerate
(
prediction
):
# image index, image inference
# Apply constraints
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论