Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
Y
yolov5
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
Administrator
yolov5
Commits
5bd6a97b
Unverified
提交
5bd6a97b
authored
1月 04, 2022
作者:
Glenn Jocher
提交者:
GitHub
1月 04, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Global export format sort (#6182)
* Global export sort * Cleanup
上级
7cad6597
显示空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
117 行增加
和
117 行删除
+117
-117
detect.py
detect.py
+3
-3
export.py
export.py
+74
-74
common.py
models/common.py
+37
-37
val.py
val.py
+3
-3
没有找到文件。
detect.py
浏览文件 @
5bd6a97b
...
...
@@ -15,13 +15,13 @@ Usage - formats:
$ python path/to/detect.py --weights yolov5s.pt # PyTorch
yolov5s.torchscript # TorchScript
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
yolov5s.mlmodel # CoreML (under development)
yolov5s.xml # OpenVINO
yolov5s.engine # TensorRT
yolov5s.mlmodel # CoreML (under development)
yolov5s_saved_model # TensorFlow SavedModel
yolov5s.pb # TensorFlow
protobu
f
yolov5s.pb # TensorFlow
GraphDe
f
yolov5s.tflite # TensorFlow Lite
yolov5s_edgetpu.tflite # TensorFlow Edge TPU
yolov5s.engine # TensorRT
"""
import
argparse
...
...
export.py
浏览文件 @
5bd6a97b
...
...
@@ -2,19 +2,19 @@
"""
Export a YOLOv5 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit
Format |
Example | `--include ...` argument
Format |
`export.py --include` | Model
--- | --- | ---
PyTorch |
yolov5s.pt | -
TorchScript |
yolov5s.torchscript | `torchscript`
ONNX |
yolov5s.onnx | `onnx`
CoreML | yolov5s.mlmodel | `coreml`
OpenVINO | yolov5s_openvino_model/ | `openvino`
TensorFlow SavedModel | yolov5s_saved_model/ | `saved_model`
TensorFlow
GraphDef | yolov5s.pb | `pb`
TensorFlow
Lite | yolov5s.tflite | `tflite`
TensorFlow
Edge TPU | yolov5s_edgetpu.tflite | `edgetpu`
TensorFlow
.js | yolov5s_web_model/ | `tfjs`
Tensor
RT | yolov5s.engine | `engine`
PyTorch |
- | yolov5s.pt
TorchScript |
`torchscript` | yolov5s.torchscript
ONNX |
`onnx` | yolov5s.onnx
OpenVINO | `openvino` | yolov5s_openvino_model/
TensorRT | `engine` | yolov5s.engine
CoreML | `coreml` | yolov5s.mlmodel
TensorFlow
SavedModel | `saved_model` | yolov5s_saved_model/
TensorFlow
GraphDef | `pb` | yolov5s.pb
TensorFlow
Lite | `tflite` | yolov5s.tflite
TensorFlow
Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite
Tensor
Flow.js | `tfjs` | yolov5s_web_model/
Usage:
$ python path/to/export.py --weights yolov5s.pt --include torchscript onnx coreml openvino saved_model tflite tfjs
...
...
@@ -23,13 +23,13 @@ Inference:
$ python path/to/detect.py --weights yolov5s.pt # PyTorch
yolov5s.torchscript # TorchScript
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
yolov5s.mlmodel # CoreML (under development)
yolov5s.xml # OpenVINO
yolov5s.engine # TensorRT
yolov5s.mlmodel # CoreML (under development)
yolov5s_saved_model # TensorFlow SavedModel
yolov5s.pb # TensorFlow
protobu
f
yolov5s.pb # TensorFlow
GraphDe
f
yolov5s.tflite # TensorFlow Lite
yolov5s_edgetpu.tflite # TensorFlow Edge TPU
yolov5s.engine # TensorRT
TensorFlow.js:
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
...
...
@@ -126,6 +126,23 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst
LOGGER
.
info
(
f
'{prefix} export failure: {e}'
)
def
export_openvino
(
model
,
im
,
file
,
prefix
=
colorstr
(
'OpenVINO:'
)):
# YOLOv5 OpenVINO export
try
:
check_requirements
((
'openvino-dev'
,))
# requires openvino-dev: https://pypi.org/project/openvino-dev/
import
openvino.inference_engine
as
ie
LOGGER
.
info
(
f
'
\n
{prefix} starting export with openvino {ie.__version__}...'
)
f
=
str
(
file
)
.
replace
(
'.pt'
,
'_openvino_model'
+
os
.
sep
)
cmd
=
f
"mo --input_model {file.with_suffix('.onnx')} --output_dir {f}"
subprocess
.
check_output
(
cmd
,
shell
=
True
)
LOGGER
.
info
(
f
'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)'
)
except
Exception
as
e
:
LOGGER
.
info
(
f
'
\n
{prefix} export failure: {e}'
)
def
export_coreml
(
model
,
im
,
file
,
prefix
=
colorstr
(
'CoreML:'
)):
# YOLOv5 CoreML export
ct_model
=
None
...
...
@@ -148,27 +165,57 @@ def export_coreml(model, im, file, prefix=colorstr('CoreML:')):
return
ct_model
def
export_
openvino
(
model
,
im
,
file
,
prefix
=
colorstr
(
'OpenVINO
:'
)):
# YOLOv5
OpenVINO expo
rt
def
export_
engine
(
model
,
im
,
file
,
train
,
half
,
simplify
,
workspace
=
4
,
verbose
=
False
,
prefix
=
colorstr
(
'TensorRT
:'
)):
# YOLOv5
TensorRT export https://developer.nvidia.com/tensor
rt
try
:
check_requirements
((
'
openvino-dev'
,))
# requires openvino-dev: https://pypi.org/project/openvino-dev/
import
openvino.inference_engine
as
ie
check_requirements
((
'
tensorrt'
,))
import
tensorrt
as
trt
LOGGER
.
info
(
f
'
\n
{prefix} starting export with openvino {ie.__version__}...'
)
f
=
str
(
file
)
.
replace
(
'.pt'
,
'_openvino_model'
+
os
.
sep
)
opset
=
(
12
,
13
)[
trt
.
__version__
[
0
]
==
'8'
]
# test on TensorRT 7.x and 8.x
export_onnx
(
model
,
im
,
file
,
opset
,
train
,
False
,
simplify
)
onnx
=
file
.
with_suffix
(
'.onnx'
)
assert
onnx
.
exists
(),
f
'failed to export ONNX file: {onnx}'
cmd
=
f
"mo --input_model {file.with_suffix('.onnx')} --output_dir {f}"
subprocess
.
check_output
(
cmd
,
shell
=
True
)
LOGGER
.
info
(
f
'
\n
{prefix} starting export with TensorRT {trt.__version__}...'
)
f
=
file
.
with_suffix
(
'.engine'
)
# TensorRT engine file
logger
=
trt
.
Logger
(
trt
.
Logger
.
INFO
)
if
verbose
:
logger
.
min_severity
=
trt
.
Logger
.
Severity
.
VERBOSE
builder
=
trt
.
Builder
(
logger
)
config
=
builder
.
create_builder_config
()
config
.
max_workspace_size
=
workspace
*
1
<<
30
flag
=
(
1
<<
int
(
trt
.
NetworkDefinitionCreationFlag
.
EXPLICIT_BATCH
))
network
=
builder
.
create_network
(
flag
)
parser
=
trt
.
OnnxParser
(
network
,
logger
)
if
not
parser
.
parse_from_file
(
str
(
onnx
)):
raise
RuntimeError
(
f
'failed to load ONNX file: {onnx}'
)
inputs
=
[
network
.
get_input
(
i
)
for
i
in
range
(
network
.
num_inputs
)]
outputs
=
[
network
.
get_output
(
i
)
for
i
in
range
(
network
.
num_outputs
)]
LOGGER
.
info
(
f
'{prefix} Network Description:'
)
for
inp
in
inputs
:
LOGGER
.
info
(
f
'{prefix}
\t
input "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}'
)
for
out
in
outputs
:
LOGGER
.
info
(
f
'{prefix}
\t
output "{out.name}" with shape {out.shape} and dtype {out.dtype}'
)
half
&=
builder
.
platform_has_fast_fp16
LOGGER
.
info
(
f
'{prefix} building FP{16 if half else 32} engine in {f}'
)
if
half
:
config
.
set_flag
(
trt
.
BuilderFlag
.
FP16
)
with
builder
.
build_engine
(
network
,
config
)
as
engine
,
open
(
f
,
'wb'
)
as
t
:
t
.
write
(
engine
.
serialize
())
LOGGER
.
info
(
f
'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)'
)
except
Exception
as
e
:
LOGGER
.
info
(
f
'
\n
{prefix} export failure: {e}'
)
def
export_saved_model
(
model
,
im
,
file
,
dynamic
,
tf_nms
=
False
,
agnostic_nms
=
False
,
topk_per_class
=
100
,
topk_all
=
100
,
iou_thres
=
0.45
,
conf_thres
=
0.25
,
prefix
=
colorstr
(
'TensorFlow
saved_m
odel:'
)):
# YOLOv5 TensorFlow
saved_m
odel export
conf_thres
=
0.25
,
prefix
=
colorstr
(
'TensorFlow
SavedM
odel:'
)):
# YOLOv5 TensorFlow
SavedM
odel export
keras_model
=
None
try
:
import
tensorflow
as
tf
...
...
@@ -304,53 +351,6 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
LOGGER
.
info
(
f
'
\n
{prefix} export failure: {e}'
)
def
export_engine
(
model
,
im
,
file
,
train
,
half
,
simplify
,
workspace
=
4
,
verbose
=
False
,
prefix
=
colorstr
(
'TensorRT:'
)):
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
try
:
check_requirements
((
'tensorrt'
,))
import
tensorrt
as
trt
opset
=
(
12
,
13
)[
trt
.
__version__
[
0
]
==
'8'
]
# test on TensorRT 7.x and 8.x
export_onnx
(
model
,
im
,
file
,
opset
,
train
,
False
,
simplify
)
onnx
=
file
.
with_suffix
(
'.onnx'
)
assert
onnx
.
exists
(),
f
'failed to export ONNX file: {onnx}'
LOGGER
.
info
(
f
'
\n
{prefix} starting export with TensorRT {trt.__version__}...'
)
f
=
file
.
with_suffix
(
'.engine'
)
# TensorRT engine file
logger
=
trt
.
Logger
(
trt
.
Logger
.
INFO
)
if
verbose
:
logger
.
min_severity
=
trt
.
Logger
.
Severity
.
VERBOSE
builder
=
trt
.
Builder
(
logger
)
config
=
builder
.
create_builder_config
()
config
.
max_workspace_size
=
workspace
*
1
<<
30
flag
=
(
1
<<
int
(
trt
.
NetworkDefinitionCreationFlag
.
EXPLICIT_BATCH
))
network
=
builder
.
create_network
(
flag
)
parser
=
trt
.
OnnxParser
(
network
,
logger
)
if
not
parser
.
parse_from_file
(
str
(
onnx
)):
raise
RuntimeError
(
f
'failed to load ONNX file: {onnx}'
)
inputs
=
[
network
.
get_input
(
i
)
for
i
in
range
(
network
.
num_inputs
)]
outputs
=
[
network
.
get_output
(
i
)
for
i
in
range
(
network
.
num_outputs
)]
LOGGER
.
info
(
f
'{prefix} Network Description:'
)
for
inp
in
inputs
:
LOGGER
.
info
(
f
'{prefix}
\t
input "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}'
)
for
out
in
outputs
:
LOGGER
.
info
(
f
'{prefix}
\t
output "{out.name}" with shape {out.shape} and dtype {out.dtype}'
)
half
&=
builder
.
platform_has_fast_fp16
LOGGER
.
info
(
f
'{prefix} building FP{16 if half else 32} engine in {f}'
)
if
half
:
config
.
set_flag
(
trt
.
BuilderFlag
.
FP16
)
with
builder
.
build_engine
(
network
,
config
)
as
engine
,
open
(
f
,
'wb'
)
as
t
:
t
.
write
(
engine
.
serialize
())
LOGGER
.
info
(
f
'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)'
)
except
Exception
as
e
:
LOGGER
.
info
(
f
'
\n
{prefix} export failure: {e}'
)
@torch.no_grad
()
def
run
(
data
=
ROOT
/
'data/coco128.yaml'
,
# 'dataset.yaml path'
weights
=
ROOT
/
'yolov5s.pt'
,
# weights path
...
...
@@ -417,12 +417,12 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
export_torchscript
(
model
,
im
,
file
,
optimize
)
if
(
'onnx'
in
include
)
or
(
'openvino'
in
include
):
# OpenVINO requires ONNX
export_onnx
(
model
,
im
,
file
,
opset
,
train
,
dynamic
,
simplify
)
if
'openvino'
in
include
:
export_openvino
(
model
,
im
,
file
)
if
'engine'
in
include
:
export_engine
(
model
,
im
,
file
,
train
,
half
,
simplify
,
workspace
,
verbose
)
if
'coreml'
in
include
:
export_coreml
(
model
,
im
,
file
)
if
'openvino'
in
include
:
export_openvino
(
model
,
im
,
file
)
# TensorFlow Exports
if
any
(
tf_exports
):
...
...
models/common.py
浏览文件 @
5bd6a97b
...
...
@@ -316,17 +316,6 @@ class DetectMultiBackend(nn.Module):
if
extra_files
[
'config.txt'
]:
d
=
json
.
loads
(
extra_files
[
'config.txt'
])
# extra_files dict
stride
,
names
=
int
(
d
[
'stride'
]),
d
[
'names'
]
elif
coreml
:
# CoreML
LOGGER
.
info
(
f
'Loading {w} for CoreML inference...'
)
import
coremltools
as
ct
model
=
ct
.
models
.
MLModel
(
w
)
elif
xml
:
# OpenVINO
LOGGER
.
info
(
f
'Loading {w} for OpenVINO inference...'
)
check_requirements
((
'openvino-dev'
,))
# requires openvino-dev: https://pypi.org/project/openvino-dev/
import
openvino.inference_engine
as
ie
core
=
ie
.
IECore
()
network
=
core
.
read_network
(
model
=
w
,
weights
=
Path
(
w
)
.
with_suffix
(
'.bin'
))
# *.xml, *.bin paths
executable_network
=
core
.
load_network
(
network
,
device_name
=
'CPU'
,
num_requests
=
1
)
elif
dnn
:
# ONNX OpenCV DNN
LOGGER
.
info
(
f
'Loading {w} for ONNX OpenCV DNN inference...'
)
check_requirements
((
'opencv-python>=4.5.4'
,))
...
...
@@ -338,6 +327,13 @@ class DetectMultiBackend(nn.Module):
import
onnxruntime
providers
=
[
'CUDAExecutionProvider'
,
'CPUExecutionProvider'
]
if
cuda
else
[
'CPUExecutionProvider'
]
session
=
onnxruntime
.
InferenceSession
(
w
,
providers
=
providers
)
elif
xml
:
# OpenVINO
LOGGER
.
info
(
f
'Loading {w} for OpenVINO inference...'
)
check_requirements
((
'openvino-dev'
,))
# requires openvino-dev: https://pypi.org/project/openvino-dev/
import
openvino.inference_engine
as
ie
core
=
ie
.
IECore
()
network
=
core
.
read_network
(
model
=
w
,
weights
=
Path
(
w
)
.
with_suffix
(
'.bin'
))
# *.xml, *.bin paths
executable_network
=
core
.
load_network
(
network
,
device_name
=
'CPU'
,
num_requests
=
1
)
elif
engine
:
# TensorRT
LOGGER
.
info
(
f
'Loading {w} for TensorRT inference...'
)
import
tensorrt
as
trt
# https://developer.nvidia.com/nvidia-tensorrt-download
...
...
@@ -356,9 +352,17 @@ class DetectMultiBackend(nn.Module):
binding_addrs
=
OrderedDict
((
n
,
d
.
ptr
)
for
n
,
d
in
bindings
.
items
())
context
=
model
.
create_execution_context
()
batch_size
=
bindings
[
'images'
]
.
shape
[
0
]
else
:
# TensorFlow (TFLite, pb, saved_model)
if
pb
:
# https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
LOGGER
.
info
(
f
'Loading {w} for TensorFlow *.pb inference...'
)
elif
coreml
:
# CoreML
LOGGER
.
info
(
f
'Loading {w} for CoreML inference...'
)
import
coremltools
as
ct
model
=
ct
.
models
.
MLModel
(
w
)
else
:
# TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
if
saved_model
:
# SavedModel
LOGGER
.
info
(
f
'Loading {w} for TensorFlow SavedModel inference...'
)
import
tensorflow
as
tf
model
=
tf
.
keras
.
models
.
load_model
(
w
)
elif
pb
:
# GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
LOGGER
.
info
(
f
'Loading {w} for TensorFlow GraphDef inference...'
)
import
tensorflow
as
tf
def
wrap_frozen_graph
(
gd
,
inputs
,
outputs
):
...
...
@@ -369,19 +373,15 @@ class DetectMultiBackend(nn.Module):
graph_def
=
tf
.
Graph
()
.
as_graph_def
()
graph_def
.
ParseFromString
(
open
(
w
,
'rb'
)
.
read
())
frozen_func
=
wrap_frozen_graph
(
gd
=
graph_def
,
inputs
=
"x:0"
,
outputs
=
"Identity:0"
)
elif
saved_model
:
LOGGER
.
info
(
f
'Loading {w} for TensorFlow saved_model inference...'
)
import
tensorflow
as
tf
model
=
tf
.
keras
.
models
.
load_model
(
w
)
elif
tflite
:
# https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
if
'edgetpu'
in
w
.
lower
():
if
'edgetpu'
in
w
.
lower
():
# Edge TPU
LOGGER
.
info
(
f
'Loading {w} for TensorFlow Lite Edge TPU inference...'
)
import
tflite_runtime.interpreter
as
tfli
delegate
=
{
'Linux'
:
'libedgetpu.so.1'
,
# install https://coral.ai/software/#edgetpu-runtime
'Darwin'
:
'libedgetpu.1.dylib'
,
'Windows'
:
'edgetpu.dll'
}[
platform
.
system
()]
interpreter
=
tfli
.
Interpreter
(
model_path
=
w
,
experimental_delegates
=
[
tfli
.
load_delegate
(
delegate
)])
else
:
else
:
# Lite
LOGGER
.
info
(
f
'Loading {w} for TensorFlow Lite inference...'
)
import
tensorflow
as
tf
interpreter
=
tf
.
lite
.
Interpreter
(
model_path
=
w
)
# load TFLite model
...
...
@@ -396,20 +396,12 @@ class DetectMultiBackend(nn.Module):
if
self
.
pt
or
self
.
jit
:
# PyTorch
y
=
self
.
model
(
im
)
if
self
.
jit
else
self
.
model
(
im
,
augment
=
augment
,
visualize
=
visualize
)
return
y
if
val
else
y
[
0
]
elif
self
.
coreml
:
# CoreML
im
=
im
.
permute
(
0
,
2
,
3
,
1
)
.
cpu
()
.
numpy
()
# torch BCHW to numpy BHWC shape(1,320,192,3)
im
=
Image
.
fromarray
((
im
[
0
]
*
255
)
.
astype
(
'uint8'
))
# im = im.resize((192, 320), Image.ANTIALIAS)
y
=
self
.
model
.
predict
({
'image'
:
im
})
# coordinates are xywh normalized
box
=
xywh2xyxy
(
y
[
'coordinates'
]
*
[[
w
,
h
,
w
,
h
]])
# xyxy pixels
conf
,
cls
=
y
[
'confidence'
]
.
max
(
1
),
y
[
'confidence'
]
.
argmax
(
1
)
.
astype
(
np
.
float
)
y
=
np
.
concatenate
((
box
,
conf
.
reshape
(
-
1
,
1
),
cls
.
reshape
(
-
1
,
1
)),
1
)
elif
self
.
onnx
:
# ONNX
elif
self
.
dnn
:
# ONNX OpenCV DNN
im
=
im
.
cpu
()
.
numpy
()
# torch to numpy
if
self
.
dnn
:
# ONNX OpenCV DNN
self
.
net
.
setInput
(
im
)
y
=
self
.
net
.
forward
()
else
:
# ONNX Runtime
elif
self
.
onnx
:
# ONNX Runtime
im
=
im
.
cpu
()
.
numpy
()
# torch to numpy
y
=
self
.
session
.
run
([
self
.
session
.
get_outputs
()[
0
]
.
name
],
{
self
.
session
.
get_inputs
()[
0
]
.
name
:
im
})[
0
]
elif
self
.
xml
:
# OpenVINO
im
=
im
.
cpu
()
.
numpy
()
# FP32
...
...
@@ -423,13 +415,21 @@ class DetectMultiBackend(nn.Module):
self
.
binding_addrs
[
'images'
]
=
int
(
im
.
data_ptr
())
self
.
context
.
execute_v2
(
list
(
self
.
binding_addrs
.
values
()))
y
=
self
.
bindings
[
'output'
]
.
data
el
se
:
# TensorFlow model (TFLite, pb, saved_model)
el
if
self
.
coreml
:
# CoreML
im
=
im
.
permute
(
0
,
2
,
3
,
1
)
.
cpu
()
.
numpy
()
# torch BCHW to numpy BHWC shape(1,320,192,3)
if
self
.
pb
:
y
=
self
.
frozen_func
(
x
=
self
.
tf
.
constant
(
im
))
.
numpy
()
elif
self
.
saved_model
:
im
=
Image
.
fromarray
((
im
[
0
]
*
255
)
.
astype
(
'uint8'
))
# im = im.resize((192, 320), Image.ANTIALIAS)
y
=
self
.
model
.
predict
({
'image'
:
im
})
# coordinates are xywh normalized
box
=
xywh2xyxy
(
y
[
'coordinates'
]
*
[[
w
,
h
,
w
,
h
]])
# xyxy pixels
conf
,
cls
=
y
[
'confidence'
]
.
max
(
1
),
y
[
'confidence'
]
.
argmax
(
1
)
.
astype
(
np
.
float
)
y
=
np
.
concatenate
((
box
,
conf
.
reshape
(
-
1
,
1
),
cls
.
reshape
(
-
1
,
1
)),
1
)
else
:
# TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
im
=
im
.
permute
(
0
,
2
,
3
,
1
)
.
cpu
()
.
numpy
()
# torch BCHW to numpy BHWC shape(1,320,192,3)
if
self
.
saved_model
:
# SavedModel
y
=
self
.
model
(
im
,
training
=
False
)
.
numpy
()
elif
self
.
tflite
:
elif
self
.
pb
:
# GraphDef
y
=
self
.
frozen_func
(
x
=
self
.
tf
.
constant
(
im
))
.
numpy
()
elif
self
.
tflite
:
# Lite
input
,
output
=
self
.
input_details
[
0
],
self
.
output_details
[
0
]
int8
=
input
[
'dtype'
]
==
np
.
uint8
# is TFLite quantized uint8 model
if
int8
:
...
...
@@ -451,7 +451,7 @@ class DetectMultiBackend(nn.Module):
def
warmup
(
self
,
imgsz
=
(
1
,
3
,
640
,
640
),
half
=
False
):
# Warmup model by running inference once
if
self
.
pt
or
self
.
engine
or
self
.
onnx
:
# warmup types
if
self
.
pt
or
self
.
jit
or
self
.
onnx
or
self
.
engine
:
# warmup types
if
isinstance
(
self
.
device
,
torch
.
device
)
and
self
.
device
.
type
!=
'cpu'
:
# only warmup GPU models
im
=
torch
.
zeros
(
*
imgsz
)
.
to
(
self
.
device
)
.
type
(
torch
.
half
if
half
else
torch
.
float
)
# input image
self
.
forward
(
im
)
# warmup
...
...
val.py
浏览文件 @
5bd6a97b
...
...
@@ -9,13 +9,13 @@ Usage - formats:
$ python path/to/val.py --weights yolov5s.pt # PyTorch
yolov5s.torchscript # TorchScript
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
yolov5s.mlmodel # CoreML (under development)
yolov5s.xml # OpenVINO
yolov5s.engine # TensorRT
yolov5s.mlmodel # CoreML (under development)
yolov5s_saved_model # TensorFlow SavedModel
yolov5s.pb # TensorFlow
protobu
f
yolov5s.pb # TensorFlow
GraphDe
f
yolov5s.tflite # TensorFlow Lite
yolov5s_edgetpu.tflite # TensorFlow Edge TPU
yolov5s.engine # TensorRT
"""
import
argparse
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论