Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
Y
yolov5
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
Administrator
yolov5
Commits
f1c63e27
提交
f1c63e27
authored
9月 13, 2020
作者:
Glenn Jocher
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add mosaic and warmup to hyperparameters (#931)
上级
806e75f2
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
22 行增加
和
9 行删除
+22
-9
hyp.finetune.yaml
data/hyp.finetune.yaml
+4
-0
hyp.scratch.yaml
data/hyp.scratch.yaml
+4
-0
train.py
train.py
+11
-7
datasets.py
utils/datasets.py
+3
-2
没有找到文件。
data/hyp.finetune.yaml
浏览文件 @
f1c63e27
...
...
@@ -12,6 +12,9 @@ lr0: 0.0032
lrf
:
0.12
momentum
:
0.843
weight_decay
:
0.00036
warmup_epochs
:
2.0
warmup_momentum
:
0.5
warmup_bias_lr
:
0.05
giou
:
0.0296
cls
:
0.243
cls_pw
:
0.631
...
...
@@ -31,4 +34,5 @@ shear: 0.602
perspective
:
0.0
flipud
:
0.00856
fliplr
:
0.5
mosaic
:
1.0
mixup
:
0.243
data/hyp.scratch.yaml
浏览文件 @
f1c63e27
...
...
@@ -7,6 +7,9 @@ lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
lrf
:
0.2
# final OneCycleLR learning rate (lr0 * lrf)
momentum
:
0.937
# SGD momentum/Adam beta1
weight_decay
:
0.0005
# optimizer weight decay 5e-4
warmup_epochs
:
3.0
# warmup epochs (fractions ok)
warmup_momentum
:
0.8
# warmup initial momentum
warmup_bias_lr
:
0.1
# warmup initial bias lr
giou
:
0.05
# box loss gain
cls
:
0.5
# cls loss gain
cls_pw
:
1.0
# cls BCELoss positive_weight
...
...
@@ -26,4 +29,5 @@ shear: 0.0 # image shear (+/- deg)
perspective
:
0.0
# image perspective (+/- fraction), range 0-0.001
flipud
:
0.0
# image flip up-down (probability)
fliplr
:
0.5
# image flip left-right (probability)
mosaic
:
1.0
# image mosaic (probability)
mixup
:
0.0
# image mixup (probability)
train.py
浏览文件 @
f1c63e27
...
...
@@ -202,7 +202,7 @@ def train(hyp, opt, device, tb_writer=None):
# Start training
t0
=
time
.
time
()
nw
=
max
(
3
*
nb
,
1e3
)
# number of warmup iterations, max(3 epochs, 1k iterations)
nw
=
max
(
round
(
hyp
[
'warmup_epochs'
]
*
nb
)
,
1e3
)
# number of warmup iterations, max(3 epochs, 1k iterations)
# nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
maps
=
np
.
zeros
(
nc
)
# mAP per class
results
=
(
0
,
0
,
0
,
0
,
0
,
0
,
0
)
# 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
...
...
@@ -250,9 +250,9 @@ def train(hyp, opt, device, tb_writer=None):
accumulate
=
max
(
1
,
np
.
interp
(
ni
,
xi
,
[
1
,
nbs
/
total_batch_size
])
.
round
())
for
j
,
x
in
enumerate
(
optimizer
.
param_groups
):
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
x
[
'lr'
]
=
np
.
interp
(
ni
,
xi
,
[
0.1
if
j
==
2
else
0.0
,
x
[
'initial_lr'
]
*
lf
(
epoch
)])
x
[
'lr'
]
=
np
.
interp
(
ni
,
xi
,
[
hyp
[
'warmup_bias_lr'
]
if
j
==
2
else
0.0
,
x
[
'initial_lr'
]
*
lf
(
epoch
)])
if
'momentum'
in
x
:
x
[
'momentum'
]
=
np
.
interp
(
ni
,
xi
,
[
0.9
,
hyp
[
'momentum'
]])
x
[
'momentum'
]
=
np
.
interp
(
ni
,
xi
,
[
hyp
[
'warmup_momentum'
]
,
hyp
[
'momentum'
]])
# Multi-scale
if
opt
.
multi_scale
:
...
...
@@ -460,8 +460,11 @@ if __name__ == '__main__':
# Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
meta
=
{
'lr0'
:
(
1
,
1e-5
,
1e-1
),
# initial learning rate (SGD=1E-2, Adam=1E-3)
'lrf'
:
(
1
,
0.01
,
1.0
),
# final OneCycleLR learning rate (lr0 * lrf)
'momentum'
:
(
0.
1
,
0.6
,
0.98
),
# SGD momentum/Adam beta1
'momentum'
:
(
0.
3
,
0.6
,
0.98
),
# SGD momentum/Adam beta1
'weight_decay'
:
(
1
,
0.0
,
0.001
),
# optimizer weight decay
'warmup_epochs'
:
(
1
,
0.0
,
5.0
),
# warmup epochs (fractions ok)
'warmup_momentum'
:
(
1
,
0.0
,
0.95
),
# warmup initial momentum
'warmup_bias_lr'
:
(
1
,
0.0
,
0.2
),
# warmup initial bias lr
'giou'
:
(
1
,
0.02
,
0.2
),
# GIoU loss gain
'cls'
:
(
1
,
0.2
,
4.0
),
# cls loss gain
'cls_pw'
:
(
1
,
0.5
,
2.0
),
# cls BCELoss positive_weight
...
...
@@ -469,7 +472,7 @@ if __name__ == '__main__':
'obj_pw'
:
(
1
,
0.5
,
2.0
),
# obj BCELoss positive_weight
'iou_t'
:
(
0
,
0.1
,
0.7
),
# IoU training threshold
'anchor_t'
:
(
1
,
2.0
,
8.0
),
# anchor-multiple threshold
'anchors'
:
(
1
,
2.0
,
10.0
),
# anchors per output grid (0 to ignore)
'anchors'
:
(
2
,
2.0
,
10.0
),
# anchors per output grid (0 to ignore)
'fl_gamma'
:
(
0
,
0.0
,
2.0
),
# focal loss gamma (efficientDet default gamma=1.5)
'hsv_h'
:
(
1
,
0.0
,
0.1
),
# image HSV-Hue augmentation (fraction)
'hsv_s'
:
(
1
,
0.0
,
0.9
),
# image HSV-Saturation augmentation (fraction)
...
...
@@ -481,6 +484,7 @@ if __name__ == '__main__':
'perspective'
:
(
0
,
0.0
,
0.001
),
# image perspective (+/- fraction), range 0-0.001
'flipud'
:
(
1
,
0.0
,
1.0
),
# image flip up-down (probability)
'fliplr'
:
(
0
,
0.0
,
1.0
),
# image flip left-right (probability)
'mosaic'
:
(
1
,
0.0
,
1.0
),
# image mixup (probability)
'mixup'
:
(
1
,
0.0
,
1.0
)}
# image mixup (probability)
assert
opt
.
local_rank
==
-
1
,
'DDP mode not implemented for --evolve'
...
...
@@ -490,7 +494,7 @@ if __name__ == '__main__':
if
opt
.
bucket
:
os
.
system
(
'gsutil cp gs://
%
s/evolve.txt .'
%
opt
.
bucket
)
# download evolve.txt if exists
for
_
in
range
(
1
):
# generations to evolve
for
_
in
range
(
300
):
# generations to evolve
if
os
.
path
.
exists
(
'evolve.txt'
):
# if evolve.txt exists: select best hyps and mutate
# Select parent(s)
parent
=
'single'
# parent selection method: 'single' or 'weighted'
...
...
@@ -505,7 +509,7 @@ if __name__ == '__main__':
x
=
(
x
*
w
.
reshape
(
n
,
1
))
.
sum
(
0
)
/
w
.
sum
()
# weighted combination
# Mutate
mp
,
s
=
0.
9
,
0.2
# mutation probability, sigma
mp
,
s
=
0.
8
,
0.2
# mutation probability, sigma
npr
=
np
.
random
npr
.
seed
(
int
(
time
.
time
()))
g
=
np
.
array
([
x
[
0
]
for
x
in
meta
.
values
()])
# gains 0-1
...
...
utils/datasets.py
浏览文件 @
f1c63e27
...
...
@@ -516,7 +516,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
index
=
self
.
indices
[
index
]
hyp
=
self
.
hyp
if
self
.
mosaic
:
mosaic
=
self
.
mosaic
and
random
.
random
()
<
hyp
[
'mosaic'
]
if
mosaic
:
# Load mosaic
img
,
labels
=
load_mosaic
(
self
,
index
)
shapes
=
None
...
...
@@ -550,7 +551,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
if
self
.
augment
:
# Augment imagespace
if
not
self
.
mosaic
:
if
not
mosaic
:
img
,
labels
=
random_perspective
(
img
,
labels
,
degrees
=
hyp
[
'degrees'
],
translate
=
hyp
[
'translate'
],
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论