Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
Y
yolov5
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
Administrator
yolov5
Commits
ca9babb8
Unverified
提交
ca9babb8
authored
1月 15, 2021
作者:
Glenn Jocher
提交者:
GitHub
1月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add ComputeLoss() class (#1950)
上级
f4a78e1b
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
138 行增加
和
122 行删除
+138
-122
test.py
test.py
+4
-4
train.py
train.py
+5
-3
loss.py
utils/loss.py
+129
-115
没有找到文件。
test.py
浏览文件 @
ca9babb8
...
@@ -13,7 +13,6 @@ from models.experimental import attempt_load
...
@@ -13,7 +13,6 @@ from models.experimental import attempt_load
from
utils.datasets
import
create_dataloader
from
utils.datasets
import
create_dataloader
from
utils.general
import
coco80_to_coco91_class
,
check_dataset
,
check_file
,
check_img_size
,
check_requirements
,
\
from
utils.general
import
coco80_to_coco91_class
,
check_dataset
,
check_file
,
check_img_size
,
check_requirements
,
\
box_iou
,
non_max_suppression
,
scale_coords
,
xyxy2xywh
,
xywh2xyxy
,
set_logging
,
increment_path
,
colorstr
box_iou
,
non_max_suppression
,
scale_coords
,
xyxy2xywh
,
xywh2xyxy
,
set_logging
,
increment_path
,
colorstr
from
utils.loss
import
compute_loss
from
utils.metrics
import
ap_per_class
,
ConfusionMatrix
from
utils.metrics
import
ap_per_class
,
ConfusionMatrix
from
utils.plots
import
plot_images
,
output_to_target
,
plot_study_txt
from
utils.plots
import
plot_images
,
output_to_target
,
plot_study_txt
from
utils.torch_utils
import
select_device
,
time_synchronized
from
utils.torch_utils
import
select_device
,
time_synchronized
...
@@ -36,7 +35,8 @@ def test(data,
...
@@ -36,7 +35,8 @@ def test(data,
save_hybrid
=
False
,
# for hybrid auto-labelling
save_hybrid
=
False
,
# for hybrid auto-labelling
save_conf
=
False
,
# save auto-label confidences
save_conf
=
False
,
# save auto-label confidences
plots
=
True
,
plots
=
True
,
log_imgs
=
0
):
# number of logged images
log_imgs
=
0
,
# number of logged images
compute_loss
=
None
):
# Initialize/load model and set device
# Initialize/load model and set device
training
=
model
is
not
None
training
=
model
is
not
None
...
@@ -111,8 +111,8 @@ def test(data,
...
@@ -111,8 +111,8 @@ def test(data,
t0
+=
time_synchronized
()
-
t
t0
+=
time_synchronized
()
-
t
# Compute loss
# Compute loss
if
training
:
if
compute_loss
:
loss
+=
compute_loss
([
x
.
float
()
for
x
in
train_out
],
targets
,
model
)[
1
][:
3
]
# box, obj, cls
loss
+=
compute_loss
([
x
.
float
()
for
x
in
train_out
],
targets
)[
1
][:
3
]
# box, obj, cls
# Run NMS
# Run NMS
targets
[:,
2
:]
*=
torch
.
Tensor
([
width
,
height
,
width
,
height
])
.
to
(
device
)
# to pixels
targets
[:,
2
:]
*=
torch
.
Tensor
([
width
,
height
,
width
,
height
])
.
to
(
device
)
# to pixels
...
...
train.py
浏览文件 @
ca9babb8
...
@@ -29,7 +29,7 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima
...
@@ -29,7 +29,7 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima
fitness
,
strip_optimizer
,
get_latest_run
,
check_dataset
,
check_file
,
check_git_status
,
check_img_size
,
\
fitness
,
strip_optimizer
,
get_latest_run
,
check_dataset
,
check_file
,
check_git_status
,
check_img_size
,
\
check_requirements
,
print_mutation
,
set_logging
,
one_cycle
,
colorstr
check_requirements
,
print_mutation
,
set_logging
,
one_cycle
,
colorstr
from
utils.google_utils
import
attempt_download
from
utils.google_utils
import
attempt_download
from
utils.loss
import
compute_l
oss
from
utils.loss
import
ComputeL
oss
from
utils.plots
import
plot_images
,
plot_labels
,
plot_results
,
plot_evolution
from
utils.plots
import
plot_images
,
plot_labels
,
plot_results
,
plot_evolution
from
utils.torch_utils
import
ModelEMA
,
select_device
,
intersect_dicts
,
torch_distributed_zero_first
from
utils.torch_utils
import
ModelEMA
,
select_device
,
intersect_dicts
,
torch_distributed_zero_first
...
@@ -227,6 +227,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
...
@@ -227,6 +227,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
results
=
(
0
,
0
,
0
,
0
,
0
,
0
,
0
)
# P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
results
=
(
0
,
0
,
0
,
0
,
0
,
0
,
0
)
# P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
scheduler
.
last_epoch
=
start_epoch
-
1
# do not move
scheduler
.
last_epoch
=
start_epoch
-
1
# do not move
scaler
=
amp
.
GradScaler
(
enabled
=
cuda
)
scaler
=
amp
.
GradScaler
(
enabled
=
cuda
)
compute_loss
=
ComputeLoss
(
model
)
# init loss class
logger
.
info
(
f
'Image sizes {imgsz} train, {imgsz_test} test
\n
'
logger
.
info
(
f
'Image sizes {imgsz} train, {imgsz_test} test
\n
'
f
'Using {dataloader.num_workers} dataloader workers
\n
'
f
'Using {dataloader.num_workers} dataloader workers
\n
'
f
'Logging results to {save_dir}
\n
'
f
'Logging results to {save_dir}
\n
'
...
@@ -286,7 +287,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
...
@@ -286,7 +287,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# Forward
# Forward
with
amp
.
autocast
(
enabled
=
cuda
):
with
amp
.
autocast
(
enabled
=
cuda
):
pred
=
model
(
imgs
)
# forward
pred
=
model
(
imgs
)
# forward
loss
,
loss_items
=
compute_loss
(
pred
,
targets
.
to
(
device
)
,
model
)
# loss scaled by batch_size
loss
,
loss_items
=
compute_loss
(
pred
,
targets
.
to
(
device
))
# loss scaled by batch_size
if
rank
!=
-
1
:
if
rank
!=
-
1
:
loss
*=
opt
.
world_size
# gradient averaged between devices in DDP mode
loss
*=
opt
.
world_size
# gradient averaged between devices in DDP mode
if
opt
.
quad
:
if
opt
.
quad
:
...
@@ -344,7 +345,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
...
@@ -344,7 +345,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
dataloader
=
testloader
,
dataloader
=
testloader
,
save_dir
=
save_dir
,
save_dir
=
save_dir
,
plots
=
plots
and
final_epoch
,
plots
=
plots
and
final_epoch
,
log_imgs
=
opt
.
log_imgs
if
wandb
else
0
)
log_imgs
=
opt
.
log_imgs
if
wandb
else
0
,
compute_loss
=
compute_loss
)
# Write
# Write
with
open
(
results_file
,
'a'
)
as
f
:
with
open
(
results_file
,
'a'
)
as
f
:
...
...
utils/loss.py
浏览文件 @
ca9babb8
...
@@ -85,119 +85,133 @@ class QFocalLoss(nn.Module):
...
@@ -85,119 +85,133 @@ class QFocalLoss(nn.Module):
return
loss
return
loss
def
compute_loss
(
p
,
targets
,
model
):
# predictions, targets, model
class
ComputeLoss
:
device
=
targets
.
device
# Compute losses
lcls
,
lbox
,
lobj
=
torch
.
zeros
(
1
,
device
=
device
),
torch
.
zeros
(
1
,
device
=
device
),
torch
.
zeros
(
1
,
device
=
device
)
def
__init__
(
self
,
model
,
autobalance
=
False
):
tcls
,
tbox
,
indices
,
anchors
=
build_targets
(
p
,
targets
,
model
)
# targets
super
(
ComputeLoss
,
self
)
.
__init__
()
h
=
model
.
hyp
# hyperparameters
device
=
next
(
model
.
parameters
())
.
device
# get model device
h
=
model
.
hyp
# hyperparameters
# Define criteria
BCEcls
=
nn
.
BCEWithLogitsLoss
(
pos_weight
=
torch
.
tensor
([
h
[
'cls_pw'
]],
device
=
device
))
# weight=model.class_weights)
# Define criteria
BCEobj
=
nn
.
BCEWithLogitsLoss
(
pos_weight
=
torch
.
tensor
([
h
[
'obj_pw'
]],
device
=
device
))
BCEcls
=
nn
.
BCEWithLogitsLoss
(
pos_weight
=
torch
.
tensor
([
h
[
'cls_pw'
]],
device
=
device
))
BCEobj
=
nn
.
BCEWithLogitsLoss
(
pos_weight
=
torch
.
tensor
([
h
[
'obj_pw'
]],
device
=
device
))
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
cp
,
cn
=
smooth_BCE
(
eps
=
0.0
)
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
self
.
cp
,
self
.
cn
=
smooth_BCE
(
eps
=
0.0
)
# Focal loss
g
=
h
[
'fl_gamma'
]
# focal loss gamma
# Focal loss
if
g
>
0
:
g
=
h
[
'fl_gamma'
]
# focal loss gamma
BCEcls
,
BCEobj
=
FocalLoss
(
BCEcls
,
g
),
FocalLoss
(
BCEobj
,
g
)
if
g
>
0
:
BCEcls
,
BCEobj
=
FocalLoss
(
BCEcls
,
g
),
FocalLoss
(
BCEobj
,
g
)
# Losses
nt
=
0
# number of targets
det
=
model
.
module
.
model
[
-
1
]
if
is_parallel
(
model
)
else
model
.
model
[
-
1
]
# Detect() module
balance
=
[
4.0
,
1.0
,
0.3
,
0.1
,
0.03
]
# P3-P7
self
.
balance
=
{
3
:
[
3.67
,
1.0
,
0.43
],
4
:
[
3.78
,
1.0
,
0.39
,
0.22
],
5
:
[
3.88
,
1.0
,
0.37
,
0.17
,
0.10
]}[
det
.
nl
]
for
i
,
pi
in
enumerate
(
p
):
# layer index, layer predictions
# self.balance = [1.0] * det.nl
b
,
a
,
gj
,
gi
=
indices
[
i
]
# image, anchor, gridy, gridx
self
.
ssi
=
(
det
.
stride
==
16
)
.
nonzero
(
as_tuple
=
False
)
.
item
()
# stride 16 index
tobj
=
torch
.
zeros_like
(
pi
[
...
,
0
],
device
=
device
)
# target obj
self
.
BCEcls
,
self
.
BCEobj
,
self
.
gr
,
self
.
hyp
,
self
.
autobalance
=
BCEcls
,
BCEobj
,
model
.
gr
,
h
,
autobalance
for
k
in
'na'
,
'nc'
,
'nl'
,
'anchors'
:
n
=
b
.
shape
[
0
]
# number of targets
setattr
(
self
,
k
,
getattr
(
det
,
k
))
if
n
:
nt
+=
n
# cumulative targets
def
__call__
(
self
,
p
,
targets
):
# predictions, targets, model
ps
=
pi
[
b
,
a
,
gj
,
gi
]
# prediction subset corresponding to targets
device
=
targets
.
device
lcls
,
lbox
,
lobj
=
torch
.
zeros
(
1
,
device
=
device
),
torch
.
zeros
(
1
,
device
=
device
),
torch
.
zeros
(
1
,
device
=
device
)
# Regression
tcls
,
tbox
,
indices
,
anchors
=
self
.
build_targets
(
p
,
targets
)
# targets
pxy
=
ps
[:,
:
2
]
.
sigmoid
()
*
2.
-
0.5
pwh
=
(
ps
[:,
2
:
4
]
.
sigmoid
()
*
2
)
**
2
*
anchors
[
i
]
# Losses
pbox
=
torch
.
cat
((
pxy
,
pwh
),
1
)
# predicted box
for
i
,
pi
in
enumerate
(
p
):
# layer index, layer predictions
iou
=
bbox_iou
(
pbox
.
T
,
tbox
[
i
],
x1y1x2y2
=
False
,
CIoU
=
True
)
# iou(prediction, target)
b
,
a
,
gj
,
gi
=
indices
[
i
]
# image, anchor, gridy, gridx
lbox
+=
(
1.0
-
iou
)
.
mean
()
# iou loss
tobj
=
torch
.
zeros_like
(
pi
[
...
,
0
],
device
=
device
)
# target obj
# Objectness
n
=
b
.
shape
[
0
]
# number of targets
tobj
[
b
,
a
,
gj
,
gi
]
=
(
1.0
-
model
.
gr
)
+
model
.
gr
*
iou
.
detach
()
.
clamp
(
0
)
.
type
(
tobj
.
dtype
)
# iou ratio
if
n
:
ps
=
pi
[
b
,
a
,
gj
,
gi
]
# prediction subset corresponding to targets
# Classification
if
model
.
nc
>
1
:
# cls loss (only if multiple classes)
# Regression
t
=
torch
.
full_like
(
ps
[:,
5
:],
cn
,
device
=
device
)
# targets
pxy
=
ps
[:,
:
2
]
.
sigmoid
()
*
2.
-
0.5
t
[
range
(
n
),
tcls
[
i
]]
=
cp
pwh
=
(
ps
[:,
2
:
4
]
.
sigmoid
()
*
2
)
**
2
*
anchors
[
i
]
lcls
+=
BCEcls
(
ps
[:,
5
:],
t
)
# BCE
pbox
=
torch
.
cat
((
pxy
,
pwh
),
1
)
# predicted box
iou
=
bbox_iou
(
pbox
.
T
,
tbox
[
i
],
x1y1x2y2
=
False
,
CIoU
=
True
)
# iou(prediction, target)
# Append targets to text file
lbox
+=
(
1.0
-
iou
)
.
mean
()
# iou loss
# with open('targets.txt', 'a') as file:
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
# Objectness
tobj
[
b
,
a
,
gj
,
gi
]
=
(
1.0
-
self
.
gr
)
+
self
.
gr
*
iou
.
detach
()
.
clamp
(
0
)
.
type
(
tobj
.
dtype
)
# iou ratio
lobj
+=
BCEobj
(
pi
[
...
,
4
],
tobj
)
*
balance
[
i
]
# obj loss
# Classification
lbox
*=
h
[
'box'
]
if
self
.
nc
>
1
:
# cls loss (only if multiple classes)
lobj
*=
h
[
'obj'
]
t
=
torch
.
full_like
(
ps
[:,
5
:],
self
.
cn
,
device
=
device
)
# targets
lcls
*=
h
[
'cls'
]
t
[
range
(
n
),
tcls
[
i
]]
=
self
.
cp
bs
=
tobj
.
shape
[
0
]
# batch size
lcls
+=
self
.
BCEcls
(
ps
[:,
5
:],
t
)
# BCE
loss
=
lbox
+
lobj
+
lcls
# Append targets to text file
return
loss
*
bs
,
torch
.
cat
((
lbox
,
lobj
,
lcls
,
loss
))
.
detach
()
# with open('targets.txt', 'a') as file:
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
def
build_targets
(
p
,
targets
,
model
):
obji
=
self
.
BCEobj
(
pi
[
...
,
4
],
tobj
)
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
lobj
+=
obji
*
self
.
balance
[
i
]
# obj loss
det
=
model
.
module
.
model
[
-
1
]
if
is_parallel
(
model
)
else
model
.
model
[
-
1
]
# Detect() module
if
self
.
autobalance
:
na
,
nt
=
det
.
na
,
targets
.
shape
[
0
]
# number of anchors, targets
self
.
balance
[
i
]
=
self
.
balance
[
i
]
*
0.9999
+
0.0001
/
obji
.
detach
()
.
item
()
tcls
,
tbox
,
indices
,
anch
=
[],
[],
[],
[]
gain
=
torch
.
ones
(
7
,
device
=
targets
.
device
)
# normalized to gridspace gain
if
self
.
autobalance
:
ai
=
torch
.
arange
(
na
,
device
=
targets
.
device
)
.
float
()
.
view
(
na
,
1
)
.
repeat
(
1
,
nt
)
# same as .repeat_interleave(nt)
self
.
balance
=
[
x
/
self
.
balance
[
self
.
ssi
]
for
x
in
self
.
balance
]
targets
=
torch
.
cat
((
targets
.
repeat
(
na
,
1
,
1
),
ai
[:,
:,
None
]),
2
)
# append anchor indices
lbox
*=
self
.
hyp
[
'box'
]
lobj
*=
self
.
hyp
[
'obj'
]
g
=
0.5
# bias
lcls
*=
self
.
hyp
[
'cls'
]
off
=
torch
.
tensor
([[
0
,
0
],
bs
=
tobj
.
shape
[
0
]
# batch size
[
1
,
0
],
[
0
,
1
],
[
-
1
,
0
],
[
0
,
-
1
],
# j,k,l,m
# [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
loss
=
lbox
+
lobj
+
lcls
],
device
=
targets
.
device
)
.
float
()
*
g
# offsets
return
loss
*
bs
,
torch
.
cat
((
lbox
,
lobj
,
lcls
,
loss
))
.
detach
()
for
i
in
range
(
det
.
nl
):
def
build_targets
(
self
,
p
,
targets
):
anchors
=
det
.
anchors
[
i
]
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
gain
[
2
:
6
]
=
torch
.
tensor
(
p
[
i
]
.
shape
)[[
3
,
2
,
3
,
2
]]
# xyxy gain
na
,
nt
=
self
.
na
,
targets
.
shape
[
0
]
# number of anchors, targets
tcls
,
tbox
,
indices
,
anch
=
[],
[],
[],
[]
# Match targets to anchors
gain
=
torch
.
ones
(
7
,
device
=
targets
.
device
)
# normalized to gridspace gain
t
=
targets
*
gain
ai
=
torch
.
arange
(
na
,
device
=
targets
.
device
)
.
float
()
.
view
(
na
,
1
)
.
repeat
(
1
,
nt
)
# same as .repeat_interleave(nt)
if
nt
:
targets
=
torch
.
cat
((
targets
.
repeat
(
na
,
1
,
1
),
ai
[:,
:,
None
]),
2
)
# append anchor indices
# Matches
r
=
t
[:,
:,
4
:
6
]
/
anchors
[:,
None
]
# wh ratio
g
=
0.5
# bias
j
=
torch
.
max
(
r
,
1.
/
r
)
.
max
(
2
)[
0
]
<
model
.
hyp
[
'anchor_t'
]
# compare
off
=
torch
.
tensor
([[
0
,
0
],
# j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
[
1
,
0
],
[
0
,
1
],
[
-
1
,
0
],
[
0
,
-
1
],
# j,k,l,m
t
=
t
[
j
]
# filter
# [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
],
device
=
targets
.
device
)
.
float
()
*
g
# offsets
# Offsets
for
i
in
range
(
self
.
nl
):
anchors
=
self
.
anchors
[
i
]
gain
[
2
:
6
]
=
torch
.
tensor
(
p
[
i
]
.
shape
)[[
3
,
2
,
3
,
2
]]
# xyxy gain
# Match targets to anchors
t
=
targets
*
gain
if
nt
:
# Matches
r
=
t
[:,
:,
4
:
6
]
/
anchors
[:,
None
]
# wh ratio
j
=
torch
.
max
(
r
,
1.
/
r
)
.
max
(
2
)[
0
]
<
self
.
hyp
[
'anchor_t'
]
# compare
# j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
t
=
t
[
j
]
# filter
# Offsets
gxy
=
t
[:,
2
:
4
]
# grid xy
gxi
=
gain
[[
2
,
3
]]
-
gxy
# inverse
j
,
k
=
((
gxy
%
1.
<
g
)
&
(
gxy
>
1.
))
.
T
l
,
m
=
((
gxi
%
1.
<
g
)
&
(
gxi
>
1.
))
.
T
j
=
torch
.
stack
((
torch
.
ones_like
(
j
),
j
,
k
,
l
,
m
))
t
=
t
.
repeat
((
5
,
1
,
1
))[
j
]
offsets
=
(
torch
.
zeros_like
(
gxy
)[
None
]
+
off
[:,
None
])[
j
]
else
:
t
=
targets
[
0
]
offsets
=
0
# Define
b
,
c
=
t
[:,
:
2
]
.
long
()
.
T
# image, class
gxy
=
t
[:,
2
:
4
]
# grid xy
gxy
=
t
[:,
2
:
4
]
# grid xy
gxi
=
gain
[[
2
,
3
]]
-
gxy
# inverse
gwh
=
t
[:,
4
:
6
]
# grid wh
j
,
k
=
((
gxy
%
1.
<
g
)
&
(
gxy
>
1.
))
.
T
gij
=
(
gxy
-
offsets
)
.
long
()
l
,
m
=
((
gxi
%
1.
<
g
)
&
(
gxi
>
1.
))
.
T
gi
,
gj
=
gij
.
T
# grid xy indices
j
=
torch
.
stack
((
torch
.
ones_like
(
j
),
j
,
k
,
l
,
m
))
t
=
t
.
repeat
((
5
,
1
,
1
))[
j
]
# Append
offsets
=
(
torch
.
zeros_like
(
gxy
)[
None
]
+
off
[:,
None
])[
j
]
a
=
t
[:,
6
]
.
long
()
# anchor indices
else
:
indices
.
append
((
b
,
a
,
gj
.
clamp_
(
0
,
gain
[
3
]
-
1
),
gi
.
clamp_
(
0
,
gain
[
2
]
-
1
)))
# image, anchor, grid indices
t
=
targets
[
0
]
tbox
.
append
(
torch
.
cat
((
gxy
-
gij
,
gwh
),
1
))
# box
offsets
=
0
anch
.
append
(
anchors
[
a
])
# anchors
tcls
.
append
(
c
)
# class
# Define
b
,
c
=
t
[:,
:
2
]
.
long
()
.
T
# image, class
return
tcls
,
tbox
,
indices
,
anch
gxy
=
t
[:,
2
:
4
]
# grid xy
gwh
=
t
[:,
4
:
6
]
# grid wh
gij
=
(
gxy
-
offsets
)
.
long
()
gi
,
gj
=
gij
.
T
# grid xy indices
# Append
a
=
t
[:,
6
]
.
long
()
# anchor indices
indices
.
append
((
b
,
a
,
gj
.
clamp_
(
0
,
gain
[
3
]
-
1
),
gi
.
clamp_
(
0
,
gain
[
2
]
-
1
)))
# image, anchor, grid indices
tbox
.
append
(
torch
.
cat
((
gxy
-
gij
,
gwh
),
1
))
# box
anch
.
append
(
anchors
[
a
])
# anchors
tcls
.
append
(
c
)
# class
return
tcls
,
tbox
,
indices
,
anch
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论