Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
Y
yolov5
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
Administrator
yolov5
Commits
08c8c3e0
Unverified
提交
08c8c3e0
authored
8月 02, 2022
作者:
Glenn Jocher
提交者:
GitHub
8月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
New `smart_resume()` (#8838)
* New `smart_resume()` * Update torch_utils.py * Update torch_utils.py * Update torch_utils.py * fix
上级
2e109099
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
25 行增加
和
27 行删除
+25
-27
train.py
train.py
+6
-27
torch_utils.py
utils/torch_utils.py
+19
-0
没有找到文件。
train.py
浏览文件 @
08c8c3e0
...
...
@@ -54,7 +54,7 @@ from utils.loss import ComputeLoss
from
utils.metrics
import
fitness
from
utils.plots
import
plot_evolve
,
plot_labels
from
utils.torch_utils
import
(
EarlyStopping
,
ModelEMA
,
de_parallel
,
select_device
,
smart_DDP
,
smart_optimizer
,
torch_distributed_zero_first
)
smart_resume
,
torch_distributed_zero_first
)
LOCAL_RANK
=
int
(
os
.
getenv
(
'LOCAL_RANK'
,
-
1
))
# https://pytorch.org/docs/stable/elastic/run.html
RANK
=
int
(
os
.
getenv
(
'RANK'
,
-
1
))
...
...
@@ -163,26 +163,9 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
ema
=
ModelEMA
(
model
)
if
RANK
in
{
-
1
,
0
}
else
None
# Resume
start_epoch
,
best_fitness
=
0
,
0.
0
best_fitness
,
start_epoch
=
0.0
,
0
if
pretrained
:
# Optimizer
if
ckpt
[
'optimizer'
]
is
not
None
:
optimizer
.
load_state_dict
(
ckpt
[
'optimizer'
])
best_fitness
=
ckpt
[
'best_fitness'
]
# EMA
if
ema
and
ckpt
.
get
(
'ema'
):
ema
.
ema
.
load_state_dict
(
ckpt
[
'ema'
]
.
float
()
.
state_dict
())
ema
.
updates
=
ckpt
[
'updates'
]
# Epochs
start_epoch
=
ckpt
[
'epoch'
]
+
1
if
resume
:
assert
start_epoch
>
0
,
f
'{weights} training to {epochs} epochs is finished, nothing to resume.'
if
epochs
<
start_epoch
:
LOGGER
.
info
(
f
"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs."
)
epochs
+=
ckpt
[
'epoch'
]
# finetune additional epochs
best_fitness
,
start_epoch
,
epochs
=
smart_resume
(
ckpt
,
optimizer
,
ema
,
weights
,
epochs
,
resume
)
del
ckpt
,
csd
# DP mode
...
...
@@ -212,8 +195,8 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
quad
=
opt
.
quad
,
prefix
=
colorstr
(
'train: '
),
shuffle
=
True
)
mlc
=
int
(
np
.
concatenate
(
dataset
.
labels
,
0
)[:,
0
]
.
max
())
# max label class
nb
=
len
(
train_loader
)
# number of batche
s
labels
=
np
.
concatenate
(
dataset
.
labels
,
0
)
mlc
=
int
(
labels
[:,
0
]
.
max
())
# max label clas
s
assert
mlc
<
nc
,
f
'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
# Process 0
...
...
@@ -232,10 +215,6 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
prefix
=
colorstr
(
'val: '
))[
0
]
if
not
resume
:
labels
=
np
.
concatenate
(
dataset
.
labels
,
0
)
# c = torch.tensor(labels[:, 0]) # classes
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
# model._initialize_biases(cf.to(device))
if
plots
:
plot_labels
(
labels
,
names
,
save_dir
)
...
...
@@ -263,6 +242,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
# Start training
t0
=
time
.
time
()
nb
=
len
(
train_loader
)
# number of batches
nw
=
max
(
round
(
hyp
[
'warmup_epochs'
]
*
nb
),
100
)
# number of warmup iterations, max(3 epochs, 100 iterations)
# nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
last_opt_step
=
-
1
...
...
@@ -510,7 +490,6 @@ def main(opt, callbacks=Callbacks()):
with
open
(
Path
(
ckpt
)
.
parent
.
parent
/
'opt.yaml'
,
errors
=
'ignore'
)
as
f
:
opt
=
argparse
.
Namespace
(
**
yaml
.
safe_load
(
f
))
# replace
opt
.
cfg
,
opt
.
weights
,
opt
.
resume
=
''
,
ckpt
,
True
# reinstate
LOGGER
.
info
(
f
'Resuming training from {ckpt}'
)
else
:
opt
.
data
,
opt
.
cfg
,
opt
.
hyp
,
opt
.
weights
,
opt
.
project
=
\
check_file
(
opt
.
data
),
check_yaml
(
opt
.
cfg
),
check_yaml
(
opt
.
hyp
),
str
(
opt
.
weights
),
str
(
opt
.
project
)
# checks
...
...
utils/torch_utils.py
浏览文件 @
08c8c3e0
...
...
@@ -306,6 +306,25 @@ def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, weight_decay=1e-
return
optimizer
def
smart_resume
(
ckpt
,
optimizer
,
ema
=
None
,
weights
=
'yolov5s.pt'
,
epochs
=
300
,
resume
=
True
):
# Resume training from a partially trained checkpoint
best_fitness
=
0.0
start_epoch
=
ckpt
[
'epoch'
]
+
1
if
ckpt
[
'optimizer'
]
is
not
None
:
optimizer
.
load_state_dict
(
ckpt
[
'optimizer'
])
# optimizer
best_fitness
=
ckpt
[
'best_fitness'
]
if
ema
and
ckpt
.
get
(
'ema'
):
ema
.
ema
.
load_state_dict
(
ckpt
[
'ema'
]
.
float
()
.
state_dict
())
# EMA
ema
.
updates
=
ckpt
[
'updates'
]
if
resume
:
assert
start_epoch
>
0
,
f
'{weights} training to {epochs} epochs is finished, nothing to resume.'
LOGGER
.
info
(
f
'Resuming training from {weights} for {epochs - start_epoch} more epochs to {epochs} total epochs'
)
if
epochs
<
start_epoch
:
LOGGER
.
info
(
f
"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs."
)
epochs
+=
ckpt
[
'epoch'
]
# finetune additional epochs
return
best_fitness
,
start_epoch
,
epochs
class
EarlyStopping
:
# YOLOv5 simple early stopper
def
__init__
(
self
,
patience
=
30
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论