Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
Y
yolov5
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
Administrator
yolov5
Commits
01a67a97
提交
01a67a97
authored
7月 11, 2020
作者:
Glenn Jocher
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'origin/master'
上级
d3e786ed
9006b85d
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
64 行增加
和
69 行删除
+64
-69
export.py
models/export.py
+2
-2
train.py
train.py
+6
-7
datasets.py
utils/datasets.py
+52
-54
utils.py
utils/utils.py
+4
-6
没有找到文件。
models/export.py
浏览文件 @
01a67a97
...
@@ -31,7 +31,7 @@ if __name__ == '__main__':
...
@@ -31,7 +31,7 @@ if __name__ == '__main__':
# TorchScript export
# TorchScript export
try
:
try
:
print
(
'
\n
Starting TorchScript export with torch
%
s...'
%
torch
.
__version__
)
print
(
'
\n
Starting TorchScript export with torch
%
s...'
%
torch
.
__version__
)
f
=
opt
.
weights
.
replace
(
'.pt'
,
'.torchscript'
)
# filename
f
=
opt
.
weights
.
replace
(
'.pt'
,
'.torchscript
.pt
'
)
# filename
ts
=
torch
.
jit
.
trace
(
model
,
img
)
ts
=
torch
.
jit
.
trace
(
model
,
img
)
ts
.
save
(
f
)
ts
.
save
(
f
)
print
(
'TorchScript export success, saved as
%
s'
%
f
)
print
(
'TorchScript export success, saved as
%
s'
%
f
)
...
@@ -62,7 +62,7 @@ if __name__ == '__main__':
...
@@ -62,7 +62,7 @@ if __name__ == '__main__':
print
(
'
\n
Starting CoreML export with coremltools
%
s...'
%
ct
.
__version__
)
print
(
'
\n
Starting CoreML export with coremltools
%
s...'
%
ct
.
__version__
)
# convert model from torchscript and apply pixel scaling as per detect.py
# convert model from torchscript and apply pixel scaling as per detect.py
model
=
ct
.
convert
(
ts
,
inputs
=
[
ct
.
ImageType
(
name
=
'images'
,
shape
=
img
.
shape
,
scale
=
1
/
255.0
,
bias
=
[
0
,
0
,
0
])])
model
=
ct
.
convert
(
ts
,
inputs
=
[
ct
.
ImageType
(
name
=
'images'
,
shape
=
img
.
shape
,
scale
=
1
/
255.0
,
bias
=
[
0
,
0
,
0
])])
f
=
opt
.
weights
.
replace
(
'.pt'
,
'.mlmodel'
)
# filename
f
=
opt
.
weights
.
replace
(
'.pt'
,
'.mlmodel'
)
# filename
model
.
save
(
f
)
model
.
save
(
f
)
print
(
'CoreML export success, saved as
%
s'
%
f
)
print
(
'CoreML export success, saved as
%
s'
%
f
)
...
...
train.py
浏览文件 @
01a67a97
...
@@ -44,7 +44,7 @@ hyp = {'optimizer': 'SGD', # ['adam', 'SGD', None] if none, default is SGD
...
@@ -44,7 +44,7 @@ hyp = {'optimizer': 'SGD', # ['adam', 'SGD', None] if none, default is SGD
def
train
(
hyp
):
def
train
(
hyp
):
print
(
f
'Hyperparameters {hyp}'
)
print
(
f
'Hyperparameters {hyp}'
)
log_dir
=
tb_writer
.
log_dir
# run directory
log_dir
=
tb_writer
.
log_dir
if
tb_writer
else
'runs/evolution'
# run directory
wdir
=
str
(
Path
(
log_dir
)
/
'weights'
)
+
os
.
sep
# weights directory
wdir
=
str
(
Path
(
log_dir
)
/
'weights'
)
+
os
.
sep
# weights directory
os
.
makedirs
(
wdir
,
exist_ok
=
True
)
os
.
makedirs
(
wdir
,
exist_ok
=
True
)
...
@@ -387,7 +387,10 @@ if __name__ == '__main__':
...
@@ -387,7 +387,10 @@ if __name__ == '__main__':
opt
.
weights
=
last
if
opt
.
resume
and
not
opt
.
weights
else
opt
.
weights
opt
.
weights
=
last
if
opt
.
resume
and
not
opt
.
weights
else
opt
.
weights
opt
.
cfg
=
check_file
(
opt
.
cfg
)
# check file
opt
.
cfg
=
check_file
(
opt
.
cfg
)
# check file
opt
.
data
=
check_file
(
opt
.
data
)
# check file
opt
.
data
=
check_file
(
opt
.
data
)
# check file
opt
.
hyp
=
check_file
(
opt
.
hyp
)
if
opt
.
hyp
else
''
# check file
if
opt
.
hyp
:
# update hyps
opt
.
hyp
=
check_file
(
opt
.
hyp
)
# check file
with
open
(
opt
.
hyp
)
as
f
:
hyp
.
update
(
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
))
# update hyps
print
(
opt
)
print
(
opt
)
opt
.
img_size
.
extend
([
opt
.
img_size
[
-
1
]]
*
(
2
-
len
(
opt
.
img_size
)))
# extend to 2 sizes (train, test)
opt
.
img_size
.
extend
([
opt
.
img_size
[
-
1
]]
*
(
2
-
len
(
opt
.
img_size
)))
# extend to 2 sizes (train, test)
device
=
torch_utils
.
select_device
(
opt
.
device
,
apex
=
mixed_precision
,
batch_size
=
opt
.
batch_size
)
device
=
torch_utils
.
select_device
(
opt
.
device
,
apex
=
mixed_precision
,
batch_size
=
opt
.
batch_size
)
...
@@ -396,12 +399,8 @@ if __name__ == '__main__':
...
@@ -396,12 +399,8 @@ if __name__ == '__main__':
# Train
# Train
if
not
opt
.
evolve
:
if
not
opt
.
evolve
:
print
(
'Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/'
)
tb_writer
=
SummaryWriter
(
log_dir
=
increment_dir
(
'runs/exp'
,
opt
.
name
))
tb_writer
=
SummaryWriter
(
log_dir
=
increment_dir
(
'runs/exp'
,
opt
.
name
))
if
opt
.
hyp
:
# update hyps
print
(
'Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/'
)
with
open
(
opt
.
hyp
)
as
f
:
hyp
.
update
(
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
))
train
(
hyp
)
train
(
hyp
)
# Evolve hyperparameters (optional)
# Evolve hyperparameters (optional)
...
...
utils/datasets.py
浏览文件 @
01a67a97
...
@@ -26,6 +26,11 @@ for orientation in ExifTags.TAGS.keys():
...
@@ -26,6 +26,11 @@ for orientation in ExifTags.TAGS.keys():
break
break
def
get_hash
(
files
):
# Returns a single hash value of a list of files
return
sum
(
os
.
path
.
getsize
(
f
)
for
f
in
files
if
os
.
path
.
isfile
(
f
))
def
exif_size
(
img
):
def
exif_size
(
img
):
# Returns exif-corrected PIL size
# Returns exif-corrected PIL size
s
=
img
.
size
# (width, height)
s
=
img
.
size
# (width, height)
...
@@ -280,7 +285,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
...
@@ -280,7 +285,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
def
__init__
(
self
,
path
,
img_size
=
640
,
batch_size
=
16
,
augment
=
False
,
hyp
=
None
,
rect
=
False
,
image_weights
=
False
,
def
__init__
(
self
,
path
,
img_size
=
640
,
batch_size
=
16
,
augment
=
False
,
hyp
=
None
,
rect
=
False
,
image_weights
=
False
,
cache_images
=
False
,
single_cls
=
False
,
stride
=
32
,
pad
=
0.0
):
cache_images
=
False
,
single_cls
=
False
,
stride
=
32
,
pad
=
0.0
):
try
:
try
:
f
=
[]
f
=
[]
# image files
for
p
in
path
if
isinstance
(
path
,
list
)
else
[
path
]:
for
p
in
path
if
isinstance
(
path
,
list
)
else
[
path
]:
p
=
str
(
Path
(
p
))
# os-agnostic
p
=
str
(
Path
(
p
))
# os-agnostic
parent
=
str
(
Path
(
p
)
.
parent
)
+
os
.
sep
parent
=
str
(
Path
(
p
)
.
parent
)
+
os
.
sep
...
@@ -292,7 +297,6 @@ class LoadImagesAndLabels(Dataset): # for training/testing
...
@@ -292,7 +297,6 @@ class LoadImagesAndLabels(Dataset): # for training/testing
f
+=
glob
.
iglob
(
p
+
os
.
sep
+
'*.*'
)
f
+=
glob
.
iglob
(
p
+
os
.
sep
+
'*.*'
)
else
:
else
:
raise
Exception
(
'
%
s does not exist'
%
p
)
raise
Exception
(
'
%
s does not exist'
%
p
)
path
=
p
# *.npy dir
self
.
img_files
=
[
x
.
replace
(
'/'
,
os
.
sep
)
for
x
in
f
if
os
.
path
.
splitext
(
x
)[
-
1
]
.
lower
()
in
img_formats
]
self
.
img_files
=
[
x
.
replace
(
'/'
,
os
.
sep
)
for
x
in
f
if
os
.
path
.
splitext
(
x
)[
-
1
]
.
lower
()
in
img_formats
]
except
Exception
as
e
:
except
Exception
as
e
:
raise
Exception
(
'Error loading data from
%
s:
%
s
\n
See
%
s'
%
(
path
,
e
,
help_url
))
raise
Exception
(
'Error loading data from
%
s:
%
s
\n
See
%
s'
%
(
path
,
e
,
help_url
))
...
@@ -314,20 +318,22 @@ class LoadImagesAndLabels(Dataset): # for training/testing
...
@@ -314,20 +318,22 @@ class LoadImagesAndLabels(Dataset): # for training/testing
self
.
stride
=
stride
self
.
stride
=
stride
# Define labels
# Define labels
self
.
label_files
=
[
x
.
replace
(
'images'
,
'labels'
)
.
replace
(
os
.
path
.
splitext
(
x
)[
-
1
],
'.txt'
)
self
.
label_files
=
[
x
.
replace
(
'images'
,
'labels'
)
.
replace
(
os
.
path
.
splitext
(
x
)[
-
1
],
'.txt'
)
for
x
in
for
x
in
self
.
img_files
]
self
.
img_files
]
# Read image shapes (wh)
# Check cache
sp
=
path
.
replace
(
'.txt'
,
''
)
+
'.shapes'
# shapefile path
cache_path
=
str
(
Path
(
self
.
label_files
[
0
])
.
parent
)
+
'.cache'
# cached labels
try
:
if
os
.
path
.
isfile
(
cache_path
):
with
open
(
sp
,
'r'
)
as
f
:
# read existing shapefile
cache
=
torch
.
load
(
cache_path
)
# load
s
=
[
x
.
split
()
for
x
in
f
.
read
()
.
splitlines
()]
if
cache
[
'hash'
]
!=
get_hash
(
self
.
label_files
+
self
.
img_files
):
# dataset changed
assert
len
(
s
)
==
n
,
'Shapefile out of sync'
cache
=
self
.
cache_labels
(
cache_path
)
# re-cache
except
:
else
:
s
=
[
exif_size
(
Image
.
open
(
f
))
for
f
in
tqdm
(
self
.
img_files
,
desc
=
'Reading image shapes'
)]
cache
=
self
.
cache_labels
(
cache_path
)
# cache
np
.
savetxt
(
sp
,
s
,
fmt
=
'
%
g'
)
# overwrites existing (if any)
self
.
shapes
=
np
.
array
(
s
,
dtype
=
np
.
float64
)
# Get labels
labels
,
shapes
=
zip
(
*
[
cache
[
x
]
for
x
in
self
.
img_files
])
self
.
shapes
=
np
.
array
(
shapes
,
dtype
=
np
.
float64
)
self
.
labels
=
list
(
labels
)
# Rectangular Training https://github.com/ultralytics/yolov3/issues/232
# Rectangular Training https://github.com/ultralytics/yolov3/issues/232
if
self
.
rect
:
if
self
.
rect
:
...
@@ -337,6 +343,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
...
@@ -337,6 +343,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
irect
=
ar
.
argsort
()
irect
=
ar
.
argsort
()
self
.
img_files
=
[
self
.
img_files
[
i
]
for
i
in
irect
]
self
.
img_files
=
[
self
.
img_files
[
i
]
for
i
in
irect
]
self
.
label_files
=
[
self
.
label_files
[
i
]
for
i
in
irect
]
self
.
label_files
=
[
self
.
label_files
[
i
]
for
i
in
irect
]
self
.
labels
=
[
self
.
labels
[
i
]
for
i
in
irect
]
self
.
shapes
=
s
[
irect
]
# wh
self
.
shapes
=
s
[
irect
]
# wh
ar
=
ar
[
irect
]
ar
=
ar
[
irect
]
...
@@ -353,33 +360,11 @@ class LoadImagesAndLabels(Dataset): # for training/testing
...
@@ -353,33 +360,11 @@ class LoadImagesAndLabels(Dataset): # for training/testing
self
.
batch_shapes
=
np
.
ceil
(
np
.
array
(
shapes
)
*
img_size
/
stride
+
pad
)
.
astype
(
np
.
int
)
*
stride
self
.
batch_shapes
=
np
.
ceil
(
np
.
array
(
shapes
)
*
img_size
/
stride
+
pad
)
.
astype
(
np
.
int
)
*
stride
# Cache labels
# Cache labels
self
.
imgs
=
[
None
]
*
n
self
.
labels
=
[
np
.
zeros
((
0
,
5
),
dtype
=
np
.
float32
)]
*
n
create_datasubset
,
extract_bounding_boxes
,
labels_loaded
=
False
,
False
,
False
create_datasubset
,
extract_bounding_boxes
,
labels_loaded
=
False
,
False
,
False
nm
,
nf
,
ne
,
ns
,
nd
=
0
,
0
,
0
,
0
,
0
# number missing, found, empty, datasubset, duplicate
nm
,
nf
,
ne
,
ns
,
nd
=
0
,
0
,
0
,
0
,
0
# number missing, found, empty, datasubset, duplicate
np_labels_path
=
str
(
Path
(
self
.
label_files
[
0
])
.
parent
)
+
'.npy'
# saved labels in *.npy file
if
os
.
path
.
isfile
(
np_labels_path
):
s
=
np_labels_path
# print string
x
=
np
.
load
(
np_labels_path
,
allow_pickle
=
True
)
if
len
(
x
)
==
n
:
self
.
labels
=
x
labels_loaded
=
True
else
:
s
=
path
.
replace
(
'images'
,
'labels'
)
pbar
=
tqdm
(
self
.
label_files
)
pbar
=
tqdm
(
self
.
label_files
)
for
i
,
file
in
enumerate
(
pbar
):
for
i
,
file
in
enumerate
(
pbar
):
if
labels_loaded
:
l
=
self
.
labels
[
i
]
# label
l
=
self
.
labels
[
i
]
# np.savetxt(file, l, '%g') # save *.txt from *.npy file
else
:
try
:
with
open
(
file
,
'r'
)
as
f
:
l
=
np
.
array
([
x
.
split
()
for
x
in
f
.
read
()
.
splitlines
()],
dtype
=
np
.
float32
)
except
:
nm
+=
1
# print('missing labels for image %s' % self.img_files[i]) # file missing
continue
if
l
.
shape
[
0
]:
if
l
.
shape
[
0
]:
assert
l
.
shape
[
1
]
==
5
,
'> 5 label columns:
%
s'
%
file
assert
l
.
shape
[
1
]
==
5
,
'> 5 label columns:
%
s'
%
file
assert
(
l
>=
0
)
.
all
(),
'negative labels:
%
s'
%
file
assert
(
l
>=
0
)
.
all
(),
'negative labels:
%
s'
%
file
...
@@ -425,15 +410,13 @@ class LoadImagesAndLabels(Dataset): # for training/testing
...
@@ -425,15 +410,13 @@ class LoadImagesAndLabels(Dataset): # for training/testing
ne
+=
1
# print('empty labels for image %s' % self.img_files[i]) # file empty
ne
+=
1
# print('empty labels for image %s' % self.img_files[i]) # file empty
# os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove
# os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove
pbar
.
desc
=
'Caching labels
%
s (
%
g found,
%
g missing,
%
g empty,
%
g duplicate, for
%
g images)'
%
(
pbar
.
desc
=
'Scanning labels
%
s (
%
g found,
%
g missing,
%
g empty,
%
g duplicate, for
%
g images)'
%
(
s
,
nf
,
nm
,
ne
,
nd
,
n
)
cache_path
,
nf
,
nm
,
ne
,
nd
,
n
)
assert
nf
>
0
or
n
==
20288
,
'No labels found in
%
s. See
%
s'
%
(
os
.
path
.
dirname
(
file
)
+
os
.
sep
,
help_url
)
assert
nf
>
0
,
'No labels found in
%
s. See
%
s'
%
(
os
.
path
.
dirname
(
file
)
+
os
.
sep
,
help_url
)
if
not
labels_loaded
and
n
>
1000
:
print
(
'Saving labels to
%
s for faster future loading'
%
np_labels_path
)
np
.
save
(
np_labels_path
,
self
.
labels
)
# save for next time
# Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
# Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
if
cache_images
:
# if training
self
.
imgs
=
[
None
]
*
n
if
cache_images
:
gb
=
0
# Gigabytes of cached images
gb
=
0
# Gigabytes of cached images
pbar
=
tqdm
(
range
(
len
(
self
.
img_files
)),
desc
=
'Caching images'
)
pbar
=
tqdm
(
range
(
len
(
self
.
img_files
)),
desc
=
'Caching images'
)
self
.
img_hw0
,
self
.
img_hw
=
[
None
]
*
n
,
[
None
]
*
n
self
.
img_hw0
,
self
.
img_hw
=
[
None
]
*
n
,
[
None
]
*
n
...
@@ -442,15 +425,30 @@ class LoadImagesAndLabels(Dataset): # for training/testing
...
@@ -442,15 +425,30 @@ class LoadImagesAndLabels(Dataset): # for training/testing
gb
+=
self
.
imgs
[
i
]
.
nbytes
gb
+=
self
.
imgs
[
i
]
.
nbytes
pbar
.
desc
=
'Caching images (
%.1
fGB)'
%
(
gb
/
1E9
)
pbar
.
desc
=
'Caching images (
%.1
fGB)'
%
(
gb
/
1E9
)
# Detect corrupted images https://medium.com/joelthchao/programmatically-detect-corrupted-image-8c1b2006c3d3
def
cache_labels
(
self
,
path
=
'labels.cache'
):
detect_corrupted_images
=
False
# Cache dataset labels, check images and read shapes
if
detect_corrupted_images
:
x
=
{}
# dict
from
skimage
import
io
# conda install -c conda-forge scikit-image
pbar
=
tqdm
(
zip
(
self
.
img_files
,
self
.
label_files
),
desc
=
'Scanning images'
,
total
=
len
(
self
.
img_files
))
for
file
in
tqdm
(
self
.
img_files
,
desc
=
'Detecting corrupted images'
):
for
(
img
,
label
)
in
pbar
:
try
:
try
:
_
=
io
.
imread
(
file
)
l
=
[]
except
:
image
=
Image
.
open
(
img
)
print
(
'Corrupted image detected:
%
s'
%
file
)
image
.
verify
()
# PIL verify
# _ = io.imread(img) # skimage verify (from skimage import io)
shape
=
exif_size
(
image
)
# image size
if
os
.
path
.
isfile
(
label
):
with
open
(
label
,
'r'
)
as
f
:
l
=
np
.
array
([
x
.
split
()
for
x
in
f
.
read
()
.
splitlines
()],
dtype
=
np
.
float32
)
# labels
if
len
(
l
)
==
0
:
l
=
np
.
zeros
((
0
,
5
),
dtype
=
np
.
float32
)
x
[
img
]
=
[
l
,
shape
]
except
Exception
as
e
:
x
[
img
]
=
None
print
(
'WARNING:
%
s:
%
s'
%
(
img
,
e
))
x
[
'hash'
]
=
get_hash
(
self
.
label_files
+
self
.
img_files
)
torch
.
save
(
x
,
path
)
# save for next time
return
x
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
img_files
)
return
len
(
self
.
img_files
)
...
...
utils/utils.py
浏览文件 @
01a67a97
...
@@ -45,7 +45,7 @@ def get_latest_run(search_dir='./runs'):
...
@@ -45,7 +45,7 @@ def get_latest_run(search_dir='./runs'):
def
check_git_status
():
def
check_git_status
():
# Suggest 'git pull' if repo is out of date
# Suggest 'git pull' if repo is out of date
if
platform
in
[
'linux'
,
'darwin'
]:
if
platform
in
[
'linux'
,
'darwin'
]
and
not
os
.
path
.
isfile
(
'/.dockerenv'
)
:
s
=
subprocess
.
check_output
(
'if [ -d .git ]; then git fetch && git status -uno; fi'
,
shell
=
True
)
.
decode
(
'utf-8'
)
s
=
subprocess
.
check_output
(
'if [ -d .git ]; then git fetch && git status -uno; fi'
,
shell
=
True
)
.
decode
(
'utf-8'
)
if
'Your branch is behind'
in
s
:
if
'Your branch is behind'
in
s
:
print
(
s
[
s
.
find
(
'Your branch is behind'
):
s
.
find
(
'
\n\n
'
)]
+
'
\n
'
)
print
(
s
[
s
.
find
(
'Your branch is behind'
):
s
.
find
(
'
\n\n
'
)]
+
'
\n
'
)
...
@@ -636,14 +636,12 @@ def strip_optimizer(f='weights/best.pt'): # from utils.utils import *; strip_op
...
@@ -636,14 +636,12 @@ def strip_optimizer(f='weights/best.pt'): # from utils.utils import *; strip_op
x
[
'optimizer'
]
=
None
x
[
'optimizer'
]
=
None
x
[
'model'
]
.
half
()
# to FP16
x
[
'model'
]
.
half
()
# to FP16
torch
.
save
(
x
,
f
)
torch
.
save
(
x
,
f
)
print
(
'Optimizer stripped from
%
s
'
%
f
)
print
(
'Optimizer stripped from
%
s
,
%.1
fMB'
%
(
f
,
os
.
path
.
getsize
(
f
)
/
1E6
)
)
def
create_pretrained
(
f
=
'weights/best.pt'
,
s
=
'weights/pretrained.pt'
):
# from utils.utils import *; create_pretrained()
def
create_pretrained
(
f
=
'weights/best.pt'
,
s
=
'weights/pretrained.pt'
):
# from utils.utils import *; create_pretrained()
# create pretrained checkpoint 's' from 'f' (create_pretrained(x, x) for x in glob.glob('./*.pt'))
# create pretrained checkpoint 's' from 'f' (create_pretrained(x, x) for x in glob.glob('./*.pt'))
device
=
torch
.
device
(
'cpu'
)
x
=
torch
.
load
(
f
,
map_location
=
torch
.
device
(
'cpu'
))
x
=
torch
.
load
(
s
,
map_location
=
device
)
x
[
'optimizer'
]
=
None
x
[
'optimizer'
]
=
None
x
[
'training_results'
]
=
None
x
[
'training_results'
]
=
None
x
[
'epoch'
]
=
-
1
x
[
'epoch'
]
=
-
1
...
@@ -651,7 +649,7 @@ def create_pretrained(f='weights/best.pt', s='weights/pretrained.pt'): # from u
...
@@ -651,7 +649,7 @@ def create_pretrained(f='weights/best.pt', s='weights/pretrained.pt'): # from u
for
p
in
x
[
'model'
]
.
parameters
():
for
p
in
x
[
'model'
]
.
parameters
():
p
.
requires_grad
=
True
p
.
requires_grad
=
True
torch
.
save
(
x
,
s
)
torch
.
save
(
x
,
s
)
print
(
'
%
s saved as pretrained checkpoint
%
s
'
%
(
f
,
s
))
print
(
'
%
s saved as pretrained checkpoint
%
s
,
%.1
fMB'
%
(
f
,
s
,
os
.
path
.
getsize
(
s
)
/
1E6
))
def
coco_class_count
(
path
=
'../coco/labels/train2014/'
):
def
coco_class_count
(
path
=
'../coco/labels/train2014/'
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论