Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
Y
yolov5
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
Administrator
yolov5
Commits
127cbeb3
提交
127cbeb3
authored
8月 01, 2020
作者:
Glenn Jocher
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
hyperparameter expansion to flips, perspective, mixup
上级
6f08e8bc
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
32 行增加
和
32 行删除
+32
-32
train.py
train.py
+13
-10
datasets.py
utils/datasets.py
+19
-22
没有找到文件。
train.py
浏览文件 @
127cbeb3
...
@@ -16,25 +16,29 @@ from utils.datasets import *
...
@@ -16,25 +16,29 @@ from utils.datasets import *
from
utils.utils
import
*
from
utils.utils
import
*
# Hyperparameters
# Hyperparameters
hyp
=
{
'optimizer'
:
'SGD'
,
# ['
adam', 'SGD', None] if none, default is SGD
hyp
=
{
'optimizer'
:
'SGD'
,
# ['
Adam', 'SGD', ...] from torch.optim
'lr0'
:
0.01
,
# initial learning rate (SGD=1E-2, Adam=1E-3)
'lr0'
:
0.01
,
# initial learning rate (SGD=1E-2, Adam=1E-3)
'momentum'
:
0.937
,
# SGD momentum/Adam beta1
'momentum'
:
0.937
,
# SGD momentum/Adam beta1
'weight_decay'
:
5e-4
,
# optimizer weight decay
'weight_decay'
:
5e-4
,
# optimizer weight decay
'giou'
:
0.05
,
#
giou
loss gain
'giou'
:
0.05
,
#
GIoU
loss gain
'cls'
:
0.5
,
# cls loss gain
'cls'
:
0.5
,
# cls loss gain
'cls_pw'
:
1.0
,
# cls BCELoss positive_weight
'cls_pw'
:
1.0
,
# cls BCELoss positive_weight
'obj'
:
1.0
,
# obj loss gain (
*=img_size/320 if img_size != 320
)
'obj'
:
1.0
,
# obj loss gain (
scale with pixels
)
'obj_pw'
:
1.0
,
# obj BCELoss positive_weight
'obj_pw'
:
1.0
,
# obj BCELoss positive_weight
'iou_t'
:
0.20
,
#
iou
training threshold
'iou_t'
:
0.20
,
#
IoU
training threshold
'anchor_t'
:
4.0
,
# anchor-multiple threshold
'anchor_t'
:
4.0
,
# anchor-multiple threshold
'fl_gamma'
:
0.0
,
# focal loss gamma (efficientDet default
is
gamma=1.5)
'fl_gamma'
:
0.0
,
# focal loss gamma (efficientDet default gamma=1.5)
'hsv_h'
:
0.015
,
# image HSV-Hue augmentation (fraction)
'hsv_h'
:
0.015
,
# image HSV-Hue augmentation (fraction)
'hsv_s'
:
0.7
,
# image HSV-Saturation augmentation (fraction)
'hsv_s'
:
0.7
,
# image HSV-Saturation augmentation (fraction)
'hsv_v'
:
0.4
,
# image HSV-Value augmentation (fraction)
'hsv_v'
:
0.4
,
# image HSV-Value augmentation (fraction)
'degrees'
:
0.0
,
# image rotation (+/- deg)
'degrees'
:
0.0
,
# image rotation (+/- deg)
'translate'
:
0.5
,
# image translation (+/- fraction)
'translate'
:
0.5
,
# image translation (+/- fraction)
'scale'
:
0.5
,
# image scale (+/- gain)
'scale'
:
0.5
,
# image scale (+/- gain)
'shear'
:
0.0
}
# image shear (+/- deg)
'shear'
:
0.0
,
# image shear (+/- deg)
'perspective'
:
0.0
,
# image perspective (+/- fraction), range 0-0.001
'flipud'
:
0.0
,
# image flip up-down (probability)
'fliplr'
:
0.5
,
# image flip left-right (probability)
'mixup'
:
0.0
}
# image mixup (probability)
def
train
(
hyp
,
tb_writer
,
opt
,
device
):
def
train
(
hyp
,
tb_writer
,
opt
,
device
):
...
@@ -47,8 +51,7 @@ def train(hyp, tb_writer, opt, device):
...
@@ -47,8 +51,7 @@ def train(hyp, tb_writer, opt, device):
results_file
=
log_dir
+
os
.
sep
+
'results.txt'
results_file
=
log_dir
+
os
.
sep
+
'results.txt'
epochs
,
batch_size
,
total_batch_size
,
weights
,
rank
=
\
epochs
,
batch_size
,
total_batch_size
,
weights
,
rank
=
\
opt
.
epochs
,
opt
.
batch_size
,
opt
.
total_batch_size
,
opt
.
weights
,
opt
.
local_rank
opt
.
epochs
,
opt
.
batch_size
,
opt
.
total_batch_size
,
opt
.
weights
,
opt
.
local_rank
# TODO: Init DDP logging. Only the first process is allowed to log.
# TODO: Use DDP logging. Only the first process is allowed to log.
# Since I see lots of print here, the logging configuration is skipped here. We may see repeated outputs.
# Save run settings
# Save run settings
with
open
(
Path
(
log_dir
)
/
'hyp.yaml'
,
'w'
)
as
f
:
with
open
(
Path
(
log_dir
)
/
'hyp.yaml'
,
'w'
)
as
f
:
...
@@ -99,7 +102,7 @@ def train(hyp, tb_writer, opt, device):
...
@@ -99,7 +102,7 @@ def train(hyp, tb_writer, opt, device):
else
:
else
:
pg0
.
append
(
v
)
# all else
pg0
.
append
(
v
)
# all else
if
hyp
[
'optimizer'
]
==
'
adam'
:
# https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
if
hyp
[
'optimizer'
]
==
'
Adam'
:
optimizer
=
optim
.
Adam
(
pg0
,
lr
=
hyp
[
'lr0'
],
betas
=
(
hyp
[
'momentum'
],
0.999
))
# adjust beta1 to momentum
optimizer
=
optim
.
Adam
(
pg0
,
lr
=
hyp
[
'lr0'
],
betas
=
(
hyp
[
'momentum'
],
0.999
))
# adjust beta1 to momentum
else
:
else
:
optimizer
=
optim
.
SGD
(
pg0
,
lr
=
hyp
[
'lr0'
],
momentum
=
hyp
[
'momentum'
],
nesterov
=
True
)
optimizer
=
optim
.
SGD
(
pg0
,
lr
=
hyp
[
'lr0'
],
momentum
=
hyp
[
'momentum'
],
nesterov
=
True
)
...
@@ -110,9 +113,9 @@ def train(hyp, tb_writer, opt, device):
...
@@ -110,9 +113,9 @@ def train(hyp, tb_writer, opt, device):
del
pg0
,
pg1
,
pg2
del
pg0
,
pg1
,
pg2
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
# https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
lf
=
lambda
x
:
(((
1
+
math
.
cos
(
x
*
math
.
pi
/
epochs
))
/
2
)
**
1.0
)
*
0.8
+
0.2
# cosine
lf
=
lambda
x
:
(((
1
+
math
.
cos
(
x
*
math
.
pi
/
epochs
))
/
2
)
**
1.0
)
*
0.8
+
0.2
# cosine
scheduler
=
lr_scheduler
.
LambdaLR
(
optimizer
,
lr_lambda
=
lf
)
scheduler
=
lr_scheduler
.
LambdaLR
(
optimizer
,
lr_lambda
=
lf
)
# https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
# plot_lr_scheduler(optimizer, scheduler, epochs)
# plot_lr_scheduler(optimizer, scheduler, epochs)
# Load Model
# Load Model
...
...
utils/datasets.py
浏览文件 @
127cbeb3
...
@@ -484,11 +484,11 @@ class LoadImagesAndLabels(Dataset): # for training/testing
...
@@ -484,11 +484,11 @@ class LoadImagesAndLabels(Dataset): # for training/testing
shapes
=
None
shapes
=
None
# MixUp https://arxiv.org/pdf/1710.09412.pdf
# MixUp https://arxiv.org/pdf/1710.09412.pdf
# if random.random() < 0.5
:
if
random
.
random
()
<
hyp
[
'mixup'
]
:
#
img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1))
img2
,
labels2
=
load_mosaic
(
self
,
random
.
randint
(
0
,
len
(
self
.
labels
)
-
1
))
# r = np.random.beta(0.3, 0.3) # mixup ratio, alpha=beta=0.3
r
=
np
.
random
.
beta
(
8.0
,
8.0
)
# mixup ratio, alpha=beta=8.0
#
img = (img * r + img2 * (1 - r)).astype(np.uint8)
img
=
(
img
*
r
+
img2
*
(
1
-
r
))
.
astype
(
np
.
uint8
)
#
labels = np.concatenate((labels, labels2), 0)
labels
=
np
.
concatenate
((
labels
,
labels2
),
0
)
else
:
else
:
# Load image
# Load image
...
@@ -517,7 +517,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
...
@@ -517,7 +517,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
degrees
=
hyp
[
'degrees'
],
degrees
=
hyp
[
'degrees'
],
translate
=
hyp
[
'translate'
],
translate
=
hyp
[
'translate'
],
scale
=
hyp
[
'scale'
],
scale
=
hyp
[
'scale'
],
shear
=
hyp
[
'shear'
])
shear
=
hyp
[
'shear'
],
perspective
=
hyp
[
'perspective'
])
# Augment colorspace
# Augment colorspace
augment_hsv
(
img
,
hgain
=
hyp
[
'hsv_h'
],
sgain
=
hyp
[
'hsv_s'
],
vgain
=
hyp
[
'hsv_v'
])
augment_hsv
(
img
,
hgain
=
hyp
[
'hsv_h'
],
sgain
=
hyp
[
'hsv_s'
],
vgain
=
hyp
[
'hsv_v'
])
...
@@ -528,28 +529,23 @@ class LoadImagesAndLabels(Dataset): # for training/testing
...
@@ -528,28 +529,23 @@ class LoadImagesAndLabels(Dataset): # for training/testing
nL
=
len
(
labels
)
# number of labels
nL
=
len
(
labels
)
# number of labels
if
nL
:
if
nL
:
# convert xyxy to xywh
labels
[:,
1
:
5
]
=
xyxy2xywh
(
labels
[:,
1
:
5
])
# convert xyxy to xywh
labels
[:,
1
:
5
]
=
xyxy2xywh
(
labels
[:,
1
:
5
])
labels
[:,
[
2
,
4
]]
/=
img
.
shape
[
0
]
# normalized height 0-1
labels
[:,
[
1
,
3
]]
/=
img
.
shape
[
1
]
# normalized width 0-1
# Normalize coordinates 0 - 1
labels
[:,
[
2
,
4
]]
/=
img
.
shape
[
0
]
# height
labels
[:,
[
1
,
3
]]
/=
img
.
shape
[
1
]
# width
if
self
.
augment
:
if
self
.
augment
:
# random left-right flip
# flip up-down
lr_flip
=
True
if
random
.
random
()
<
hyp
[
'flipud'
]:
if
lr_flip
and
random
.
random
()
<
0.5
:
img
=
np
.
fliplr
(
img
)
if
nL
:
labels
[:,
1
]
=
1
-
labels
[:,
1
]
# random up-down flip
ud_flip
=
False
if
ud_flip
and
random
.
random
()
<
0.5
:
img
=
np
.
flipud
(
img
)
img
=
np
.
flipud
(
img
)
if
nL
:
if
nL
:
labels
[:,
2
]
=
1
-
labels
[:,
2
]
labels
[:,
2
]
=
1
-
labels
[:,
2
]
# flip left-right
if
random
.
random
()
<
hyp
[
'fliplr'
]:
img
=
np
.
fliplr
(
img
)
if
nL
:
labels
[:,
1
]
=
1
-
labels
[:,
1
]
labels_out
=
torch
.
zeros
((
nL
,
6
))
labels_out
=
torch
.
zeros
((
nL
,
6
))
if
nL
:
if
nL
:
labels_out
[:,
1
:]
=
torch
.
from_numpy
(
labels
)
labels_out
[:,
1
:]
=
torch
.
from_numpy
(
labels
)
...
@@ -661,6 +657,7 @@ def load_mosaic(self, index):
...
@@ -661,6 +657,7 @@ def load_mosaic(self, index):
translate
=
self
.
hyp
[
'translate'
],
translate
=
self
.
hyp
[
'translate'
],
scale
=
self
.
hyp
[
'scale'
],
scale
=
self
.
hyp
[
'scale'
],
shear
=
self
.
hyp
[
'shear'
],
shear
=
self
.
hyp
[
'shear'
],
perspective
=
self
.
hyp
[
'perspective'
],
border
=
self
.
mosaic_border
)
# border to remove
border
=
self
.
mosaic_border
)
# border to remove
return
img4
,
labels4
return
img4
,
labels4
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论