Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
Y
yolov5
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
Administrator
yolov5
Commits
8056fe2d
提交
8056fe2d
authored
8月 01, 2020
作者:
Glenn Jocher
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
hyperparameter evolution bug fix (#566)
上级
61b5733c
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
41 行增加
和
21 行删除
+41
-21
train.py
train.py
+41
-21
没有找到文件。
train.py
浏览文件 @
8056fe2d
...
...
@@ -16,8 +16,7 @@ from utils.datasets import *
from
utils.utils
import
*
# Hyperparameters
hyp
=
{
'optimizer'
:
'SGD'
,
# ['Adam', 'SGD', ...] from torch.optim
'lr0'
:
0.01
,
# initial learning rate (SGD=1E-2, Adam=1E-3)
hyp
=
{
'lr0'
:
0.01
,
# initial learning rate (SGD=1E-2, Adam=1E-3)
'momentum'
:
0.937
,
# SGD momentum/Adam beta1
'weight_decay'
:
5e-4
,
# optimizer weight decay
'giou'
:
0.05
,
# GIoU loss gain
...
...
@@ -41,7 +40,7 @@ hyp = {'optimizer': 'SGD', # ['Adam', 'SGD', ...] from torch.optim
'mixup'
:
0.0
}
# image mixup (probability)
def
train
(
hyp
,
tb_writer
,
opt
,
devic
e
):
def
train
(
hyp
,
opt
,
device
,
tb_writer
=
Non
e
):
print
(
f
'Hyperparameters {hyp}'
)
log_dir
=
tb_writer
.
log_dir
if
tb_writer
else
'runs/evolution'
# run directory
wdir
=
str
(
Path
(
log_dir
)
/
'weights'
)
+
os
.
sep
# weights directory
...
...
@@ -102,7 +101,7 @@ def train(hyp, tb_writer, opt, device):
else
:
pg0
.
append
(
v
)
# all else
if
hyp
[
'optimizer'
]
==
'Adam'
:
if
opt
.
adam
:
optimizer
=
optim
.
Adam
(
pg0
,
lr
=
hyp
[
'lr0'
],
betas
=
(
hyp
[
'momentum'
],
0.999
))
# adjust beta1 to momentum
else
:
optimizer
=
optim
.
SGD
(
pg0
,
lr
=
hyp
[
'lr0'
],
momentum
=
hyp
[
'momentum'
],
nesterov
=
True
)
...
...
@@ -279,7 +278,7 @@ def train(hyp, tb_writer, opt, device):
imgs
=
F
.
interpolate
(
imgs
,
size
=
ns
,
mode
=
'bilinear'
,
align_corners
=
False
)
# Autocast
with
amp
.
autocast
():
with
amp
.
autocast
(
enabled
=
cuda
):
# Forward
pred
=
model
(
imgs
)
...
...
@@ -402,11 +401,11 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--data'
,
type
=
str
,
default
=
'data/coco128.yaml'
,
help
=
'data.yaml path'
)
parser
.
add_argument
(
'--hyp'
,
type
=
str
,
default
=
''
,
help
=
'hyp.yaml path (optional)'
)
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
default
=
300
)
parser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
16
,
help
=
"Total batch size for all gpus."
)
parser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
16
,
help
=
'total batch size for all GPUs'
)
parser
.
add_argument
(
'--img-size'
,
nargs
=
'+'
,
type
=
int
,
default
=
[
640
,
640
],
help
=
'train,test sizes'
)
parser
.
add_argument
(
'--rect'
,
action
=
'store_true'
,
help
=
'rectangular training'
)
parser
.
add_argument
(
'--resume'
,
nargs
=
'?'
,
const
=
'get_last'
,
default
=
False
,
help
=
'resume from given path/
to/last.pt, or most recent run if blank.
'
)
help
=
'resume from given path/
last.pt, or most recent run if blank
'
)
parser
.
add_argument
(
'--nosave'
,
action
=
'store_true'
,
help
=
'only save final checkpoint'
)
parser
.
add_argument
(
'--notest'
,
action
=
'store_true'
,
help
=
'only test final epoch'
)
parser
.
add_argument
(
'--noautoanchor'
,
action
=
'store_true'
,
help
=
'disable autoanchor check'
)
...
...
@@ -418,6 +417,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--device'
,
default
=
''
,
help
=
'cuda device, i.e. 0 or 0,1,2,3 or cpu'
)
parser
.
add_argument
(
'--multi-scale'
,
action
=
'store_true'
,
help
=
'vary img-size +/- 50
%%
'
)
parser
.
add_argument
(
'--single-cls'
,
action
=
'store_true'
,
help
=
'train as single-class dataset'
)
parser
.
add_argument
(
'--adam'
,
action
=
'store_true'
,
help
=
'use torch.optim.Adam() optimizer'
)
parser
.
add_argument
(
'--sync-bn'
,
action
=
'store_true'
,
help
=
'use SyncBatchNorm, only available in DDP mode'
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=-
1
,
help
=
'DDP parameter, do not modify'
)
opt
=
parser
.
parse_args
()
...
...
@@ -445,30 +445,52 @@ if __name__ == '__main__':
if
opt
.
local_rank
!=
-
1
:
assert
torch
.
cuda
.
device_count
()
>
opt
.
local_rank
torch
.
cuda
.
set_device
(
opt
.
local_rank
)
device
=
torch
.
device
(
"cuda"
,
opt
.
local_rank
)
device
=
torch
.
device
(
'cuda'
,
opt
.
local_rank
)
dist
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
'env://'
)
# distributed backend
opt
.
world_size
=
dist
.
get_world_size
()
assert
opt
.
batch_size
%
opt
.
world_size
==
0
,
"Batch size is not a multiple of the number of devices given!"
assert
opt
.
batch_size
%
opt
.
world_size
==
0
,
'--batch-size must be multiple of CUDA device count'
opt
.
batch_size
=
opt
.
total_batch_size
//
opt
.
world_size
print
(
opt
)
# Train
if
not
opt
.
evolve
:
tb_writer
=
None
if
opt
.
local_rank
in
[
-
1
,
0
]:
print
(
'Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/'
)
tb_writer
=
SummaryWriter
(
log_dir
=
increment_dir
(
'runs/exp'
,
opt
.
name
))
else
:
tb_writer
=
None
train
(
hyp
,
tb_writer
,
opt
,
device
)
train
(
hyp
,
opt
,
device
,
tb_writer
)
# Evolve hyperparameters (optional)
else
:
assert
opt
.
local_rank
==
-
1
,
'DDP mode not implemented for --evolve'
# Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
meta
=
{
'lr0'
:
(
1
,
1e-5
,
1e-2
),
# initial learning rate (SGD=1E-2, Adam=1E-3)
'momentum'
:
(
0.1
,
0.6
,
0.98
),
# SGD momentum/Adam beta1
'weight_decay'
:
(
1
,
0.0
,
0.001
),
# optimizer weight decay
'giou'
:
(
1
,
0.02
,
0.2
),
# GIoU loss gain
'cls'
:
(
1
,
0.2
,
4.0
),
# cls loss gain
'cls_pw'
:
(
1
,
0.5
,
2.0
),
# cls BCELoss positive_weight
'obj'
:
(
1
,
0.2
,
4.0
),
# obj loss gain (scale with pixels)
'obj_pw'
:
(
1
,
0.5
,
2.0
),
# obj BCELoss positive_weight
'iou_t'
:
(
0
,
0.1
,
0.7
),
# IoU training threshold
'anchor_t'
:
(
1
,
2.0
,
8.0
),
# anchor-multiple threshold
'fl_gamma'
:
(
0
,
0.0
,
2.0
),
# focal loss gamma (efficientDet default gamma=1.5)
'hsv_h'
:
(
1
,
0.0
,
0.1
),
# image HSV-Hue augmentation (fraction)
'hsv_s'
:
(
1
,
0.0
,
0.8
),
# image HSV-Saturation augmentation (fraction)
'hsv_v'
:
(
1
,
0.0
,
0.8
),
# image HSV-Value augmentation (fraction)
'degrees'
:
(
1
,
0.0
,
45.0
),
# image rotation (+/- deg)
'translate'
:
(
1
,
0.0
,
0.9
),
# image translation (+/- fraction)
'scale'
:
(
1
,
0.0
,
0.9
),
# image scale (+/- gain)
'shear'
:
(
1
,
0.0
,
10.0
),
# image shear (+/- deg)
'perspective'
:
(
1
,
0.0
,
0.001
),
# image perspective (+/- fraction), range 0-0.001
'flipud'
:
(
0
,
0.0
,
1.0
),
# image flip up-down (probability)
'fliplr'
:
(
1
,
0.0
,
1.0
),
# image flip left-right (probability)
'mixup'
:
(
1
,
0.0
,
1.0
)}
# image mixup (probability)
tb_writer
=
None
assert
opt
.
local_rank
==
-
1
,
'DDP mode not implemented for --evolve'
opt
.
notest
,
opt
.
nosave
=
True
,
True
# only test/save final epoch
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
if
opt
.
bucket
:
os
.
system
(
'gsutil cp gs://
%
s/evolve.txt .'
%
opt
.
bucket
)
# download evolve.txt if exists
...
...
@@ -490,8 +512,8 @@ if __name__ == '__main__':
mp
,
s
=
0.9
,
0.2
# mutation probability, sigma
npr
=
np
.
random
npr
.
seed
(
int
(
time
.
time
()))
g
=
np
.
array
([
1
,
1
,
1
,
1
,
1
,
1
,
1
,
0
,
.
1
,
1
,
0
,
1
,
1
,
1
,
1
,
1
,
1
,
1
])
# gains
ng
=
len
(
g
)
g
=
np
.
array
([
x
[
0
]
for
x
in
meta
.
values
()])
# gains 0-1
ng
=
len
(
meta
)
v
=
np
.
ones
(
ng
)
while
all
(
v
==
1
):
# mutate until a change occurs (prevent duplicates)
v
=
(
g
*
(
npr
.
random
(
ng
)
<
mp
)
*
npr
.
randn
(
ng
)
*
npr
.
random
()
*
s
+
1
)
.
clip
(
0.3
,
3.0
)
...
...
@@ -499,13 +521,11 @@ if __name__ == '__main__':
hyp
[
k
]
=
x
[
i
+
7
]
*
v
[
i
]
# mutate
# Clip to limits
keys
=
[
'lr0'
,
'iou_t'
,
'momentum'
,
'weight_decay'
,
'hsv_s'
,
'hsv_v'
,
'translate'
,
'scale'
,
'fl_gamma'
]
limits
=
[(
1e-5
,
1e-2
),
(
0.00
,
0.70
),
(
0.60
,
0.98
),
(
0
,
0.001
),
(
0
,
.
9
),
(
0
,
.
9
),
(
0
,
.
9
),
(
0
,
.
9
),
(
0
,
3
)]
for
k
,
v
in
zip
(
keys
,
limits
):
hyp
[
k
]
=
np
.
clip
(
hyp
[
k
],
v
[
0
],
v
[
1
])
for
k
,
v
in
meta
.
items
():
hyp
[
k
]
=
np
.
clip
(
hyp
[
k
],
v
[
1
],
v
[
2
])
# Train mutation
results
=
train
(
hyp
.
copy
(),
tb_writer
,
opt
,
device
)
results
=
train
(
hyp
.
copy
(),
opt
,
device
)
# Write mutation results
print_mutation
(
hyp
,
results
,
opt
.
bucket
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论