Unverified 提交 a4207a20 authored 作者: imyhxy's avatar imyhxy 提交者: GitHub

Fix TensorRT potential unordered binding addresses (#5826)

* feat: change file suffix in pythonic way * fix: enforce binding addresses order * fix: enforce binding addresses order
上级 5ca5dd4c
...@@ -276,7 +276,7 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F ...@@ -276,7 +276,7 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
assert onnx.exists(), f'failed to export ONNX file: {onnx}' assert onnx.exists(), f'failed to export ONNX file: {onnx}'
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...') LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
f = str(file).replace('.pt', '.engine') # TensorRT engine file f = file.with_suffix('.engine') # TensorRT engine file
logger = trt.Logger(trt.Logger.INFO) logger = trt.Logger(trt.Logger.INFO)
if verbose: if verbose:
logger.min_severity = trt.Logger.Severity.VERBOSE logger.min_severity = trt.Logger.Severity.VERBOSE
...@@ -310,6 +310,7 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F ...@@ -310,6 +310,7 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
except Exception as e: except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}') LOGGER.info(f'\n{prefix} export failure: {e}')
@torch.no_grad() @torch.no_grad()
def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
weights=ROOT / 'yolov5s.pt', # weights path weights=ROOT / 'yolov5s.pt', # weights path
......
...@@ -7,7 +7,7 @@ import json ...@@ -7,7 +7,7 @@ import json
import math import math
import platform import platform
import warnings import warnings
from collections import namedtuple from collections import OrderedDict, namedtuple
from copy import copy from copy import copy
from pathlib import Path from pathlib import Path
...@@ -326,14 +326,14 @@ class DetectMultiBackend(nn.Module): ...@@ -326,14 +326,14 @@ class DetectMultiBackend(nn.Module):
logger = trt.Logger(trt.Logger.INFO) logger = trt.Logger(trt.Logger.INFO)
with open(w, 'rb') as f, trt.Runtime(logger) as runtime: with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(f.read()) model = runtime.deserialize_cuda_engine(f.read())
bindings = dict() bindings = OrderedDict()
for index in range(model.num_bindings): for index in range(model.num_bindings):
name = model.get_binding_name(index) name = model.get_binding_name(index)
dtype = trt.nptype(model.get_binding_dtype(index)) dtype = trt.nptype(model.get_binding_dtype(index))
shape = tuple(model.get_binding_shape(index)) shape = tuple(model.get_binding_shape(index))
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device) data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr())) bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
binding_addrs = {n: d.ptr for n, d in bindings.items()} binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
context = model.create_execution_context() context = model.create_execution_context()
batch_size = bindings['images'].shape[0] batch_size = bindings['images'].shape[0]
else: # TensorFlow model (TFLite, pb, saved_model) else: # TensorFlow model (TFLite, pb, saved_model)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论