Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
Y
yolov5
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
Administrator
yolov5
Commits
0c4b4b88
提交
0c4b4b88
authored
6月 16, 2020
作者:
Lornatang
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'upstream/master'
上级
09c1b961
db2c3acd
隐藏空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
38 行增加
和
20 行删除
+38
-20
detect.py
detect.py
+1
-1
hubconf.py
hubconf.py
+2
-2
test.py
test.py
+10
-8
train.py
train.py
+5
-5
activations.py
utils/activations.py
+1
-0
google_utils.py
utils/google_utils.py
+8
-3
utils.py
utils/utils.py
+11
-1
没有找到文件。
detect.py
浏览文件 @
0c4b4b88
...
...
@@ -18,7 +18,7 @@ def detect(save_img=False):
# Load model
google_utils
.
attempt_download
(
weights
)
model
=
torch
.
load
(
weights
,
map_location
=
device
)[
'model'
]
model
=
torch
.
load
(
weights
,
map_location
=
device
)[
'model'
]
.
float
()
# load to FP32
# torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning
# model.fuse()
model
.
to
(
device
)
.
eval
()
...
...
hubconf.py
浏览文件 @
0c4b4b88
...
...
@@ -32,8 +32,8 @@ def create(name, pretrained, channels, classes):
if
pretrained
:
ckpt
=
'
%
s.pt'
%
name
# checkpoint filename
google_utils
.
attempt_download
(
ckpt
)
# download if not found locally
state_dict
=
torch
.
load
(
ckpt
,
map_location
=
torch
.
device
(
'cpu'
))[
'model'
]
.
state_dict
()
state_dict
=
{
k
:
v
for
k
,
v
in
state_dict
.
items
()
if
model
.
state_dict
()[
k
]
.
numel
()
==
v
.
numel
()
}
# filter
state_dict
=
torch
.
load
(
ckpt
,
map_location
=
torch
.
device
(
'cpu'
))[
'model'
]
.
float
()
.
state_dict
()
# to FP32
state_dict
=
{
k
:
v
for
k
,
v
in
state_dict
.
items
()
if
model
.
state_dict
()[
k
]
.
shape
==
v
.
shape
}
# filter
model
.
load_state_dict
(
state_dict
,
strict
=
False
)
# load
return
model
...
...
test.py
浏览文件 @
0c4b4b88
...
...
@@ -23,6 +23,7 @@ def test(data,
verbose
=
False
):
# Initialize/load model and set device
if
model
is
None
:
training
=
False
device
=
torch_utils
.
select_device
(
opt
.
device
,
batch_size
=
batch_size
)
half
=
device
.
type
!=
'cpu'
# half precision only supported on CUDA
...
...
@@ -32,9 +33,9 @@ def test(data,
# Load model
google_utils
.
attempt_download
(
weights
)
model
=
torch
.
load
(
weights
,
map_location
=
device
)[
'model'
]
model
=
torch
.
load
(
weights
,
map_location
=
device
)[
'model'
]
.
float
()
# load to FP32
torch_utils
.
model_info
(
model
)
#
model.fuse()
model
.
fuse
()
model
.
to
(
device
)
if
half
:
model
.
half
()
# to FP16
...
...
@@ -42,11 +43,12 @@ def test(data,
if
device
.
type
!=
'cpu'
and
torch
.
cuda
.
device_count
()
>
1
:
model
=
nn
.
DataParallel
(
model
)
training
=
False
else
:
# called by train.py
device
=
next
(
model
.
parameters
())
.
device
# get model device
half
=
False
training
=
True
device
=
next
(
model
.
parameters
())
.
device
# get model device
half
=
device
.
type
!=
'cpu'
# half precision only supported on CUDA
if
half
:
model
.
half
()
# to FP16
# Configure
model
.
eval
()
...
...
@@ -69,7 +71,7 @@ def test(data,
batch_size
,
rect
=
True
,
# rectangular inference
single_cls
=
opt
.
single_cls
,
# single class mode
pad
=
0.
0
if
fast
else
0.
5
)
# padding
pad
=
0.5
)
# padding
batch_size
=
min
(
batch_size
,
len
(
dataset
))
nw
=
min
([
os
.
cpu_count
(),
batch_size
if
batch_size
>
1
else
0
,
8
])
# number of workers
dataloader
=
DataLoader
(
dataset
,
...
...
@@ -102,7 +104,7 @@ def test(data,
# Compute loss
if
training
:
# if model has loss hyperparameters
loss
+=
compute_loss
(
train_out
,
targets
,
model
)[
1
][:
3
]
# GIoU, obj, cls
loss
+=
compute_loss
(
[
x
.
float
()
for
x
in
train_out
]
,
targets
,
model
)[
1
][:
3
]
# GIoU, obj, cls
# Run NMS
t
=
torch_utils
.
time_synchronized
()
...
...
@@ -255,7 +257,7 @@ if __name__ == '__main__':
opt
=
parser
.
parse_args
()
opt
.
img_size
=
check_img_size
(
opt
.
img_size
)
opt
.
save_json
=
opt
.
save_json
or
opt
.
data
.
endswith
(
'coco.yaml'
)
opt
.
data
=
glob
.
glob
(
'./**/'
+
opt
.
data
,
recursive
=
True
)[
0
]
# find
file
opt
.
data
=
check_file
(
opt
.
data
)
# check
file
print
(
opt
)
# task = 'val', 'test', 'study'
...
...
train.py
浏览文件 @
0c4b4b88
...
...
@@ -112,8 +112,8 @@ def train(hyp):
# load model
try
:
ckpt
[
'model'
]
=
\
{
k
:
v
for
k
,
v
in
ckpt
[
'model'
]
.
state_dict
()
.
items
()
if
model
.
state_dict
()[
k
]
.
numel
()
==
v
.
numel
()}
ckpt
[
'model'
]
=
{
k
:
v
for
k
,
v
in
ckpt
[
'model'
]
.
float
()
.
state_dict
()
.
items
()
if
model
.
state_dict
()[
k
]
.
shape
==
v
.
shape
}
# to FP32, filter
model
.
load_state_dict
(
ckpt
[
'model'
],
strict
=
False
)
except
KeyError
as
e
:
s
=
"
%
s is not compatible with
%
s. Specify --weights '' or specify a --cfg compatible with
%
s."
\
...
...
@@ -363,6 +363,7 @@ def train(hyp):
if
__name__
==
'__main__'
:
check_git_status
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
default
=
300
)
parser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
16
)
...
...
@@ -384,12 +385,11 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--single-cls'
,
action
=
'store_true'
,
help
=
'train as single-class dataset'
)
opt
=
parser
.
parse_args
()
opt
.
weights
=
last
if
opt
.
resume
else
opt
.
weights
opt
.
cfg
=
glob
.
glob
(
'./**/'
+
opt
.
cfg
,
recursive
=
True
)[
0
]
# find
file
opt
.
data
=
glob
.
glob
(
'./**/'
+
opt
.
data
,
recursive
=
True
)[
0
]
# find
file
opt
.
cfg
=
check_file
(
opt
.
cfg
)
# check
file
opt
.
data
=
check_file
(
opt
.
data
)
# check
file
print
(
opt
)
opt
.
img_size
.
extend
([
opt
.
img_size
[
-
1
]]
*
(
2
-
len
(
opt
.
img_size
)))
# extend to 2 sizes (train, test)
device
=
torch_utils
.
select_device
(
opt
.
device
,
apex
=
mixed_precision
,
batch_size
=
opt
.
batch_size
)
# check_git_status()
if
device
.
type
==
'cpu'
:
mixed_precision
=
False
...
...
utils/activations.py
浏览文件 @
0c4b4b88
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn
as
nn
...
...
utils/google_utils.py
浏览文件 @
0c4b4b88
...
...
@@ -25,10 +25,15 @@ def attempt_download(weights):
if
file
in
d
:
r
=
gdrive_download
(
id
=
d
[
file
],
name
=
weights
)
# Error check
if
not
(
r
==
0
and
os
.
path
.
exists
(
weights
)
and
os
.
path
.
getsize
(
weights
)
>
1E6
):
# weights exist and > 1MB
os
.
system
(
'rm '
+
weights
)
# remove partial downloads
raise
Exception
(
msg
)
os
.
remove
(
weights
)
if
os
.
path
.
exists
(
weights
)
else
None
# remove partial downloads
s
=
"curl -L -o
%
s 'https://storage.googleapis.com/ultralytics/yolov5/ckpt/
%
s'"
%
(
weights
,
file
)
r
=
os
.
system
(
s
)
# execute, capture return values
# Error check
if
not
(
r
==
0
and
os
.
path
.
exists
(
weights
)
and
os
.
path
.
getsize
(
weights
)
>
1E6
):
# weights exist and > 1MB
os
.
remove
(
weights
)
if
os
.
path
.
exists
(
weights
)
else
None
# remove partial downloads
raise
Exception
(
msg
)
def
gdrive_download
(
id
=
'1HaXkef9z6y5l4vUnCYgdmEAj61c6bfWO'
,
name
=
'coco.zip'
):
...
...
utils/utils.py
浏览文件 @
0c4b4b88
...
...
@@ -64,6 +64,16 @@ def check_best_possible_recall(dataset, anchors, thr):
'Compute new anchors with utils.utils.kmeans_anchors() and update model before training.'
%
bpr
def
check_file
(
file
):
# Searches for file if not found locally
if
os
.
path
.
isfile
(
file
):
return
file
else
:
files
=
glob
.
glob
(
'./**/'
+
file
,
recursive
=
True
)
# find file
assert
len
(
files
),
'File Not Found:
%
s'
%
file
# assert file was found
return
files
[
0
]
# return first file if multiple found
def
make_divisible
(
x
,
divisor
):
# Returns x evenly divisble by divisor
return
math
.
ceil
(
x
/
divisor
)
*
divisor
...
...
@@ -518,7 +528,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c
fast
|=
conf_thres
>
0.001
# fast mode
if
fast
:
merge
=
False
multi_label
=
False
multi_label
=
nc
>
1
# multiple labels per box (adds 0.5ms/img)
else
:
merge
=
True
# merge for best mAP (adds 0.5ms/img)
multi_label
=
nc
>
1
# multiple labels per box (adds 0.5ms/img)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论