Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
Y
yolov5
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
Administrator
yolov5
Commits
dd03b20b
Unverified
提交
dd03b20b
authored
1月 12, 2021
作者:
Glenn Jocher
提交者:
GitHub
1月 12, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
colorstr() updates (#1909)
* W&B ImportError message fix * colorstr() updates * colorstr() updates * colorstr() default to 'blue', 'bold' * train: magenta * train: blue
上级
1d1c0567
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
46 行增加
和
45 行删除
+46
-45
test.py
test.py
+3
-2
train.py
train.py
+11
-10
autoanchor.py
utils/autoanchor.py
+2
-2
datasets.py
utils/datasets.py
+27
-27
general.py
utils/general.py
+2
-3
plots.py
utils/plots.py
+1
-1
没有找到文件。
test.py
浏览文件 @
dd03b20b
...
@@ -12,7 +12,7 @@ from tqdm import tqdm
...
@@ -12,7 +12,7 @@ from tqdm import tqdm
from
models.experimental
import
attempt_load
from
models.experimental
import
attempt_load
from
utils.datasets
import
create_dataloader
from
utils.datasets
import
create_dataloader
from
utils.general
import
coco80_to_coco91_class
,
check_dataset
,
check_file
,
check_img_size
,
check_requirements
,
\
from
utils.general
import
coco80_to_coco91_class
,
check_dataset
,
check_file
,
check_img_size
,
check_requirements
,
\
box_iou
,
non_max_suppression
,
scale_coords
,
xyxy2xywh
,
xywh2xyxy
,
set_logging
,
increment_path
box_iou
,
non_max_suppression
,
scale_coords
,
xyxy2xywh
,
xywh2xyxy
,
set_logging
,
increment_path
,
colorstr
from
utils.loss
import
compute_loss
from
utils.loss
import
compute_loss
from
utils.metrics
import
ap_per_class
,
ConfusionMatrix
from
utils.metrics
import
ap_per_class
,
ConfusionMatrix
from
utils.plots
import
plot_images
,
output_to_target
,
plot_study_txt
from
utils.plots
import
plot_images
,
output_to_target
,
plot_study_txt
...
@@ -86,7 +86,8 @@ def test(data,
...
@@ -86,7 +86,8 @@ def test(data,
img
=
torch
.
zeros
((
1
,
3
,
imgsz
,
imgsz
),
device
=
device
)
# init img
img
=
torch
.
zeros
((
1
,
3
,
imgsz
,
imgsz
),
device
=
device
)
# init img
_
=
model
(
img
.
half
()
if
half
else
img
)
if
device
.
type
!=
'cpu'
else
None
# run once
_
=
model
(
img
.
half
()
if
half
else
img
)
if
device
.
type
!=
'cpu'
else
None
# run once
path
=
data
[
'test'
]
if
opt
.
task
==
'test'
else
data
[
'val'
]
# path to val/test images
path
=
data
[
'test'
]
if
opt
.
task
==
'test'
else
data
[
'val'
]
# path to val/test images
dataloader
=
create_dataloader
(
path
,
imgsz
,
batch_size
,
model
.
stride
.
max
(),
opt
,
pad
=
0.5
,
rect
=
True
)[
0
]
dataloader
=
create_dataloader
(
path
,
imgsz
,
batch_size
,
model
.
stride
.
max
(),
opt
,
pad
=
0.5
,
rect
=
True
,
prefix
=
colorstr
(
'test: '
if
opt
.
task
==
'test'
else
'val: '
))[
0
]
seen
=
0
seen
=
0
confusion_matrix
=
ConfusionMatrix
(
nc
=
nc
)
confusion_matrix
=
ConfusionMatrix
(
nc
=
nc
)
...
...
train.py
浏览文件 @
dd03b20b
...
@@ -36,15 +36,9 @@ from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_di
...
@@ -36,15 +36,9 @@ from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_di
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
try
:
import
wandb
except
ImportError
:
wandb
=
None
logger
.
info
(
"Install Weights & Biases for experiment logging via 'pip install wandb' (recommended)"
)
def
train
(
hyp
,
opt
,
device
,
tb_writer
=
None
,
wandb
=
None
):
def
train
(
hyp
,
opt
,
device
,
tb_writer
=
None
,
wandb
=
None
):
logger
.
info
(
colorstr
(
'
blue'
,
'bold'
,
'
Hyperparameters: '
)
+
', '
.
join
(
f
'{k}={v}'
for
k
,
v
in
hyp
.
items
()))
logger
.
info
(
colorstr
(
'Hyperparameters: '
)
+
', '
.
join
(
f
'{k}={v}'
for
k
,
v
in
hyp
.
items
()))
save_dir
,
epochs
,
batch_size
,
total_batch_size
,
weights
,
rank
=
\
save_dir
,
epochs
,
batch_size
,
total_batch_size
,
weights
,
rank
=
\
Path
(
opt
.
save_dir
),
opt
.
epochs
,
opt
.
batch_size
,
opt
.
total_batch_size
,
opt
.
weights
,
opt
.
global_rank
Path
(
opt
.
save_dir
),
opt
.
epochs
,
opt
.
batch_size
,
opt
.
total_batch_size
,
opt
.
weights
,
opt
.
global_rank
...
@@ -189,7 +183,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
...
@@ -189,7 +183,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
dataloader
,
dataset
=
create_dataloader
(
train_path
,
imgsz
,
batch_size
,
gs
,
opt
,
dataloader
,
dataset
=
create_dataloader
(
train_path
,
imgsz
,
batch_size
,
gs
,
opt
,
hyp
=
hyp
,
augment
=
True
,
cache
=
opt
.
cache_images
,
rect
=
opt
.
rect
,
rank
=
rank
,
hyp
=
hyp
,
augment
=
True
,
cache
=
opt
.
cache_images
,
rect
=
opt
.
rect
,
rank
=
rank
,
world_size
=
opt
.
world_size
,
workers
=
opt
.
workers
,
world_size
=
opt
.
world_size
,
workers
=
opt
.
workers
,
image_weights
=
opt
.
image_weights
,
quad
=
opt
.
quad
)
image_weights
=
opt
.
image_weights
,
quad
=
opt
.
quad
,
prefix
=
colorstr
(
'train: '
)
)
mlc
=
np
.
concatenate
(
dataset
.
labels
,
0
)[:,
0
]
.
max
()
# max label class
mlc
=
np
.
concatenate
(
dataset
.
labels
,
0
)[:,
0
]
.
max
()
# max label class
nb
=
len
(
dataloader
)
# number of batches
nb
=
len
(
dataloader
)
# number of batches
assert
mlc
<
nc
,
'Label class
%
g exceeds nc=
%
g in
%
s. Possible class labels are 0-
%
g'
%
(
mlc
,
nc
,
opt
.
data
,
nc
-
1
)
assert
mlc
<
nc
,
'Label class
%
g exceeds nc=
%
g in
%
s. Possible class labels are 0-
%
g'
%
(
mlc
,
nc
,
opt
.
data
,
nc
-
1
)
...
@@ -198,8 +192,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
...
@@ -198,8 +192,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
if
rank
in
[
-
1
,
0
]:
if
rank
in
[
-
1
,
0
]:
ema
.
updates
=
start_epoch
*
nb
//
accumulate
# set EMA updates
ema
.
updates
=
start_epoch
*
nb
//
accumulate
# set EMA updates
testloader
=
create_dataloader
(
test_path
,
imgsz_test
,
total_batch_size
,
gs
,
opt
,
# testloader
testloader
=
create_dataloader
(
test_path
,
imgsz_test
,
total_batch_size
,
gs
,
opt
,
# testloader
hyp
=
hyp
,
cache
=
opt
.
cache_images
and
not
opt
.
notest
,
rect
=
True
,
hyp
=
hyp
,
cache
=
opt
.
cache_images
and
not
opt
.
notest
,
rect
=
True
,
rank
=-
1
,
rank
=-
1
,
world_size
=
opt
.
world_size
,
workers
=
opt
.
workers
,
pad
=
0.5
)[
0
]
world_size
=
opt
.
world_size
,
workers
=
opt
.
workers
,
pad
=
0.5
,
prefix
=
colorstr
(
'val: '
))[
0
]
if
not
opt
.
resume
:
if
not
opt
.
resume
:
labels
=
np
.
concatenate
(
dataset
.
labels
,
0
)
labels
=
np
.
concatenate
(
dataset
.
labels
,
0
)
...
@@ -514,6 +509,12 @@ if __name__ == '__main__':
...
@@ -514,6 +509,12 @@ if __name__ == '__main__':
# Train
# Train
logger
.
info
(
opt
)
logger
.
info
(
opt
)
try
:
import
wandb
except
ImportError
:
wandb
=
None
prefix
=
colorstr
(
'wandb: '
)
logger
.
info
(
f
"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)"
)
if
not
opt
.
evolve
:
if
not
opt
.
evolve
:
tb_writer
=
None
# init loggers
tb_writer
=
None
# init loggers
if
opt
.
global_rank
in
[
-
1
,
0
]:
if
opt
.
global_rank
in
[
-
1
,
0
]:
...
...
utils/autoanchor.py
浏览文件 @
dd03b20b
...
@@ -22,7 +22,7 @@ def check_anchor_order(m):
...
@@ -22,7 +22,7 @@ def check_anchor_order(m):
def
check_anchors
(
dataset
,
model
,
thr
=
4.0
,
imgsz
=
640
):
def
check_anchors
(
dataset
,
model
,
thr
=
4.0
,
imgsz
=
640
):
# Check anchor fit to data, recompute if necessary
# Check anchor fit to data, recompute if necessary
prefix
=
colorstr
(
'
blue'
,
'bold'
,
'autoanchor'
)
+
': '
prefix
=
colorstr
(
'
autoanchor: '
)
print
(
f
'
\n
{prefix}Analyzing anchors... '
,
end
=
''
)
print
(
f
'
\n
{prefix}Analyzing anchors... '
,
end
=
''
)
m
=
model
.
module
.
model
[
-
1
]
if
hasattr
(
model
,
'module'
)
else
model
.
model
[
-
1
]
# Detect()
m
=
model
.
module
.
model
[
-
1
]
if
hasattr
(
model
,
'module'
)
else
model
.
model
[
-
1
]
# Detect()
shapes
=
imgsz
*
dataset
.
shapes
/
dataset
.
shapes
.
max
(
1
,
keepdims
=
True
)
shapes
=
imgsz
*
dataset
.
shapes
/
dataset
.
shapes
.
max
(
1
,
keepdims
=
True
)
...
@@ -73,7 +73,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
...
@@ -73,7 +73,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
from utils.autoanchor import *; _ = kmean_anchors()
from utils.autoanchor import *; _ = kmean_anchors()
"""
"""
thr
=
1.
/
thr
thr
=
1.
/
thr
prefix
=
colorstr
(
'
blue'
,
'bold'
,
'autoanchor'
)
+
': '
prefix
=
colorstr
(
'
autoanchor: '
)
def
metric
(
k
,
wh
):
# compute metrics
def
metric
(
k
,
wh
):
# compute metrics
r
=
wh
[:,
None
]
/
k
[
None
]
r
=
wh
[:,
None
]
/
k
[
None
]
...
...
utils/datasets.py
浏览文件 @
dd03b20b
...
@@ -56,7 +56,7 @@ def exif_size(img):
...
@@ -56,7 +56,7 @@ def exif_size(img):
def
create_dataloader
(
path
,
imgsz
,
batch_size
,
stride
,
opt
,
hyp
=
None
,
augment
=
False
,
cache
=
False
,
pad
=
0.0
,
rect
=
False
,
def
create_dataloader
(
path
,
imgsz
,
batch_size
,
stride
,
opt
,
hyp
=
None
,
augment
=
False
,
cache
=
False
,
pad
=
0.0
,
rect
=
False
,
rank
=-
1
,
world_size
=
1
,
workers
=
8
,
image_weights
=
False
,
quad
=
False
):
rank
=-
1
,
world_size
=
1
,
workers
=
8
,
image_weights
=
False
,
quad
=
False
,
prefix
=
''
):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with
torch_distributed_zero_first
(
rank
):
with
torch_distributed_zero_first
(
rank
):
dataset
=
LoadImagesAndLabels
(
path
,
imgsz
,
batch_size
,
dataset
=
LoadImagesAndLabels
(
path
,
imgsz
,
batch_size
,
...
@@ -67,8 +67,8 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
...
@@ -67,8 +67,8 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
single_cls
=
opt
.
single_cls
,
single_cls
=
opt
.
single_cls
,
stride
=
int
(
stride
),
stride
=
int
(
stride
),
pad
=
pad
,
pad
=
pad
,
rank
=
rank
,
image_weights
=
image_weights
,
image_weights
=
image_weights
)
prefix
=
prefix
)
batch_size
=
min
(
batch_size
,
len
(
dataset
))
batch_size
=
min
(
batch_size
,
len
(
dataset
))
nw
=
min
([
os
.
cpu_count
()
//
world_size
,
batch_size
if
batch_size
>
1
else
0
,
workers
])
# number of workers
nw
=
min
([
os
.
cpu_count
()
//
world_size
,
batch_size
if
batch_size
>
1
else
0
,
workers
])
# number of workers
...
@@ -129,7 +129,7 @@ class LoadImages: # for inference
...
@@ -129,7 +129,7 @@ class LoadImages: # for inference
elif
os
.
path
.
isfile
(
p
):
elif
os
.
path
.
isfile
(
p
):
files
=
[
p
]
# files
files
=
[
p
]
# files
else
:
else
:
raise
Exception
(
'ERROR:
%
s does not exist'
%
p
)
raise
Exception
(
f
'ERROR: {p} does not exist'
)
images
=
[
x
for
x
in
files
if
x
.
split
(
'.'
)[
-
1
]
.
lower
()
in
img_formats
]
images
=
[
x
for
x
in
files
if
x
.
split
(
'.'
)[
-
1
]
.
lower
()
in
img_formats
]
videos
=
[
x
for
x
in
files
if
x
.
split
(
'.'
)[
-
1
]
.
lower
()
in
vid_formats
]
videos
=
[
x
for
x
in
files
if
x
.
split
(
'.'
)[
-
1
]
.
lower
()
in
vid_formats
]
...
@@ -144,8 +144,8 @@ class LoadImages: # for inference
...
@@ -144,8 +144,8 @@ class LoadImages: # for inference
self
.
new_video
(
videos
[
0
])
# new video
self
.
new_video
(
videos
[
0
])
# new video
else
:
else
:
self
.
cap
=
None
self
.
cap
=
None
assert
self
.
nf
>
0
,
'No images or videos found in
%
s. Supported formats are:
\n
images:
%
s
\n
videos:
%
s'
%
\
assert
self
.
nf
>
0
,
f
'No images or videos found in {p}. '
\
(
p
,
img_formats
,
vid_formats
)
f
'Supported formats are:
\n
images: {img_formats}
\n
videos: {vid_formats}'
def
__iter__
(
self
):
def
__iter__
(
self
):
self
.
count
=
0
self
.
count
=
0
...
@@ -171,14 +171,14 @@ class LoadImages: # for inference
...
@@ -171,14 +171,14 @@ class LoadImages: # for inference
ret_val
,
img0
=
self
.
cap
.
read
()
ret_val
,
img0
=
self
.
cap
.
read
()
self
.
frame
+=
1
self
.
frame
+=
1
print
(
'video
%
g/
%
g (
%
g/
%
g)
%
s: '
%
(
self
.
count
+
1
,
self
.
nf
,
self
.
frame
,
self
.
nframes
,
path
)
,
end
=
''
)
print
(
f
'video {self.count + 1}/{self.nf} ({self.frame}/{self.nframes}) {path}: '
,
end
=
''
)
else
:
else
:
# Read image
# Read image
self
.
count
+=
1
self
.
count
+=
1
img0
=
cv2
.
imread
(
path
)
# BGR
img0
=
cv2
.
imread
(
path
)
# BGR
assert
img0
is
not
None
,
'Image Not Found '
+
path
assert
img0
is
not
None
,
'Image Not Found '
+
path
print
(
'image
%
g/
%
g
%
s: '
%
(
self
.
count
,
self
.
nf
,
path
)
,
end
=
''
)
print
(
f
'image {self.count}/{self.nf} {path}: '
,
end
=
''
)
# Padded resize
# Padded resize
img
=
letterbox
(
img0
,
new_shape
=
self
.
img_size
)[
0
]
img
=
letterbox
(
img0
,
new_shape
=
self
.
img_size
)[
0
]
...
@@ -238,9 +238,9 @@ class LoadWebcam: # for inference
...
@@ -238,9 +238,9 @@ class LoadWebcam: # for inference
break
break
# Print
# Print
assert
ret_val
,
'Camera Error
%
s'
%
self
.
pipe
assert
ret_val
,
f
'Camera Error {self.pipe}'
img_path
=
'webcam.jpg'
img_path
=
'webcam.jpg'
print
(
'webcam
%
g: '
%
self
.
count
,
end
=
''
)
print
(
f
'webcam {self.count}: '
,
end
=
''
)
# Padded resize
# Padded resize
img
=
letterbox
(
img0
,
new_shape
=
self
.
img_size
)[
0
]
img
=
letterbox
(
img0
,
new_shape
=
self
.
img_size
)[
0
]
...
@@ -271,15 +271,15 @@ class LoadStreams: # multiple IP or RTSP cameras
...
@@ -271,15 +271,15 @@ class LoadStreams: # multiple IP or RTSP cameras
self
.
sources
=
[
clean_str
(
x
)
for
x
in
sources
]
# clean source names for later
self
.
sources
=
[
clean_str
(
x
)
for
x
in
sources
]
# clean source names for later
for
i
,
s
in
enumerate
(
sources
):
for
i
,
s
in
enumerate
(
sources
):
# Start the thread to read frames from the video stream
# Start the thread to read frames from the video stream
print
(
'
%
g/
%
g:
%
s... '
%
(
i
+
1
,
n
,
s
)
,
end
=
''
)
print
(
f
'{i + 1}/{n}: {s}... '
,
end
=
''
)
cap
=
cv2
.
VideoCapture
(
eval
(
s
)
if
s
.
isnumeric
()
else
s
)
cap
=
cv2
.
VideoCapture
(
eval
(
s
)
if
s
.
isnumeric
()
else
s
)
assert
cap
.
isOpened
(),
'Failed to open
%
s'
%
s
assert
cap
.
isOpened
(),
f
'Failed to open {s}'
w
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_WIDTH
))
w
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_WIDTH
))
h
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_HEIGHT
))
h
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_HEIGHT
))
fps
=
cap
.
get
(
cv2
.
CAP_PROP_FPS
)
%
100
fps
=
cap
.
get
(
cv2
.
CAP_PROP_FPS
)
%
100
_
,
self
.
imgs
[
i
]
=
cap
.
read
()
# guarantee first frame
_
,
self
.
imgs
[
i
]
=
cap
.
read
()
# guarantee first frame
thread
=
Thread
(
target
=
self
.
update
,
args
=
([
i
,
cap
]),
daemon
=
True
)
thread
=
Thread
(
target
=
self
.
update
,
args
=
([
i
,
cap
]),
daemon
=
True
)
print
(
' success (
%
gx
%
g at
%.2
f FPS).'
%
(
w
,
h
,
fps
)
)
print
(
f
' success ({w}x{h} at {fps:.2f} FPS).'
)
thread
.
start
()
thread
.
start
()
print
(
''
)
# newline
print
(
''
)
# newline
...
@@ -336,7 +336,7 @@ def img2label_paths(img_paths):
...
@@ -336,7 +336,7 @@ def img2label_paths(img_paths):
class
LoadImagesAndLabels
(
Dataset
):
# for training/testing
class
LoadImagesAndLabels
(
Dataset
):
# for training/testing
def
__init__
(
self
,
path
,
img_size
=
640
,
batch_size
=
16
,
augment
=
False
,
hyp
=
None
,
rect
=
False
,
image_weights
=
False
,
def
__init__
(
self
,
path
,
img_size
=
640
,
batch_size
=
16
,
augment
=
False
,
hyp
=
None
,
rect
=
False
,
image_weights
=
False
,
cache_images
=
False
,
single_cls
=
False
,
stride
=
32
,
pad
=
0.0
,
rank
=-
1
):
cache_images
=
False
,
single_cls
=
False
,
stride
=
32
,
pad
=
0.0
,
prefix
=
''
):
self
.
img_size
=
img_size
self
.
img_size
=
img_size
self
.
augment
=
augment
self
.
augment
=
augment
self
.
hyp
=
hyp
self
.
hyp
=
hyp
...
@@ -358,11 +358,11 @@ class LoadImagesAndLabels(Dataset): # for training/testing
...
@@ -358,11 +358,11 @@ class LoadImagesAndLabels(Dataset): # for training/testing
parent
=
str
(
p
.
parent
)
+
os
.
sep
parent
=
str
(
p
.
parent
)
+
os
.
sep
f
+=
[
x
.
replace
(
'./'
,
parent
)
if
x
.
startswith
(
'./'
)
else
x
for
x
in
t
]
# local to global path
f
+=
[
x
.
replace
(
'./'
,
parent
)
if
x
.
startswith
(
'./'
)
else
x
for
x
in
t
]
# local to global path
else
:
else
:
raise
Exception
(
'
%
s does not exist'
%
p
)
raise
Exception
(
f
'{prefix}{p} does not exist'
)
self
.
img_files
=
sorted
([
x
.
replace
(
'/'
,
os
.
sep
)
for
x
in
f
if
x
.
split
(
'.'
)[
-
1
]
.
lower
()
in
img_formats
])
self
.
img_files
=
sorted
([
x
.
replace
(
'/'
,
os
.
sep
)
for
x
in
f
if
x
.
split
(
'.'
)[
-
1
]
.
lower
()
in
img_formats
])
assert
self
.
img_files
,
'
No images found'
assert
self
.
img_files
,
f
'{prefix}
No images found'
except
Exception
as
e
:
except
Exception
as
e
:
raise
Exception
(
'Error loading data from
%
s:
%
s
\n
See
%
s'
%
(
path
,
e
,
help_url
)
)
raise
Exception
(
f
'{prefix}Error loading data from {path}: {e}
\n
See {help_url}'
)
# Check cache
# Check cache
self
.
label_files
=
img2label_paths
(
self
.
img_files
)
# labels
self
.
label_files
=
img2label_paths
(
self
.
img_files
)
# labels
...
@@ -370,15 +370,15 @@ class LoadImagesAndLabels(Dataset): # for training/testing
...
@@ -370,15 +370,15 @@ class LoadImagesAndLabels(Dataset): # for training/testing
if
cache_path
.
is_file
():
if
cache_path
.
is_file
():
cache
=
torch
.
load
(
cache_path
)
# load
cache
=
torch
.
load
(
cache_path
)
# load
if
cache
[
'hash'
]
!=
get_hash
(
self
.
label_files
+
self
.
img_files
)
or
'results'
not
in
cache
:
# changed
if
cache
[
'hash'
]
!=
get_hash
(
self
.
label_files
+
self
.
img_files
)
or
'results'
not
in
cache
:
# changed
cache
=
self
.
cache_labels
(
cache_path
)
# re-cache
cache
=
self
.
cache_labels
(
cache_path
,
prefix
)
# re-cache
else
:
else
:
cache
=
self
.
cache_labels
(
cache_path
)
# cache
cache
=
self
.
cache_labels
(
cache_path
,
prefix
)
# cache
# Display cache
# Display cache
[
nf
,
nm
,
ne
,
nc
,
n
]
=
cache
.
pop
(
'results'
)
# found, missing, empty, corrupted, total
[
nf
,
nm
,
ne
,
nc
,
n
]
=
cache
.
pop
(
'results'
)
# found, missing, empty, corrupted, total
desc
=
f
"Scanning '{cache_path}' for images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
desc
=
f
"Scanning '{cache_path}' for images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
tqdm
(
None
,
desc
=
desc
,
total
=
n
,
initial
=
n
)
tqdm
(
None
,
desc
=
prefix
+
desc
,
total
=
n
,
initial
=
n
)
assert
nf
>
0
or
not
augment
,
f
'
No labels found
in {cache_path}. Can not train without labels. See {help_url}'
assert
nf
>
0
or
not
augment
,
f
'
{prefix}No labels
in {cache_path}. Can not train without labels. See {help_url}'
# Read cache
# Read cache
cache
.
pop
(
'hash'
)
# remove hash
cache
.
pop
(
'hash'
)
# remove hash
...
@@ -432,9 +432,9 @@ class LoadImagesAndLabels(Dataset): # for training/testing
...
@@ -432,9 +432,9 @@ class LoadImagesAndLabels(Dataset): # for training/testing
for
i
,
x
in
pbar
:
for
i
,
x
in
pbar
:
self
.
imgs
[
i
],
self
.
img_hw0
[
i
],
self
.
img_hw
[
i
]
=
x
# img, hw_original, hw_resized = load_image(self, i)
self
.
imgs
[
i
],
self
.
img_hw0
[
i
],
self
.
img_hw
[
i
]
=
x
# img, hw_original, hw_resized = load_image(self, i)
gb
+=
self
.
imgs
[
i
]
.
nbytes
gb
+=
self
.
imgs
[
i
]
.
nbytes
pbar
.
desc
=
'Caching images (
%.1
fGB)'
%
(
gb
/
1E9
)
pbar
.
desc
=
f
'{prefix}Caching images ({gb / 1E9:.1f}GB)'
def
cache_labels
(
self
,
path
=
Path
(
'./labels.cache'
)):
def
cache_labels
(
self
,
path
=
Path
(
'./labels.cache'
)
,
prefix
=
''
):
# Cache dataset labels, check images and read shapes
# Cache dataset labels, check images and read shapes
x
=
{}
# dict
x
=
{}
# dict
nm
,
nf
,
ne
,
nc
=
0
,
0
,
0
,
0
# number missing, found, empty, duplicate
nm
,
nf
,
ne
,
nc
=
0
,
0
,
0
,
0
# number missing, found, empty, duplicate
...
@@ -466,18 +466,18 @@ class LoadImagesAndLabels(Dataset): # for training/testing
...
@@ -466,18 +466,18 @@ class LoadImagesAndLabels(Dataset): # for training/testing
x
[
im_file
]
=
[
l
,
shape
]
x
[
im_file
]
=
[
l
,
shape
]
except
Exception
as
e
:
except
Exception
as
e
:
nc
+=
1
nc
+=
1
print
(
'WARNING: Ignoring corrupted image and/or label
%
s:
%
s'
%
(
im_file
,
e
)
)
print
(
f
'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}'
)
pbar
.
desc
=
f
"Scanning '{path.parent / path.stem}' for images and labels... "
\
pbar
.
desc
=
f
"
{prefix}
Scanning '{path.parent / path.stem}' for images and labels... "
\
f
"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
f
"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
if
nf
==
0
:
if
nf
==
0
:
print
(
f
'WARNING: No labels found in {path}. See {help_url}'
)
print
(
f
'
{prefix}
WARNING: No labels found in {path}. See {help_url}'
)
x
[
'hash'
]
=
get_hash
(
self
.
label_files
+
self
.
img_files
)
x
[
'hash'
]
=
get_hash
(
self
.
label_files
+
self
.
img_files
)
x
[
'results'
]
=
[
nf
,
nm
,
ne
,
nc
,
i
+
1
]
x
[
'results'
]
=
[
nf
,
nm
,
ne
,
nc
,
i
+
1
]
torch
.
save
(
x
,
path
)
# save for next time
torch
.
save
(
x
,
path
)
# save for next time
logging
.
info
(
f
"New cache created: {path}"
)
logging
.
info
(
f
'{prefix}New cache created: {path}'
)
return
x
return
x
def
__len__
(
self
):
def
__len__
(
self
):
...
...
utils/general.py
浏览文件 @
dd03b20b
...
@@ -118,7 +118,7 @@ def one_cycle(y1=0.0, y2=1.0, steps=100):
...
@@ -118,7 +118,7 @@ def one_cycle(y1=0.0, y2=1.0, steps=100):
def
colorstr
(
*
input
):
def
colorstr
(
*
input
):
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
*
prefix
,
string
=
input
# color arguments, string
*
args
,
string
=
input
if
len
(
input
)
>
1
else
(
'blue'
,
'bold'
,
input
[
0
])
# color arguments, string
colors
=
{
'black'
:
'
\033
[30m'
,
# basic colors
colors
=
{
'black'
:
'
\033
[30m'
,
# basic colors
'red'
:
'
\033
[31m'
,
'red'
:
'
\033
[31m'
,
'green'
:
'
\033
[32m'
,
'green'
:
'
\033
[32m'
,
...
@@ -138,8 +138,7 @@ def colorstr(*input):
...
@@ -138,8 +138,7 @@ def colorstr(*input):
'end'
:
'
\033
[0m'
,
# misc
'end'
:
'
\033
[0m'
,
# misc
'bold'
:
'
\033
[1m'
,
'bold'
:
'
\033
[1m'
,
'underline'
:
'
\033
[4m'
}
'underline'
:
'
\033
[4m'
}
return
''
.
join
(
colors
[
x
]
for
x
in
args
)
+
f
'{string}'
+
colors
[
'end'
]
return
''
.
join
(
colors
[
x
]
for
x
in
prefix
)
+
f
'{string}'
+
colors
[
'end'
]
def
labels_to_class_weights
(
labels
,
nc
=
80
):
def
labels_to_class_weights
(
labels
,
nc
=
80
):
...
...
utils/plots.py
浏览文件 @
dd03b20b
...
@@ -245,9 +245,9 @@ def plot_study_txt(path='study/', x=None): # from utils.plots import *; plot_st
...
@@ -245,9 +245,9 @@ def plot_study_txt(path='study/', x=None): # from utils.plots import *; plot_st
'k.-'
,
linewidth
=
2
,
markersize
=
8
,
alpha
=.
25
,
label
=
'EfficientDet'
)
'k.-'
,
linewidth
=
2
,
markersize
=
8
,
alpha
=.
25
,
label
=
'EfficientDet'
)
ax2
.
grid
()
ax2
.
grid
()
ax2
.
set_yticks
(
np
.
arange
(
30
,
60
,
5
))
ax2
.
set_xlim
(
0
,
30
)
ax2
.
set_xlim
(
0
,
30
)
ax2
.
set_ylim
(
29
,
51
)
ax2
.
set_ylim
(
29
,
51
)
ax2
.
set_yticks
(
np
.
arange
(
30
,
55
,
5
))
ax2
.
set_xlabel
(
'GPU Speed (ms/img)'
)
ax2
.
set_xlabel
(
'GPU Speed (ms/img)'
)
ax2
.
set_ylabel
(
'COCO AP val'
)
ax2
.
set_ylabel
(
'COCO AP val'
)
ax2
.
legend
(
loc
=
'lower right'
)
ax2
.
legend
(
loc
=
'lower right'
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论