Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
Y
yolov5
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
Administrator
yolov5
Commits
5bd6a97b
Unverified
提交
5bd6a97b
authored
1月 04, 2022
作者:
Glenn Jocher
提交者:
GitHub
1月 04, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Global export format sort (#6182)
* Global export sort * Cleanup
上级
7cad6597
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
46 行增加
和
46 行删除
+46
-46
detect.py
detect.py
+3
-3
export.py
export.py
+0
-0
common.py
models/common.py
+40
-40
val.py
val.py
+3
-3
没有找到文件。
detect.py
浏览文件 @
5bd6a97b
...
...
@@ -15,13 +15,13 @@ Usage - formats:
$ python path/to/detect.py --weights yolov5s.pt # PyTorch
yolov5s.torchscript # TorchScript
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
yolov5s.mlmodel # CoreML (under development)
yolov5s.xml # OpenVINO
yolov5s.engine # TensorRT
yolov5s.mlmodel # CoreML (under development)
yolov5s_saved_model # TensorFlow SavedModel
yolov5s.pb # TensorFlow
protobu
f
yolov5s.pb # TensorFlow
GraphDe
f
yolov5s.tflite # TensorFlow Lite
yolov5s_edgetpu.tflite # TensorFlow Edge TPU
yolov5s.engine # TensorRT
"""
import
argparse
...
...
export.py
浏览文件 @
5bd6a97b
差异被折叠。
点击展开。
models/common.py
浏览文件 @
5bd6a97b
...
...
@@ -316,17 +316,6 @@ class DetectMultiBackend(nn.Module):
if
extra_files
[
'config.txt'
]:
d
=
json
.
loads
(
extra_files
[
'config.txt'
])
# extra_files dict
stride
,
names
=
int
(
d
[
'stride'
]),
d
[
'names'
]
elif
coreml
:
# CoreML
LOGGER
.
info
(
f
'Loading {w} for CoreML inference...'
)
import
coremltools
as
ct
model
=
ct
.
models
.
MLModel
(
w
)
elif
xml
:
# OpenVINO
LOGGER
.
info
(
f
'Loading {w} for OpenVINO inference...'
)
check_requirements
((
'openvino-dev'
,))
# requires openvino-dev: https://pypi.org/project/openvino-dev/
import
openvino.inference_engine
as
ie
core
=
ie
.
IECore
()
network
=
core
.
read_network
(
model
=
w
,
weights
=
Path
(
w
)
.
with_suffix
(
'.bin'
))
# *.xml, *.bin paths
executable_network
=
core
.
load_network
(
network
,
device_name
=
'CPU'
,
num_requests
=
1
)
elif
dnn
:
# ONNX OpenCV DNN
LOGGER
.
info
(
f
'Loading {w} for ONNX OpenCV DNN inference...'
)
check_requirements
((
'opencv-python>=4.5.4'
,))
...
...
@@ -338,6 +327,13 @@ class DetectMultiBackend(nn.Module):
import
onnxruntime
providers
=
[
'CUDAExecutionProvider'
,
'CPUExecutionProvider'
]
if
cuda
else
[
'CPUExecutionProvider'
]
session
=
onnxruntime
.
InferenceSession
(
w
,
providers
=
providers
)
elif
xml
:
# OpenVINO
LOGGER
.
info
(
f
'Loading {w} for OpenVINO inference...'
)
check_requirements
((
'openvino-dev'
,))
# requires openvino-dev: https://pypi.org/project/openvino-dev/
import
openvino.inference_engine
as
ie
core
=
ie
.
IECore
()
network
=
core
.
read_network
(
model
=
w
,
weights
=
Path
(
w
)
.
with_suffix
(
'.bin'
))
# *.xml, *.bin paths
executable_network
=
core
.
load_network
(
network
,
device_name
=
'CPU'
,
num_requests
=
1
)
elif
engine
:
# TensorRT
LOGGER
.
info
(
f
'Loading {w} for TensorRT inference...'
)
import
tensorrt
as
trt
# https://developer.nvidia.com/nvidia-tensorrt-download
...
...
@@ -356,9 +352,17 @@ class DetectMultiBackend(nn.Module):
binding_addrs
=
OrderedDict
((
n
,
d
.
ptr
)
for
n
,
d
in
bindings
.
items
())
context
=
model
.
create_execution_context
()
batch_size
=
bindings
[
'images'
]
.
shape
[
0
]
else
:
# TensorFlow (TFLite, pb, saved_model)
if
pb
:
# https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
LOGGER
.
info
(
f
'Loading {w} for TensorFlow *.pb inference...'
)
elif
coreml
:
# CoreML
LOGGER
.
info
(
f
'Loading {w} for CoreML inference...'
)
import
coremltools
as
ct
model
=
ct
.
models
.
MLModel
(
w
)
else
:
# TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
if
saved_model
:
# SavedModel
LOGGER
.
info
(
f
'Loading {w} for TensorFlow SavedModel inference...'
)
import
tensorflow
as
tf
model
=
tf
.
keras
.
models
.
load_model
(
w
)
elif
pb
:
# GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
LOGGER
.
info
(
f
'Loading {w} for TensorFlow GraphDef inference...'
)
import
tensorflow
as
tf
def
wrap_frozen_graph
(
gd
,
inputs
,
outputs
):
...
...
@@ -369,19 +373,15 @@ class DetectMultiBackend(nn.Module):
graph_def
=
tf
.
Graph
()
.
as_graph_def
()
graph_def
.
ParseFromString
(
open
(
w
,
'rb'
)
.
read
())
frozen_func
=
wrap_frozen_graph
(
gd
=
graph_def
,
inputs
=
"x:0"
,
outputs
=
"Identity:0"
)
elif
saved_model
:
LOGGER
.
info
(
f
'Loading {w} for TensorFlow saved_model inference...'
)
import
tensorflow
as
tf
model
=
tf
.
keras
.
models
.
load_model
(
w
)
elif
tflite
:
# https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
if
'edgetpu'
in
w
.
lower
():
if
'edgetpu'
in
w
.
lower
():
# Edge TPU
LOGGER
.
info
(
f
'Loading {w} for TensorFlow Lite Edge TPU inference...'
)
import
tflite_runtime.interpreter
as
tfli
delegate
=
{
'Linux'
:
'libedgetpu.so.1'
,
# install https://coral.ai/software/#edgetpu-runtime
'Darwin'
:
'libedgetpu.1.dylib'
,
'Windows'
:
'edgetpu.dll'
}[
platform
.
system
()]
interpreter
=
tfli
.
Interpreter
(
model_path
=
w
,
experimental_delegates
=
[
tfli
.
load_delegate
(
delegate
)])
else
:
else
:
# Lite
LOGGER
.
info
(
f
'Loading {w} for TensorFlow Lite inference...'
)
import
tensorflow
as
tf
interpreter
=
tf
.
lite
.
Interpreter
(
model_path
=
w
)
# load TFLite model
...
...
@@ -396,21 +396,13 @@ class DetectMultiBackend(nn.Module):
if
self
.
pt
or
self
.
jit
:
# PyTorch
y
=
self
.
model
(
im
)
if
self
.
jit
else
self
.
model
(
im
,
augment
=
augment
,
visualize
=
visualize
)
return
y
if
val
else
y
[
0
]
elif
self
.
coreml
:
# CoreML
im
=
im
.
permute
(
0
,
2
,
3
,
1
)
.
cpu
()
.
numpy
()
# torch BCHW to numpy BHWC shape(1,320,192,3)
im
=
Image
.
fromarray
((
im
[
0
]
*
255
)
.
astype
(
'uint8'
))
# im = im.resize((192, 320), Image.ANTIALIAS)
y
=
self
.
model
.
predict
({
'image'
:
im
})
# coordinates are xywh normalized
box
=
xywh2xyxy
(
y
[
'coordinates'
]
*
[[
w
,
h
,
w
,
h
]])
# xyxy pixels
conf
,
cls
=
y
[
'confidence'
]
.
max
(
1
),
y
[
'confidence'
]
.
argmax
(
1
)
.
astype
(
np
.
float
)
y
=
np
.
concatenate
((
box
,
conf
.
reshape
(
-
1
,
1
),
cls
.
reshape
(
-
1
,
1
)),
1
)
elif
self
.
onnx
:
# ONNX
elif
self
.
dnn
:
# ONNX OpenCV DNN
im
=
im
.
cpu
()
.
numpy
()
# torch to numpy
if
self
.
dnn
:
# ONNX OpenCV DNN
self
.
net
.
setInput
(
im
)
y
=
self
.
net
.
forward
()
else
:
# ONNX Runtime
y
=
self
.
session
.
run
([
self
.
session
.
get_outputs
()[
0
]
.
name
],
{
self
.
session
.
get_inputs
()[
0
]
.
name
:
im
})[
0
]
self
.
net
.
setInput
(
im
)
y
=
self
.
net
.
forward
(
)
elif
self
.
onnx
:
# ONNX Runtime
im
=
im
.
cpu
()
.
numpy
()
# torch to numpy
y
=
self
.
session
.
run
([
self
.
session
.
get_outputs
()[
0
]
.
name
],
{
self
.
session
.
get_inputs
()[
0
]
.
name
:
im
})[
0
]
elif
self
.
xml
:
# OpenVINO
im
=
im
.
cpu
()
.
numpy
()
# FP32
desc
=
self
.
ie
.
TensorDesc
(
precision
=
'FP32'
,
dims
=
im
.
shape
,
layout
=
'NCHW'
)
# Tensor Description
...
...
@@ -423,13 +415,21 @@ class DetectMultiBackend(nn.Module):
self
.
binding_addrs
[
'images'
]
=
int
(
im
.
data_ptr
())
self
.
context
.
execute_v2
(
list
(
self
.
binding_addrs
.
values
()))
y
=
self
.
bindings
[
'output'
]
.
data
el
se
:
# TensorFlow model (TFLite, pb, saved_model)
el
if
self
.
coreml
:
# CoreML
im
=
im
.
permute
(
0
,
2
,
3
,
1
)
.
cpu
()
.
numpy
()
# torch BCHW to numpy BHWC shape(1,320,192,3)
if
self
.
pb
:
y
=
self
.
frozen_func
(
x
=
self
.
tf
.
constant
(
im
))
.
numpy
()
elif
self
.
saved_model
:
im
=
Image
.
fromarray
((
im
[
0
]
*
255
)
.
astype
(
'uint8'
))
# im = im.resize((192, 320), Image.ANTIALIAS)
y
=
self
.
model
.
predict
({
'image'
:
im
})
# coordinates are xywh normalized
box
=
xywh2xyxy
(
y
[
'coordinates'
]
*
[[
w
,
h
,
w
,
h
]])
# xyxy pixels
conf
,
cls
=
y
[
'confidence'
]
.
max
(
1
),
y
[
'confidence'
]
.
argmax
(
1
)
.
astype
(
np
.
float
)
y
=
np
.
concatenate
((
box
,
conf
.
reshape
(
-
1
,
1
),
cls
.
reshape
(
-
1
,
1
)),
1
)
else
:
# TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
im
=
im
.
permute
(
0
,
2
,
3
,
1
)
.
cpu
()
.
numpy
()
# torch BCHW to numpy BHWC shape(1,320,192,3)
if
self
.
saved_model
:
# SavedModel
y
=
self
.
model
(
im
,
training
=
False
)
.
numpy
()
elif
self
.
tflite
:
elif
self
.
pb
:
# GraphDef
y
=
self
.
frozen_func
(
x
=
self
.
tf
.
constant
(
im
))
.
numpy
()
elif
self
.
tflite
:
# Lite
input
,
output
=
self
.
input_details
[
0
],
self
.
output_details
[
0
]
int8
=
input
[
'dtype'
]
==
np
.
uint8
# is TFLite quantized uint8 model
if
int8
:
...
...
@@ -451,7 +451,7 @@ class DetectMultiBackend(nn.Module):
def
warmup
(
self
,
imgsz
=
(
1
,
3
,
640
,
640
),
half
=
False
):
# Warmup model by running inference once
if
self
.
pt
or
self
.
engine
or
self
.
onnx
:
# warmup types
if
self
.
pt
or
self
.
jit
or
self
.
onnx
or
self
.
engine
:
# warmup types
if
isinstance
(
self
.
device
,
torch
.
device
)
and
self
.
device
.
type
!=
'cpu'
:
# only warmup GPU models
im
=
torch
.
zeros
(
*
imgsz
)
.
to
(
self
.
device
)
.
type
(
torch
.
half
if
half
else
torch
.
float
)
# input image
self
.
forward
(
im
)
# warmup
...
...
val.py
浏览文件 @
5bd6a97b
...
...
@@ -9,13 +9,13 @@ Usage - formats:
$ python path/to/val.py --weights yolov5s.pt # PyTorch
yolov5s.torchscript # TorchScript
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
yolov5s.mlmodel # CoreML (under development)
yolov5s.xml # OpenVINO
yolov5s.engine # TensorRT
yolov5s.mlmodel # CoreML (under development)
yolov5s_saved_model # TensorFlow SavedModel
yolov5s.pb # TensorFlow
protobu
f
yolov5s.pb # TensorFlow
GraphDe
f
yolov5s.tflite # TensorFlow Lite
yolov5s_edgetpu.tflite # TensorFlow Edge TPU
yolov5s.engine # TensorRT
"""
import
argparse
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论