Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
Y
yolov5
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
Administrator
yolov5
Commits
b57abb17
Unverified
提交
b57abb17
authored
2月 07, 2022
作者:
Glenn Jocher
提交者:
GitHub
2月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move trainloader functions to class methods (#6559)
* Move trainloader functions to class methods * results = ThreadPool(NUM_THREADS).imap(self.load_image, range(n)) * Cleanup
上级
dc7e0930
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
53 行增加
和
55 行删除
+53
-55
datasets.py
utils/datasets.py
+53
-55
没有找到文件。
utils/datasets.py
浏览文件 @
b57abb17
...
...
@@ -484,7 +484,7 @@ class LoadImagesAndLabels(Dataset):
self
.
batch_shapes
=
np
.
ceil
(
np
.
array
(
shapes
)
*
img_size
/
stride
+
pad
)
.
astype
(
np
.
int
)
*
stride
# Cache images into
memory for faster training (WARNING: large datasets may exceed system RAM
)
# Cache images into
RAM/disk for faster training (WARNING: large datasets may exceed system resources
)
self
.
imgs
,
self
.
img_npy
=
[
None
]
*
n
,
[
None
]
*
n
if
cache_images
:
if
cache_images
==
'disk'
:
...
...
@@ -493,14 +493,14 @@ class LoadImagesAndLabels(Dataset):
self
.
im_cache_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
gb
=
0
# Gigabytes of cached images
self
.
img_hw0
,
self
.
img_hw
=
[
None
]
*
n
,
[
None
]
*
n
results
=
ThreadPool
(
NUM_THREADS
)
.
imap
(
lambda
x
:
load_image
(
*
x
),
zip
(
repeat
(
self
),
range
(
n
)
))
results
=
ThreadPool
(
NUM_THREADS
)
.
imap
(
self
.
load_image
,
range
(
n
))
pbar
=
tqdm
(
enumerate
(
results
),
total
=
n
)
for
i
,
x
in
pbar
:
if
cache_images
==
'disk'
:
if
not
self
.
img_npy
[
i
]
.
exists
():
np
.
save
(
self
.
img_npy
[
i
]
.
as_posix
(),
x
[
0
])
gb
+=
self
.
img_npy
[
i
]
.
stat
()
.
st_size
else
:
else
:
# 'ram'
self
.
imgs
[
i
],
self
.
img_hw0
[
i
],
self
.
img_hw
[
i
]
=
x
# im, hw_orig, hw_resized = load_image(self, i)
gb
+=
self
.
imgs
[
i
]
.
nbytes
pbar
.
desc
=
f
'{prefix}Caching images ({gb / 1E9:.1f}GB {cache_images})'
...
...
@@ -558,16 +558,16 @@ class LoadImagesAndLabels(Dataset):
mosaic
=
self
.
mosaic
and
random
.
random
()
<
hyp
[
'mosaic'
]
if
mosaic
:
# Load mosaic
img
,
labels
=
load_mosaic
(
self
,
index
)
img
,
labels
=
self
.
load_mosaic
(
index
)
shapes
=
None
# MixUp augmentation
if
random
.
random
()
<
hyp
[
'mixup'
]:
img
,
labels
=
mixup
(
img
,
labels
,
*
load_mosaic
(
self
,
random
.
randint
(
0
,
self
.
n
-
1
)))
img
,
labels
=
mixup
(
img
,
labels
,
*
self
.
load_mosaic
(
random
.
randint
(
0
,
self
.
n
-
1
)))
else
:
# Load image
img
,
(
h0
,
w0
),
(
h
,
w
)
=
load_image
(
self
,
index
)
img
,
(
h0
,
w0
),
(
h
,
w
)
=
self
.
load_image
(
index
)
# Letterbox
shape
=
self
.
batch_shapes
[
self
.
batch
[
index
]]
if
self
.
rect
else
self
.
img_size
# final letterboxed shape
...
...
@@ -624,63 +624,28 @@ class LoadImagesAndLabels(Dataset):
return
torch
.
from_numpy
(
img
),
labels_out
,
self
.
img_files
[
index
],
shapes
@staticmethod
def
collate_fn
(
batch
):
img
,
label
,
path
,
shapes
=
zip
(
*
batch
)
# transposed
for
i
,
lb
in
enumerate
(
label
):
lb
[:,
0
]
=
i
# add target image index for build_targets()
return
torch
.
stack
(
img
,
0
),
torch
.
cat
(
label
,
0
),
path
,
shapes
@staticmethod
def
collate_fn4
(
batch
):
img
,
label
,
path
,
shapes
=
zip
(
*
batch
)
# transposed
n
=
len
(
shapes
)
//
4
img4
,
label4
,
path4
,
shapes4
=
[],
[],
path
[:
n
],
shapes
[:
n
]
ho
=
torch
.
tensor
([[
0.0
,
0
,
0
,
1
,
0
,
0
]])
wo
=
torch
.
tensor
([[
0.0
,
0
,
1
,
0
,
0
,
0
]])
s
=
torch
.
tensor
([[
1
,
1
,
0.5
,
0.5
,
0.5
,
0.5
]])
# scale
for
i
in
range
(
n
):
# zidane torch.zeros(16,3,720,1280) # BCHW
i
*=
4
if
random
.
random
()
<
0.5
:
im
=
F
.
interpolate
(
img
[
i
]
.
unsqueeze
(
0
)
.
float
(),
scale_factor
=
2.0
,
mode
=
'bilinear'
,
align_corners
=
False
)[
0
]
.
type
(
img
[
i
]
.
type
())
lb
=
label
[
i
]
else
:
im
=
torch
.
cat
((
torch
.
cat
((
img
[
i
],
img
[
i
+
1
]),
1
),
torch
.
cat
((
img
[
i
+
2
],
img
[
i
+
3
]),
1
)),
2
)
lb
=
torch
.
cat
((
label
[
i
],
label
[
i
+
1
]
+
ho
,
label
[
i
+
2
]
+
wo
,
label
[
i
+
3
]
+
ho
+
wo
),
0
)
*
s
img4
.
append
(
im
)
label4
.
append
(
lb
)
for
i
,
lb
in
enumerate
(
label4
):
lb
[:,
0
]
=
i
# add target image index for build_targets()
return
torch
.
stack
(
img4
,
0
),
torch
.
cat
(
label4
,
0
),
path4
,
shapes4
# Ancillary functions --------------------------------------------------------------------------------------------------
def
load_image
(
self
,
i
):
# loads 1 image from dataset index 'i', returns im, original hw, resized hw
def
load_image
(
self
,
i
):
# loads 1 image from dataset index 'i', returns (im, original hw, resized hw)
im
=
self
.
imgs
[
i
]
if
im
is
None
:
# not cached in ram
if
im
is
None
:
# not cached in RAM
npy
=
self
.
img_npy
[
i
]
if
npy
and
npy
.
exists
():
# load npy
im
=
np
.
load
(
npy
)
else
:
# read image
path
=
self
.
img_files
[
i
]
im
=
cv2
.
imread
(
path
)
# BGR
assert
im
is
not
None
,
f
'Image Not Found {path
}'
f
=
self
.
img_files
[
i
]
im
=
cv2
.
imread
(
f
)
# BGR
assert
im
is
not
None
,
f
'Image Not Found {f
}'
h0
,
w0
=
im
.
shape
[:
2
]
# orig hw
r
=
self
.
img_size
/
max
(
h0
,
w0
)
# ratio
if
r
!=
1
:
# if sizes are not equal
im
=
cv2
.
resize
(
im
,
(
int
(
w0
*
r
),
int
(
h0
*
r
)),
interpolation
=
cv2
.
INTER_AREA
if
r
<
1
and
not
self
.
augment
else
cv2
.
INTER_LINEAR
)
im
=
cv2
.
resize
(
im
,
(
int
(
w0
*
r
),
int
(
h0
*
r
)),
interpolation
=
cv2
.
INTER_LINEAR
if
(
self
.
augment
or
r
>
1
)
else
cv2
.
INTER_AREA
)
return
im
,
(
h0
,
w0
),
im
.
shape
[:
2
]
# im, hw_original, hw_resized
else
:
return
self
.
imgs
[
i
],
self
.
img_hw0
[
i
],
self
.
img_hw
[
i
]
# im, hw_original, hw_resized
def
load_mosaic
(
self
,
index
):
def
load_mosaic
(
self
,
index
):
# YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
labels4
,
segments4
=
[],
[]
s
=
self
.
img_size
...
...
@@ -689,7 +654,7 @@ def load_mosaic(self, index):
random
.
shuffle
(
indices
)
for
i
,
index
in
enumerate
(
indices
):
# Load image
img
,
_
,
(
h
,
w
)
=
load_image
(
self
,
index
)
img
,
_
,
(
h
,
w
)
=
self
.
load_image
(
index
)
# place img in img4
if
i
==
0
:
# top left
...
...
@@ -736,8 +701,7 @@ def load_mosaic(self, index):
return
img4
,
labels4
def
load_mosaic9
(
self
,
index
):
def
load_mosaic9
(
self
,
index
):
# YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic
labels9
,
segments9
=
[],
[]
s
=
self
.
img_size
...
...
@@ -746,7 +710,7 @@ def load_mosaic9(self, index):
hp
,
wp
=
-
1
,
-
1
# height, width previous
for
i
,
index
in
enumerate
(
indices
):
# Load image
img
,
_
,
(
h
,
w
)
=
load_image
(
self
,
index
)
img
,
_
,
(
h
,
w
)
=
self
.
load_image
(
index
)
# place img in img9
if
i
==
0
:
# center
...
...
@@ -811,7 +775,41 @@ def load_mosaic9(self, index):
return
img9
,
labels9
@staticmethod
def
collate_fn
(
batch
):
img
,
label
,
path
,
shapes
=
zip
(
*
batch
)
# transposed
for
i
,
lb
in
enumerate
(
label
):
lb
[:,
0
]
=
i
# add target image index for build_targets()
return
torch
.
stack
(
img
,
0
),
torch
.
cat
(
label
,
0
),
path
,
shapes
@staticmethod
def
collate_fn4
(
batch
):
img
,
label
,
path
,
shapes
=
zip
(
*
batch
)
# transposed
n
=
len
(
shapes
)
//
4
img4
,
label4
,
path4
,
shapes4
=
[],
[],
path
[:
n
],
shapes
[:
n
]
ho
=
torch
.
tensor
([[
0.0
,
0
,
0
,
1
,
0
,
0
]])
wo
=
torch
.
tensor
([[
0.0
,
0
,
1
,
0
,
0
,
0
]])
s
=
torch
.
tensor
([[
1
,
1
,
0.5
,
0.5
,
0.5
,
0.5
]])
# scale
for
i
in
range
(
n
):
# zidane torch.zeros(16,3,720,1280) # BCHW
i
*=
4
if
random
.
random
()
<
0.5
:
im
=
F
.
interpolate
(
img
[
i
]
.
unsqueeze
(
0
)
.
float
(),
scale_factor
=
2.0
,
mode
=
'bilinear'
,
align_corners
=
False
)[
0
]
.
type
(
img
[
i
]
.
type
())
lb
=
label
[
i
]
else
:
im
=
torch
.
cat
((
torch
.
cat
((
img
[
i
],
img
[
i
+
1
]),
1
),
torch
.
cat
((
img
[
i
+
2
],
img
[
i
+
3
]),
1
)),
2
)
lb
=
torch
.
cat
((
label
[
i
],
label
[
i
+
1
]
+
ho
,
label
[
i
+
2
]
+
wo
,
label
[
i
+
3
]
+
ho
+
wo
),
0
)
*
s
img4
.
append
(
im
)
label4
.
append
(
lb
)
for
i
,
lb
in
enumerate
(
label4
):
lb
[:,
0
]
=
i
# add target image index for build_targets()
return
torch
.
stack
(
img4
,
0
),
torch
.
cat
(
label4
,
0
),
path4
,
shapes4
# Ancillary functions --------------------------------------------------------------------------------------------------
def
create_folder
(
path
=
'./new'
):
# Create folder
if
os
.
path
.
exists
(
path
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论