Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
Y
yolov5
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
Administrator
yolov5
Commits
ca9babb8
Unverified
提交
ca9babb8
authored
1月 15, 2021
作者:
Glenn Jocher
提交者:
GitHub
1月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add ComputeLoss() class (#1950)
上级
f4a78e1b
显示空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
48 行增加
和
32 行删除
+48
-32
test.py
test.py
+4
-4
train.py
train.py
+5
-3
loss.py
utils/loss.py
+39
-25
没有找到文件。
test.py
浏览文件 @
ca9babb8
...
...
@@ -13,7 +13,6 @@ from models.experimental import attempt_load
from
utils.datasets
import
create_dataloader
from
utils.general
import
coco80_to_coco91_class
,
check_dataset
,
check_file
,
check_img_size
,
check_requirements
,
\
box_iou
,
non_max_suppression
,
scale_coords
,
xyxy2xywh
,
xywh2xyxy
,
set_logging
,
increment_path
,
colorstr
from
utils.loss
import
compute_loss
from
utils.metrics
import
ap_per_class
,
ConfusionMatrix
from
utils.plots
import
plot_images
,
output_to_target
,
plot_study_txt
from
utils.torch_utils
import
select_device
,
time_synchronized
...
...
@@ -36,7 +35,8 @@ def test(data,
save_hybrid
=
False
,
# for hybrid auto-labelling
save_conf
=
False
,
# save auto-label confidences
plots
=
True
,
log_imgs
=
0
):
# number of logged images
log_imgs
=
0
,
# number of logged images
compute_loss
=
None
):
# Initialize/load model and set device
training
=
model
is
not
None
...
...
@@ -111,8 +111,8 @@ def test(data,
t0
+=
time_synchronized
()
-
t
# Compute loss
if
training
:
loss
+=
compute_loss
([
x
.
float
()
for
x
in
train_out
],
targets
,
model
)[
1
][:
3
]
# box, obj, cls
if
compute_loss
:
loss
+=
compute_loss
([
x
.
float
()
for
x
in
train_out
],
targets
)[
1
][:
3
]
# box, obj, cls
# Run NMS
targets
[:,
2
:]
*=
torch
.
Tensor
([
width
,
height
,
width
,
height
])
.
to
(
device
)
# to pixels
...
...
train.py
浏览文件 @
ca9babb8
...
...
@@ -29,7 +29,7 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima
fitness
,
strip_optimizer
,
get_latest_run
,
check_dataset
,
check_file
,
check_git_status
,
check_img_size
,
\
check_requirements
,
print_mutation
,
set_logging
,
one_cycle
,
colorstr
from
utils.google_utils
import
attempt_download
from
utils.loss
import
compute_l
oss
from
utils.loss
import
ComputeL
oss
from
utils.plots
import
plot_images
,
plot_labels
,
plot_results
,
plot_evolution
from
utils.torch_utils
import
ModelEMA
,
select_device
,
intersect_dicts
,
torch_distributed_zero_first
...
...
@@ -227,6 +227,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
results
=
(
0
,
0
,
0
,
0
,
0
,
0
,
0
)
# P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
scheduler
.
last_epoch
=
start_epoch
-
1
# do not move
scaler
=
amp
.
GradScaler
(
enabled
=
cuda
)
compute_loss
=
ComputeLoss
(
model
)
# init loss class
logger
.
info
(
f
'Image sizes {imgsz} train, {imgsz_test} test
\n
'
f
'Using {dataloader.num_workers} dataloader workers
\n
'
f
'Logging results to {save_dir}
\n
'
...
...
@@ -286,7 +287,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# Forward
with
amp
.
autocast
(
enabled
=
cuda
):
pred
=
model
(
imgs
)
# forward
loss
,
loss_items
=
compute_loss
(
pred
,
targets
.
to
(
device
)
,
model
)
# loss scaled by batch_size
loss
,
loss_items
=
compute_loss
(
pred
,
targets
.
to
(
device
))
# loss scaled by batch_size
if
rank
!=
-
1
:
loss
*=
opt
.
world_size
# gradient averaged between devices in DDP mode
if
opt
.
quad
:
...
...
@@ -344,7 +345,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
dataloader
=
testloader
,
save_dir
=
save_dir
,
plots
=
plots
and
final_epoch
,
log_imgs
=
opt
.
log_imgs
if
wandb
else
0
)
log_imgs
=
opt
.
log_imgs
if
wandb
else
0
,
compute_loss
=
compute_loss
)
# Write
with
open
(
results_file
,
'a'
)
as
f
:
...
...
utils/loss.py
浏览文件 @
ca9babb8
...
...
@@ -85,34 +85,45 @@ class QFocalLoss(nn.Module):
return
loss
def
compute_loss
(
p
,
targets
,
model
):
# predictions, targets, model
device
=
targets
.
device
lcls
,
lbox
,
lobj
=
torch
.
zeros
(
1
,
device
=
device
),
torch
.
zeros
(
1
,
device
=
device
),
torch
.
zeros
(
1
,
device
=
device
)
tcls
,
tbox
,
indices
,
anchors
=
build_targets
(
p
,
targets
,
model
)
# targets
class
ComputeLoss
:
# Compute losses
def
__init__
(
self
,
model
,
autobalance
=
False
):
super
(
ComputeLoss
,
self
)
.
__init__
()
device
=
next
(
model
.
parameters
())
.
device
# get model device
h
=
model
.
hyp
# hyperparameters
# Define criteria
BCEcls
=
nn
.
BCEWithLogitsLoss
(
pos_weight
=
torch
.
tensor
([
h
[
'cls_pw'
]],
device
=
device
))
# weight=model.class_weights
)
BCEcls
=
nn
.
BCEWithLogitsLoss
(
pos_weight
=
torch
.
tensor
([
h
[
'cls_pw'
]],
device
=
device
)
)
BCEobj
=
nn
.
BCEWithLogitsLoss
(
pos_weight
=
torch
.
tensor
([
h
[
'obj_pw'
]],
device
=
device
))
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
cp
,
cn
=
smooth_BCE
(
eps
=
0.0
)
self
.
cp
,
self
.
cn
=
smooth_BCE
(
eps
=
0.0
)
# Focal loss
g
=
h
[
'fl_gamma'
]
# focal loss gamma
if
g
>
0
:
BCEcls
,
BCEobj
=
FocalLoss
(
BCEcls
,
g
),
FocalLoss
(
BCEobj
,
g
)
det
=
model
.
module
.
model
[
-
1
]
if
is_parallel
(
model
)
else
model
.
model
[
-
1
]
# Detect() module
self
.
balance
=
{
3
:
[
3.67
,
1.0
,
0.43
],
4
:
[
3.78
,
1.0
,
0.39
,
0.22
],
5
:
[
3.88
,
1.0
,
0.37
,
0.17
,
0.10
]}[
det
.
nl
]
# self.balance = [1.0] * det.nl
self
.
ssi
=
(
det
.
stride
==
16
)
.
nonzero
(
as_tuple
=
False
)
.
item
()
# stride 16 index
self
.
BCEcls
,
self
.
BCEobj
,
self
.
gr
,
self
.
hyp
,
self
.
autobalance
=
BCEcls
,
BCEobj
,
model
.
gr
,
h
,
autobalance
for
k
in
'na'
,
'nc'
,
'nl'
,
'anchors'
:
setattr
(
self
,
k
,
getattr
(
det
,
k
))
def
__call__
(
self
,
p
,
targets
):
# predictions, targets, model
device
=
targets
.
device
lcls
,
lbox
,
lobj
=
torch
.
zeros
(
1
,
device
=
device
),
torch
.
zeros
(
1
,
device
=
device
),
torch
.
zeros
(
1
,
device
=
device
)
tcls
,
tbox
,
indices
,
anchors
=
self
.
build_targets
(
p
,
targets
)
# targets
# Losses
nt
=
0
# number of targets
balance
=
[
4.0
,
1.0
,
0.3
,
0.1
,
0.03
]
# P3-P7
for
i
,
pi
in
enumerate
(
p
):
# layer index, layer predictions
b
,
a
,
gj
,
gi
=
indices
[
i
]
# image, anchor, gridy, gridx
tobj
=
torch
.
zeros_like
(
pi
[
...
,
0
],
device
=
device
)
# target obj
n
=
b
.
shape
[
0
]
# number of targets
if
n
:
nt
+=
n
# cumulative targets
ps
=
pi
[
b
,
a
,
gj
,
gi
]
# prediction subset corresponding to targets
# Regression
...
...
@@ -123,33 +134,36 @@ def compute_loss(p, targets, model): # predictions, targets, model
lbox
+=
(
1.0
-
iou
)
.
mean
()
# iou loss
# Objectness
tobj
[
b
,
a
,
gj
,
gi
]
=
(
1.0
-
model
.
gr
)
+
model
.
gr
*
iou
.
detach
()
.
clamp
(
0
)
.
type
(
tobj
.
dtype
)
# iou ratio
tobj
[
b
,
a
,
gj
,
gi
]
=
(
1.0
-
self
.
gr
)
+
self
.
gr
*
iou
.
detach
()
.
clamp
(
0
)
.
type
(
tobj
.
dtype
)
# iou ratio
# Classification
if
model
.
nc
>
1
:
# cls loss (only if multiple classes)
t
=
torch
.
full_like
(
ps
[:,
5
:],
cn
,
device
=
device
)
# targets
t
[
range
(
n
),
tcls
[
i
]]
=
cp
lcls
+=
BCEcls
(
ps
[:,
5
:],
t
)
# BCE
if
self
.
nc
>
1
:
# cls loss (only if multiple classes)
t
=
torch
.
full_like
(
ps
[:,
5
:],
self
.
cn
,
device
=
device
)
# targets
t
[
range
(
n
),
tcls
[
i
]]
=
self
.
cp
lcls
+=
self
.
BCEcls
(
ps
[:,
5
:],
t
)
# BCE
# Append targets to text file
# with open('targets.txt', 'a') as file:
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
lobj
+=
BCEobj
(
pi
[
...
,
4
],
tobj
)
*
balance
[
i
]
# obj loss
obji
=
self
.
BCEobj
(
pi
[
...
,
4
],
tobj
)
lobj
+=
obji
*
self
.
balance
[
i
]
# obj loss
if
self
.
autobalance
:
self
.
balance
[
i
]
=
self
.
balance
[
i
]
*
0.9999
+
0.0001
/
obji
.
detach
()
.
item
()
lbox
*=
h
[
'box'
]
lobj
*=
h
[
'obj'
]
lcls
*=
h
[
'cls'
]
if
self
.
autobalance
:
self
.
balance
=
[
x
/
self
.
balance
[
self
.
ssi
]
for
x
in
self
.
balance
]
lbox
*=
self
.
hyp
[
'box'
]
lobj
*=
self
.
hyp
[
'obj'
]
lcls
*=
self
.
hyp
[
'cls'
]
bs
=
tobj
.
shape
[
0
]
# batch size
loss
=
lbox
+
lobj
+
lcls
return
loss
*
bs
,
torch
.
cat
((
lbox
,
lobj
,
lcls
,
loss
))
.
detach
()
def
build_targets
(
p
,
targets
,
model
):
def
build_targets
(
self
,
p
,
targets
):
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
det
=
model
.
module
.
model
[
-
1
]
if
is_parallel
(
model
)
else
model
.
model
[
-
1
]
# Detect() module
na
,
nt
=
det
.
na
,
targets
.
shape
[
0
]
# number of anchors, targets
na
,
nt
=
self
.
na
,
targets
.
shape
[
0
]
# number of anchors, targets
tcls
,
tbox
,
indices
,
anch
=
[],
[],
[],
[]
gain
=
torch
.
ones
(
7
,
device
=
targets
.
device
)
# normalized to gridspace gain
ai
=
torch
.
arange
(
na
,
device
=
targets
.
device
)
.
float
()
.
view
(
na
,
1
)
.
repeat
(
1
,
nt
)
# same as .repeat_interleave(nt)
...
...
@@ -161,8 +175,8 @@ def build_targets(p, targets, model):
# [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
],
device
=
targets
.
device
)
.
float
()
*
g
# offsets
for
i
in
range
(
det
.
nl
):
anchors
=
det
.
anchors
[
i
]
for
i
in
range
(
self
.
nl
):
anchors
=
self
.
anchors
[
i
]
gain
[
2
:
6
]
=
torch
.
tensor
(
p
[
i
]
.
shape
)[[
3
,
2
,
3
,
2
]]
# xyxy gain
# Match targets to anchors
...
...
@@ -170,7 +184,7 @@ def build_targets(p, targets, model):
if
nt
:
# Matches
r
=
t
[:,
:,
4
:
6
]
/
anchors
[:,
None
]
# wh ratio
j
=
torch
.
max
(
r
,
1.
/
r
)
.
max
(
2
)[
0
]
<
model
.
hyp
[
'anchor_t'
]
# compare
j
=
torch
.
max
(
r
,
1.
/
r
)
.
max
(
2
)[
0
]
<
self
.
hyp
[
'anchor_t'
]
# compare
# j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
t
=
t
[
j
]
# filter
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论