Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
Y
yolov5
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
Administrator
yolov5
Commits
ebafd1ea
Unverified
提交
ebafd1ea
authored
8月 17, 2020
作者:
Glenn Jocher
提交者:
GitHub
8月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
single command --resume (#756)
* single command --resume * else check files, remove TODO * argparse.Namespace() * tensorboard lr * bug fix in get_latest_run()
上级
26c3b11f
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
31 行增加
和
26 行删除
+31
-26
train.py
train.py
+30
-25
general.py
utils/general.py
+1
-1
没有找到文件。
train.py
浏览文件 @
ebafd1ea
...
...
@@ -42,7 +42,6 @@ def train(hyp, opt, device, tb_writer=None):
epochs
,
batch_size
,
total_batch_size
,
weights
,
rank
=
\
opt
.
epochs
,
opt
.
batch_size
,
opt
.
total_batch_size
,
opt
.
weights
,
opt
.
global_rank
# TODO: Use DDP logging. Only the first process is allowed to log.
# Save run settings
with
open
(
log_dir
/
'hyp.yaml'
,
'w'
)
as
f
:
yaml
.
dump
(
hyp
,
f
,
sort_keys
=
False
)
...
...
@@ -130,6 +129,8 @@ def train(hyp, opt, device, tb_writer=None):
# Epochs
start_epoch
=
ckpt
[
'epoch'
]
+
1
if
opt
.
resume
:
assert
start_epoch
>
0
,
'
%
s training to
%
g epochs is finished, nothing to resume.'
%
(
weights
,
epochs
)
if
epochs
<
start_epoch
:
logger
.
info
(
'
%
s has been trained for
%
g epochs. Fine-tuning for
%
g additional epochs.'
%
(
weights
,
ckpt
[
'epoch'
],
epochs
))
...
...
@@ -158,19 +159,19 @@ def train(hyp, opt, device, tb_writer=None):
model
=
DDP
(
model
,
device_ids
=
[
opt
.
local_rank
],
output_device
=
(
opt
.
local_rank
))
# Trainloader
dataloader
,
dataset
=
create_dataloader
(
train_path
,
imgsz
,
batch_size
,
gs
,
opt
,
hyp
=
hyp
,
augment
=
True
,
cache
=
opt
.
cache_images
,
rect
=
opt
.
rect
,
rank
=
rank
,
dataloader
,
dataset
=
create_dataloader
(
train_path
,
imgsz
,
batch_size
,
gs
,
opt
,
hyp
=
hyp
,
augment
=
True
,
cache
=
opt
.
cache_images
,
rect
=
opt
.
rect
,
rank
=
rank
,
world_size
=
opt
.
world_size
,
workers
=
opt
.
workers
)
mlc
=
np
.
concatenate
(
dataset
.
labels
,
0
)[:,
0
]
.
max
()
# max label class
nb
=
len
(
dataloader
)
# number of batches
ema
.
updates
=
start_epoch
*
nb
//
accumulate
# set EMA updates
assert
mlc
<
nc
,
'Label class
%
g exceeds nc=
%
g in
%
s. Possible class labels are 0-
%
g'
%
(
mlc
,
nc
,
opt
.
data
,
nc
-
1
)
# Testloader
if
rank
in
[
-
1
,
0
]:
# local_rank is set to -1. Because only the first process is expected to do evaluation.
testloader
=
create_dataloader
(
test_path
,
imgsz_test
,
total_batch_size
,
gs
,
opt
,
hyp
=
hyp
,
augment
=
False
,
cache
=
opt
.
cache_images
,
rect
=
True
,
rank
=-
1
,
world_size
=
opt
.
world_size
,
workers
=
opt
.
workers
)[
0
]
testloader
=
create_dataloader
(
test_path
,
imgsz_test
,
total_batch_size
,
gs
,
opt
,
hyp
=
hyp
,
augment
=
False
,
cache
=
opt
.
cache_images
,
rect
=
True
,
rank
=-
1
,
world_size
=
opt
.
world_size
,
workers
=
opt
.
workers
)[
0
]
# only runs on process 0
# Model parameters
hyp
[
'cls'
]
*=
nc
/
80.
# scale coco-tuned hyp['cls'] to current dataset
...
...
@@ -283,7 +284,7 @@ def train(hyp, opt, device, tb_writer=None):
scaler
.
step
(
optimizer
)
# optimizer.step
scaler
.
update
()
optimizer
.
zero_grad
()
if
ema
is
not
None
:
if
ema
:
ema
.
update
(
model
)
# Print
...
...
@@ -305,12 +306,13 @@ def train(hyp, opt, device, tb_writer=None):
# end batch ------------------------------------------------------------------------------------------------
# Scheduler
lr
=
[
x
[
'lr'
]
for
x
in
optimizer
.
param_groups
]
# for tensorboard
scheduler
.
step
()
# DDP process 0 or single-GPU
if
rank
in
[
-
1
,
0
]:
# mAP
if
ema
is
not
None
:
if
ema
:
ema
.
update_attr
(
model
,
include
=
[
'yaml'
,
'nc'
,
'hyp'
,
'gr'
,
'names'
,
'stride'
])
final_epoch
=
epoch
+
1
==
epochs
if
not
opt
.
notest
or
final_epoch
:
# Calculate mAP
...
...
@@ -330,10 +332,11 @@ def train(hyp, opt, device, tb_writer=None):
# Tensorboard
if
tb_writer
:
tags
=
[
'train/giou_loss'
,
'train/obj_loss'
,
'train/cls_loss'
,
tags
=
[
'train/giou_loss'
,
'train/obj_loss'
,
'train/cls_loss'
,
# train loss
'metrics/precision'
,
'metrics/recall'
,
'metrics/mAP_0.5'
,
'metrics/mAP_0.5:0.95'
,
'val/giou_loss'
,
'val/obj_loss'
,
'val/cls_loss'
]
for
x
,
tag
in
zip
(
list
(
mloss
[:
-
1
])
+
list
(
results
),
tags
):
'val/giou_loss'
,
'val/obj_loss'
,
'val/cls_loss'
,
# val loss
'x/lr0'
,
'x/lr1'
,
'x/lr2'
]
# params
for
x
,
tag
in
zip
(
list
(
mloss
[:
-
1
])
+
list
(
results
)
+
lr
,
tags
):
tb_writer
.
add_scalar
(
tag
,
x
,
epoch
)
# Update best mAP
...
...
@@ -389,8 +392,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
16
,
help
=
'total batch size for all GPUs'
)
parser
.
add_argument
(
'--img-size'
,
nargs
=
'+'
,
type
=
int
,
default
=
[
640
,
640
],
help
=
'train,test sizes'
)
parser
.
add_argument
(
'--rect'
,
action
=
'store_true'
,
help
=
'rectangular training'
)
parser
.
add_argument
(
'--resume'
,
nargs
=
'?'
,
const
=
'get_last'
,
default
=
False
,
help
=
'resume from given path/last.pt, or most recent run if blank'
)
parser
.
add_argument
(
'--resume'
,
nargs
=
'?'
,
const
=
True
,
default
=
False
,
help
=
'resume most recent training'
)
parser
.
add_argument
(
'--nosave'
,
action
=
'store_true'
,
help
=
'only save final checkpoint'
)
parser
.
add_argument
(
'--notest'
,
action
=
'store_true'
,
help
=
'only test final epoch'
)
parser
.
add_argument
(
'--noautoanchor'
,
action
=
'store_true'
,
help
=
'disable autoanchor check'
)
...
...
@@ -413,21 +415,24 @@ if __name__ == '__main__':
opt
.
world_size
=
int
(
os
.
environ
[
'WORLD_SIZE'
])
if
'WORLD_SIZE'
in
os
.
environ
else
1
opt
.
global_rank
=
int
(
os
.
environ
[
'RANK'
])
if
'RANK'
in
os
.
environ
else
-
1
set_logging
(
opt
.
global_rank
)
# Resume
if
opt
.
resume
:
last
=
get_latest_run
()
if
opt
.
resume
==
'get_last'
else
opt
.
resume
# resume from most recent run
if
last
and
not
opt
.
weights
:
logger
.
info
(
f
'Resuming training from {last}'
)
opt
.
weights
=
last
if
opt
.
resume
and
not
opt
.
weights
else
opt
.
weights
if
opt
.
global_rank
in
[
-
1
,
0
]:
check_git_status
()
opt
.
hyp
=
opt
.
hyp
or
(
'data/hyp.finetune.yaml'
if
opt
.
weights
else
'data/hyp.scratch.yaml'
)
opt
.
data
,
opt
.
cfg
,
opt
.
hyp
=
check_file
(
opt
.
data
),
check_file
(
opt
.
cfg
),
check_file
(
opt
.
hyp
)
# check files
assert
len
(
opt
.
cfg
)
or
len
(
opt
.
weights
),
'either --cfg or --weights must be specified'
# Resume
if
opt
.
resume
:
# resume an interrupted run
ckpt
=
opt
.
resume
if
isinstance
(
opt
.
resume
,
str
)
else
get_latest_run
()
# specified or most recent path
assert
os
.
path
.
isfile
(
ckpt
),
'ERROR: --resume checkpoint does not exist'
with
open
(
Path
(
ckpt
)
.
parent
.
parent
/
'opt.yaml'
)
as
f
:
opt
=
argparse
.
Namespace
(
**
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
))
# replace
opt
.
cfg
,
opt
.
weights
,
opt
.
resume
=
''
,
ckpt
,
True
logger
.
info
(
'Resuming training from
%
s'
%
ckpt
)
else
:
opt
.
hyp
=
opt
.
hyp
or
(
'data/hyp.finetune.yaml'
if
opt
.
weights
else
'data/hyp.scratch.yaml'
)
opt
.
data
,
opt
.
cfg
,
opt
.
hyp
=
check_file
(
opt
.
data
),
check_file
(
opt
.
cfg
),
check_file
(
opt
.
hyp
)
# check files
assert
len
(
opt
.
cfg
)
or
len
(
opt
.
weights
),
'either --cfg or --weights must be specified'
opt
.
img_size
.
extend
([
opt
.
img_size
[
-
1
]]
*
(
2
-
len
(
opt
.
img_size
)))
# extend to 2 sizes (train, test)
opt
.
img_size
.
extend
([
opt
.
img_size
[
-
1
]]
*
(
2
-
len
(
opt
.
img_size
)))
# extend to 2 sizes (train, test)
device
=
select_device
(
opt
.
device
,
batch_size
=
opt
.
batch_size
)
# DDP mode
...
...
utils/general.py
浏览文件 @
ebafd1ea
...
...
@@ -61,7 +61,7 @@ def init_seeds(seed=0):
def
get_latest_run
(
search_dir
=
'./runs'
):
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
last_list
=
glob
.
glob
(
f
'{search_dir}/**/last*.pt'
,
recursive
=
True
)
return
max
(
last_list
,
key
=
os
.
path
.
getctime
)
return
max
(
last_list
,
key
=
os
.
path
.
getctime
)
if
last_list
else
''
def
check_git_status
():
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论