Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
Y
yolov5
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
Administrator
yolov5
Commits
77655577
提交
77655577
authored
7月 22, 2020
作者:
Glenn Jocher
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update train.py gsutil bucket fix (#463)
上级
4ffd9779
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
40 行增加
和
52 行删除
+40
-52
train.py
train.py
+40
-52
没有找到文件。
train.py
浏览文件 @
77655577
...
@@ -47,11 +47,13 @@ def train(hyp, tb_writer, opt, device):
...
@@ -47,11 +47,13 @@ def train(hyp, tb_writer, opt, device):
print
(
f
'Hyperparameters {hyp}'
)
print
(
f
'Hyperparameters {hyp}'
)
log_dir
=
tb_writer
.
log_dir
if
tb_writer
else
'runs/evolution'
# run directory
log_dir
=
tb_writer
.
log_dir
if
tb_writer
else
'runs/evolution'
# run directory
wdir
=
str
(
Path
(
log_dir
)
/
'weights'
)
+
os
.
sep
# weights directory
wdir
=
str
(
Path
(
log_dir
)
/
'weights'
)
+
os
.
sep
# weights directory
os
.
makedirs
(
wdir
,
exist_ok
=
True
)
os
.
makedirs
(
wdir
,
exist_ok
=
True
)
last
=
wdir
+
'last.pt'
last
=
wdir
+
'last.pt'
best
=
wdir
+
'best.pt'
best
=
wdir
+
'best.pt'
results_file
=
log_dir
+
os
.
sep
+
'results.txt'
results_file
=
log_dir
+
os
.
sep
+
'results.txt'
epochs
,
batch_size
,
total_batch_size
,
weights
,
rank
=
opt
.
epochs
,
opt
.
batch_size
,
opt
.
total_batch_size
,
opt
.
weights
,
opt
.
local_rank
# TODO: Init DDP logging. Only the first process is allowed to log.
# Since I see lots of print here, the logging configuration is skipped here. We may see repeated outputs.
# Save run settings
# Save run settings
with
open
(
Path
(
log_dir
)
/
'hyp.yaml'
,
'w'
)
as
f
:
with
open
(
Path
(
log_dir
)
/
'hyp.yaml'
,
'w'
)
as
f
:
...
@@ -59,17 +61,8 @@ def train(hyp, tb_writer, opt, device):
...
@@ -59,17 +61,8 @@ def train(hyp, tb_writer, opt, device):
with
open
(
Path
(
log_dir
)
/
'opt.yaml'
,
'w'
)
as
f
:
with
open
(
Path
(
log_dir
)
/
'opt.yaml'
,
'w'
)
as
f
:
yaml
.
dump
(
vars
(
opt
),
f
,
sort_keys
=
False
)
yaml
.
dump
(
vars
(
opt
),
f
,
sort_keys
=
False
)
epochs
=
opt
.
epochs
# 300
batch_size
=
opt
.
batch_size
# batch size per process.
total_batch_size
=
opt
.
total_batch_size
weights
=
opt
.
weights
# initial training weights
local_rank
=
opt
.
local_rank
# TODO: Init DDP logging. Only the first process is allowed to log.
# Since I see lots of print here, the logging configuration is skipped here. We may see repeated outputs.
# Configure
# Configure
init_seeds
(
2
+
local_
rank
)
init_seeds
(
2
+
rank
)
with
open
(
opt
.
data
)
as
f
:
with
open
(
opt
.
data
)
as
f
:
data_dict
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
# model dict
data_dict
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
# model dict
train_path
=
data_dict
[
'train'
]
train_path
=
data_dict
[
'train'
]
...
@@ -78,7 +71,7 @@ def train(hyp, tb_writer, opt, device):
...
@@ -78,7 +71,7 @@ def train(hyp, tb_writer, opt, device):
assert
len
(
names
)
==
nc
,
'
%
g names found for nc=
%
g dataset in
%
s'
%
(
len
(
names
),
nc
,
opt
.
data
)
# check
assert
len
(
names
)
==
nc
,
'
%
g names found for nc=
%
g dataset in
%
s'
%
(
len
(
names
),
nc
,
opt
.
data
)
# check
# Remove previous results
# Remove previous results
if
local_
rank
in
[
-
1
,
0
]:
if
rank
in
[
-
1
,
0
]:
for
f
in
glob
.
glob
(
'*_batch*.jpg'
)
+
glob
.
glob
(
results_file
):
for
f
in
glob
.
glob
(
'*_batch*.jpg'
)
+
glob
.
glob
(
results_file
):
os
.
remove
(
f
)
os
.
remove
(
f
)
...
@@ -91,7 +84,7 @@ def train(hyp, tb_writer, opt, device):
...
@@ -91,7 +84,7 @@ def train(hyp, tb_writer, opt, device):
# Optimizer
# Optimizer
nbs
=
64
# nominal batch size
nbs
=
64
# nominal batch size
#
the
default DDP implementation is slow for accumulation according to: https://pytorch.org/docs/stable/notes/ddp.html
# default DDP implementation is slow for accumulation according to: https://pytorch.org/docs/stable/notes/ddp.html
# all-reduce operation is carried out during loss.backward().
# all-reduce operation is carried out during loss.backward().
# Thus, there would be redundant all-reduce communications in a accumulation procedure,
# Thus, there would be redundant all-reduce communications in a accumulation procedure,
# which means, the result is still right but the training speed gets slower.
# which means, the result is still right but the training speed gets slower.
...
@@ -121,8 +114,7 @@ def train(hyp, tb_writer, opt, device):
...
@@ -121,8 +114,7 @@ def train(hyp, tb_writer, opt, device):
del
pg0
,
pg1
,
pg2
del
pg0
,
pg1
,
pg2
# Load Model
# Load Model
# Avoid multiple downloads.
with
torch_distributed_zero_first
(
rank
):
with
torch_distributed_zero_first
(
local_rank
):
google_utils
.
attempt_download
(
weights
)
google_utils
.
attempt_download
(
weights
)
start_epoch
,
best_fitness
=
0
,
0.0
start_epoch
,
best_fitness
=
0
,
0.0
if
weights
.
endswith
(
'.pt'
):
# pytorch format
if
weights
.
endswith
(
'.pt'
):
# pytorch format
...
@@ -169,32 +161,31 @@ def train(hyp, tb_writer, opt, device):
...
@@ -169,32 +161,31 @@ def train(hyp, tb_writer, opt, device):
# plot_lr_scheduler(optimizer, scheduler, epochs)
# plot_lr_scheduler(optimizer, scheduler, epochs)
# DP mode
# DP mode
if
device
.
type
!=
'cpu'
and
local_
rank
==
-
1
and
torch
.
cuda
.
device_count
()
>
1
:
if
device
.
type
!=
'cpu'
and
rank
==
-
1
and
torch
.
cuda
.
device_count
()
>
1
:
model
=
torch
.
nn
.
DataParallel
(
model
)
model
=
torch
.
nn
.
DataParallel
(
model
)
# Exponential moving average
# SyncBatchNorm
# From https://github.com/rwightman/pytorch-image-models/blob/master/train.py:
if
opt
.
sync_bn
and
device
.
type
!=
'cpu'
and
rank
!=
-
1
:
# "Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper"
# chenyzsjtu: ema should be placed before after SyncBN. As SyncBN introduces new modules.
if
opt
.
sync_bn
and
device
.
type
!=
'cpu'
and
local_rank
!=
-
1
:
print
(
"SyncBN activated!"
)
model
=
torch
.
nn
.
SyncBatchNorm
.
convert_sync_batchnorm
(
model
)
.
to
(
device
)
model
=
torch
.
nn
.
SyncBatchNorm
.
convert_sync_batchnorm
(
model
)
.
to
(
device
)
ema
=
torch_utils
.
ModelEMA
(
model
)
if
local_rank
in
[
-
1
,
0
]
else
None
print
(
'Using SyncBatchNorm()'
)
# Exponential moving average
ema
=
torch_utils
.
ModelEMA
(
model
)
if
rank
in
[
-
1
,
0
]
else
None
# DDP mode
# DDP mode
if
device
.
type
!=
'cpu'
and
local_
rank
!=
-
1
:
if
device
.
type
!=
'cpu'
and
rank
!=
-
1
:
model
=
DDP
(
model
,
device_ids
=
[
local_rank
],
output_device
=
local_
rank
)
model
=
DDP
(
model
,
device_ids
=
[
rank
],
output_device
=
rank
)
# Trainloader
# Trainloader
dataloader
,
dataset
=
create_dataloader
(
train_path
,
imgsz
,
batch_size
,
gs
,
opt
,
hyp
=
hyp
,
augment
=
True
,
dataloader
,
dataset
=
create_dataloader
(
train_path
,
imgsz
,
batch_size
,
gs
,
opt
,
hyp
=
hyp
,
augment
=
True
,
cache
=
opt
.
cache_images
,
rect
=
opt
.
rect
,
local_rank
=
local_
rank
,
cache
=
opt
.
cache_images
,
rect
=
opt
.
rect
,
local_rank
=
rank
,
world_size
=
opt
.
world_size
)
world_size
=
opt
.
world_size
)
mlc
=
np
.
concatenate
(
dataset
.
labels
,
0
)[:,
0
]
.
max
()
# max label class
mlc
=
np
.
concatenate
(
dataset
.
labels
,
0
)[:,
0
]
.
max
()
# max label class
nb
=
len
(
dataloader
)
# number of batches
nb
=
len
(
dataloader
)
# number of batches
assert
mlc
<
nc
,
'Label class
%
g exceeds nc=
%
g in
%
s. Possible class labels are 0-
%
g'
%
(
mlc
,
nc
,
opt
.
data
,
nc
-
1
)
assert
mlc
<
nc
,
'Label class
%
g exceeds nc=
%
g in
%
s. Possible class labels are 0-
%
g'
%
(
mlc
,
nc
,
opt
.
data
,
nc
-
1
)
# Testloader
# Testloader
if
local_
rank
in
[
-
1
,
0
]:
if
rank
in
[
-
1
,
0
]:
# local_rank is set to -1. Because only the first process is expected to do evaluation.
# local_rank is set to -1. Because only the first process is expected to do evaluation.
testloader
=
create_dataloader
(
test_path
,
imgsz_test
,
total_batch_size
,
gs
,
opt
,
hyp
=
hyp
,
augment
=
False
,
testloader
=
create_dataloader
(
test_path
,
imgsz_test
,
total_batch_size
,
gs
,
opt
,
hyp
=
hyp
,
augment
=
False
,
cache
=
opt
.
cache_images
,
rect
=
True
,
local_rank
=-
1
,
world_size
=
opt
.
world_size
)[
0
]
cache
=
opt
.
cache_images
,
rect
=
True
,
local_rank
=-
1
,
world_size
=
opt
.
world_size
)[
0
]
...
@@ -208,8 +199,7 @@ def train(hyp, tb_writer, opt, device):
...
@@ -208,8 +199,7 @@ def train(hyp, tb_writer, opt, device):
model
.
names
=
names
model
.
names
=
names
# Class frequency
# Class frequency
# Only one check and log is needed.
if
rank
in
[
-
1
,
0
]:
if
local_rank
in
[
-
1
,
0
]:
labels
=
np
.
concatenate
(
dataset
.
labels
,
0
)
labels
=
np
.
concatenate
(
dataset
.
labels
,
0
)
c
=
torch
.
tensor
(
labels
[:,
0
])
# classes
c
=
torch
.
tensor
(
labels
[:,
0
])
# classes
# cf = torch.bincount(c.long(), minlength=nc) + 1.
# cf = torch.bincount(c.long(), minlength=nc) + 1.
...
@@ -222,13 +212,14 @@ def train(hyp, tb_writer, opt, device):
...
@@ -222,13 +212,14 @@ def train(hyp, tb_writer, opt, device):
# Check anchors
# Check anchors
if
not
opt
.
noautoanchor
:
if
not
opt
.
noautoanchor
:
check_anchors
(
dataset
,
model
=
model
,
thr
=
hyp
[
'anchor_t'
],
imgsz
=
imgsz
)
check_anchors
(
dataset
,
model
=
model
,
thr
=
hyp
[
'anchor_t'
],
imgsz
=
imgsz
)
# Start training
# Start training
t0
=
time
.
time
()
t0
=
time
.
time
()
nw
=
max
(
3
*
nb
,
1e3
)
# number of warmup iterations, max(3 epochs, 1k iterations)
nw
=
max
(
3
*
nb
,
1e3
)
# number of warmup iterations, max(3 epochs, 1k iterations)
maps
=
np
.
zeros
(
nc
)
# mAP per class
maps
=
np
.
zeros
(
nc
)
# mAP per class
results
=
(
0
,
0
,
0
,
0
,
0
,
0
,
0
)
# 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
results
=
(
0
,
0
,
0
,
0
,
0
,
0
,
0
)
# 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
scheduler
.
last_epoch
=
start_epoch
-
1
# do not move
scheduler
.
last_epoch
=
start_epoch
-
1
# do not move
if
local_
rank
in
[
0
,
-
1
]:
if
rank
in
[
0
,
-
1
]:
print
(
'Image sizes
%
g train,
%
g test'
%
(
imgsz
,
imgsz_test
))
print
(
'Image sizes
%
g train,
%
g test'
%
(
imgsz
,
imgsz_test
))
print
(
'Using
%
g dataloader workers'
%
dataloader
.
num_workers
)
print
(
'Using
%
g dataloader workers'
%
dataloader
.
num_workers
)
print
(
'Starting training for
%
g epochs...'
%
epochs
)
print
(
'Starting training for
%
g epochs...'
%
epochs
)
...
@@ -240,18 +231,18 @@ def train(hyp, tb_writer, opt, device):
...
@@ -240,18 +231,18 @@ def train(hyp, tb_writer, opt, device):
# When in DDP mode, the generated indices will be broadcasted to synchronize dataset.
# When in DDP mode, the generated indices will be broadcasted to synchronize dataset.
if
dataset
.
image_weights
:
if
dataset
.
image_weights
:
# Generate indices.
# Generate indices.
if
local_
rank
in
[
-
1
,
0
]:
if
rank
in
[
-
1
,
0
]:
w
=
model
.
class_weights
.
cpu
()
.
numpy
()
*
(
1
-
maps
)
**
2
# class weights
w
=
model
.
class_weights
.
cpu
()
.
numpy
()
*
(
1
-
maps
)
**
2
# class weights
image_weights
=
labels_to_image_weights
(
dataset
.
labels
,
nc
=
nc
,
class_weights
=
w
)
image_weights
=
labels_to_image_weights
(
dataset
.
labels
,
nc
=
nc
,
class_weights
=
w
)
dataset
.
indices
=
random
.
choices
(
range
(
dataset
.
n
),
weights
=
image_weights
,
dataset
.
indices
=
random
.
choices
(
range
(
dataset
.
n
),
weights
=
image_weights
,
k
=
dataset
.
n
)
# rand weighted idx
k
=
dataset
.
n
)
# rand weighted idx
# Broadcast.
# Broadcast.
if
local_
rank
!=
-
1
:
if
rank
!=
-
1
:
indices
=
torch
.
zeros
([
dataset
.
n
],
dtype
=
torch
.
int
)
indices
=
torch
.
zeros
([
dataset
.
n
],
dtype
=
torch
.
int
)
if
local_
rank
==
0
:
if
rank
==
0
:
indices
[:]
=
torch
.
from_tensor
(
dataset
.
indices
,
dtype
=
torch
.
int
)
indices
[:]
=
torch
.
from_tensor
(
dataset
.
indices
,
dtype
=
torch
.
int
)
dist
.
broadcast
(
indices
,
0
)
dist
.
broadcast
(
indices
,
0
)
if
local_
rank
!=
0
:
if
rank
!=
0
:
dataset
.
indices
=
indices
.
cpu
()
.
numpy
()
dataset
.
indices
=
indices
.
cpu
()
.
numpy
()
# Update mosaic border
# Update mosaic border
...
@@ -259,10 +250,10 @@ def train(hyp, tb_writer, opt, device):
...
@@ -259,10 +250,10 @@ def train(hyp, tb_writer, opt, device):
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders
mloss
=
torch
.
zeros
(
4
,
device
=
device
)
# mean losses
mloss
=
torch
.
zeros
(
4
,
device
=
device
)
# mean losses
if
local_
rank
!=
-
1
:
if
rank
!=
-
1
:
dataloader
.
sampler
.
set_epoch
(
epoch
)
dataloader
.
sampler
.
set_epoch
(
epoch
)
pbar
=
enumerate
(
dataloader
)
pbar
=
enumerate
(
dataloader
)
if
local_
rank
in
[
-
1
,
0
]:
if
rank
in
[
-
1
,
0
]:
print
((
'
\n
'
+
'
%10
s'
*
8
)
%
(
'Epoch'
,
'gpu_mem'
,
'GIoU'
,
'obj'
,
'cls'
,
'total'
,
'targets'
,
'img_size'
))
print
((
'
\n
'
+
'
%10
s'
*
8
)
%
(
'Epoch'
,
'gpu_mem'
,
'GIoU'
,
'obj'
,
'cls'
,
'total'
,
'targets'
,
'img_size'
))
pbar
=
tqdm
(
pbar
,
total
=
nb
)
# progress bar
pbar
=
tqdm
(
pbar
,
total
=
nb
)
# progress bar
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
...
@@ -293,10 +284,9 @@ def train(hyp, tb_writer, opt, device):
...
@@ -293,10 +284,9 @@ def train(hyp, tb_writer, opt, device):
pred
=
model
(
imgs
)
pred
=
model
(
imgs
)
# Loss
# Loss
loss
,
loss_items
=
compute_loss
(
pred
,
targets
.
to
(
device
),
model
)
loss
,
loss_items
=
compute_loss
(
pred
,
targets
.
to
(
device
),
model
)
# scaled by batch_size
# loss is scaled with batch size in func compute_loss. But in DDP mode, gradient is averaged between devices.
if
rank
!=
-
1
:
if
local_rank
!=
-
1
:
loss
*=
opt
.
world_size
# gradient averaged between devices in DDP mode
loss
*=
opt
.
world_size
if
not
torch
.
isfinite
(
loss
):
if
not
torch
.
isfinite
(
loss
):
print
(
'WARNING: non-finite loss, ending training '
,
loss_items
)
print
(
'WARNING: non-finite loss, ending training '
,
loss_items
)
return
results
return
results
...
@@ -316,7 +306,7 @@ def train(hyp, tb_writer, opt, device):
...
@@ -316,7 +306,7 @@ def train(hyp, tb_writer, opt, device):
ema
.
update
(
model
)
ema
.
update
(
model
)
# Print
# Print
if
local_
rank
in
[
-
1
,
0
]:
if
rank
in
[
-
1
,
0
]:
mloss
=
(
mloss
*
i
+
loss_items
)
/
(
i
+
1
)
# update mean losses
mloss
=
(
mloss
*
i
+
loss_items
)
/
(
i
+
1
)
# update mean losses
mem
=
'
%.3
gG'
%
(
torch
.
cuda
.
memory_cached
()
/
1E9
if
torch
.
cuda
.
is_available
()
else
0
)
# (GB)
mem
=
'
%.3
gG'
%
(
torch
.
cuda
.
memory_cached
()
/
1E9
if
torch
.
cuda
.
is_available
()
else
0
)
# (GB)
s
=
(
'
%10
s'
*
2
+
'
%10.4
g'
*
6
)
%
(
s
=
(
'
%10
s'
*
2
+
'
%10.4
g'
*
6
)
%
(
...
@@ -337,7 +327,7 @@ def train(hyp, tb_writer, opt, device):
...
@@ -337,7 +327,7 @@ def train(hyp, tb_writer, opt, device):
scheduler
.
step
()
scheduler
.
step
()
# Only the first process in DDP mode is allowed to log or save checkpoints.
# Only the first process in DDP mode is allowed to log or save checkpoints.
if
local_
rank
in
[
-
1
,
0
]:
if
rank
in
[
-
1
,
0
]:
# mAP
# mAP
if
ema
is
not
None
:
if
ema
is
not
None
:
ema
.
update_attr
(
model
,
include
=
[
'md'
,
'nc'
,
'hyp'
,
'gr'
,
'names'
,
'stride'
])
ema
.
update_attr
(
model
,
include
=
[
'md'
,
'nc'
,
'hyp'
,
'gr'
,
'names'
,
'stride'
])
...
@@ -351,17 +341,17 @@ def train(hyp, tb_writer, opt, device):
...
@@ -351,17 +341,17 @@ def train(hyp, tb_writer, opt, device):
single_cls
=
opt
.
single_cls
,
single_cls
=
opt
.
single_cls
,
dataloader
=
testloader
,
dataloader
=
testloader
,
save_dir
=
log_dir
)
save_dir
=
log_dir
)
# Explicitly keep the shape.
# Write
# Write
with
open
(
results_file
,
'a'
)
as
f
:
with
open
(
results_file
,
'a'
)
as
f
:
f
.
write
(
s
+
'
%10.4
g'
*
7
%
results
+
'
\n
'
)
# P, R, mAP, F1, test_losses=(GIoU, obj, cls)
f
.
write
(
s
+
'
%10.4
g'
*
7
%
results
+
'
\n
'
)
# P, R, mAP, F1, test_losses=(GIoU, obj, cls)
if
len
(
opt
.
name
)
and
opt
.
bucket
:
if
len
(
opt
.
name
)
and
opt
.
bucket
:
os
.
system
(
'gsutil cp
results.txt gs://
%
s/results/results
%
s.txt'
%
(
opt
.
bucket
,
opt
.
name
))
os
.
system
(
'gsutil cp
%
s gs://
%
s/results/results
%
s.txt'
%
(
results_file
,
opt
.
bucket
,
opt
.
name
))
# Tensorboard
# Tensorboard
if
tb_writer
:
if
tb_writer
:
tags
=
[
'train/giou_loss'
,
'train/obj_loss'
,
'train/cls_loss'
,
tags
=
[
'train/giou_loss'
,
'train/obj_loss'
,
'train/cls_loss'
,
'metrics/precision'
,
'metrics/recall'
,
'metrics/mAP_0.5'
,
'metrics/
F1
'
,
'metrics/precision'
,
'metrics/recall'
,
'metrics/mAP_0.5'
,
'metrics/
mAP_0.5:0.95
'
,
'val/giou_loss'
,
'val/obj_loss'
,
'val/cls_loss'
]
'val/giou_loss'
,
'val/obj_loss'
,
'val/cls_loss'
]
for
x
,
tag
in
zip
(
list
(
mloss
[:
-
1
])
+
list
(
results
),
tags
):
for
x
,
tag
in
zip
(
list
(
mloss
[:
-
1
])
+
list
(
results
),
tags
):
tb_writer
.
add_scalar
(
tag
,
x
,
epoch
)
tb_writer
.
add_scalar
(
tag
,
x
,
epoch
)
...
@@ -389,7 +379,7 @@ def train(hyp, tb_writer, opt, device):
...
@@ -389,7 +379,7 @@ def train(hyp, tb_writer, opt, device):
# end epoch ----------------------------------------------------------------------------------------------------
# end epoch ----------------------------------------------------------------------------------------------------
# end training
# end training
if
local_
rank
in
[
-
1
,
0
]:
if
rank
in
[
-
1
,
0
]:
# Strip optimizers
# Strip optimizers
n
=
(
'_'
if
len
(
opt
.
name
)
and
not
opt
.
name
.
isnumeric
()
else
''
)
+
opt
.
name
n
=
(
'_'
if
len
(
opt
.
name
)
and
not
opt
.
name
.
isnumeric
()
else
''
)
+
opt
.
name
fresults
,
flast
,
fbest
=
'results
%
s.txt'
%
n
,
wdir
+
'last
%
s.pt'
%
n
,
wdir
+
'best
%
s.pt'
%
n
fresults
,
flast
,
fbest
=
'results
%
s.txt'
%
n
,
wdir
+
'last
%
s.pt'
%
n
,
wdir
+
'best
%
s.pt'
%
n
...
@@ -401,10 +391,10 @@ def train(hyp, tb_writer, opt, device):
...
@@ -401,10 +391,10 @@ def train(hyp, tb_writer, opt, device):
os
.
system
(
'gsutil cp
%
s gs://
%
s/weights'
%
(
f2
,
opt
.
bucket
))
if
opt
.
bucket
and
ispt
else
None
# upload
os
.
system
(
'gsutil cp
%
s gs://
%
s/weights'
%
(
f2
,
opt
.
bucket
))
if
opt
.
bucket
and
ispt
else
None
# upload
# Finish
# Finish
if
not
opt
.
evolve
:
if
not
opt
.
evolve
:
plot_results
()
# save as results.png
plot_results
(
save_dir
=
log_dir
)
# save as results.png
print
(
'
%
g epochs completed in
%.3
f hours.
\n
'
%
(
epoch
-
start_epoch
+
1
,
(
time
.
time
()
-
t0
)
/
3600
))
print
(
'
%
g epochs completed in
%.3
f hours.
\n
'
%
(
epoch
-
start_epoch
+
1
,
(
time
.
time
()
-
t0
)
/
3600
))
dist
.
destroy_process_group
()
if
local_
rank
not
in
[
-
1
,
0
]
else
None
dist
.
destroy_process_group
()
if
rank
not
in
[
-
1
,
0
]
else
None
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
results
return
results
...
@@ -431,10 +421,8 @@ if __name__ == '__main__':
...
@@ -431,10 +421,8 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--device'
,
default
=
''
,
help
=
'cuda device, i.e. 0 or 0,1,2,3 or cpu'
)
parser
.
add_argument
(
'--device'
,
default
=
''
,
help
=
'cuda device, i.e. 0 or 0,1,2,3 or cpu'
)
parser
.
add_argument
(
'--multi-scale'
,
action
=
'store_true'
,
help
=
'vary img-size +/- 50
%%
'
)
parser
.
add_argument
(
'--multi-scale'
,
action
=
'store_true'
,
help
=
'vary img-size +/- 50
%%
'
)
parser
.
add_argument
(
'--single-cls'
,
action
=
'store_true'
,
help
=
'train as single-class dataset'
)
parser
.
add_argument
(
'--single-cls'
,
action
=
'store_true'
,
help
=
'train as single-class dataset'
)
parser
.
add_argument
(
"--sync-bn"
,
action
=
"store_true"
,
help
=
"Use sync-bn, only avaible in DDP mode."
)
parser
.
add_argument
(
'--sync-bn'
,
action
=
"store_true"
,
help
=
'use SyncBatchNorm, only available in DDP mode'
)
# Parameter For DDP.
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=-
1
,
help
=
'DDP parameter, do not modify'
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=-
1
,
help
=
"Extra parameter for DDP implementation. Don't use it manually."
)
opt
=
parser
.
parse_args
()
opt
=
parser
.
parse_args
()
last
=
get_latest_run
()
if
opt
.
resume
==
'get_last'
else
opt
.
resume
# resume from most recent run
last
=
get_latest_run
()
if
opt
.
resume
==
'get_last'
else
opt
.
resume
# resume from most recent run
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论