Unverified 提交 7b31a531 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Add `tensorrt>=7.0.0` checks (#6193)

* Add `tensorrt>=7.0.0` checks * Update export.py * Update common.py * Update export.py
上级 a2f4a179
...@@ -61,8 +61,8 @@ from models.experimental import attempt_load ...@@ -61,8 +61,8 @@ from models.experimental import attempt_load
from models.yolo import Detect from models.yolo import Detect
from utils.activations import SiLU from utils.activations import SiLU
from utils.datasets import LoadImages from utils.datasets import LoadImages
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, colorstr, file_size, print_args, from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, colorstr,
url2file) file_size, print_args, url2file)
from utils.torch_utils import select_device from utils.torch_utils import select_device
...@@ -174,14 +174,14 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F ...@@ -174,14 +174,14 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
check_requirements(('tensorrt',)) check_requirements(('tensorrt',))
import tensorrt as trt import tensorrt as trt
opset = (12, 13)[trt.__version__[0] == '8'] # test on TensorRT 7.x and 8.x if trt.__version__[0] == 7: # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
if opset == 12: # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
grid = model.model[-1].anchor_grid grid = model.model[-1].anchor_grid
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid] model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
export_onnx(model, im, file, opset, train, False, simplify) export_onnx(model, im, file, 12, train, False, simplify) # opset 12
model.model[-1].anchor_grid = grid model.model[-1].anchor_grid = grid
else: # TensorRT >= 8 else: # TensorRT >= 8
export_onnx(model, im, file, opset, train, False, simplify) check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=8.0.0
export_onnx(model, im, file, 13, train, False, simplify) # opset 13
onnx = file.with_suffix('.onnx') onnx = file.with_suffix('.onnx')
assert onnx.exists(), f'failed to export ONNX file: {onnx}' assert onnx.exists(), f'failed to export ONNX file: {onnx}'
......
...@@ -337,7 +337,7 @@ class DetectMultiBackend(nn.Module): ...@@ -337,7 +337,7 @@ class DetectMultiBackend(nn.Module):
elif engine: # TensorRT elif engine: # TensorRT
LOGGER.info(f'Loading {w} for TensorRT inference...') LOGGER.info(f'Loading {w} for TensorRT inference...')
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
check_version(trt.__version__, '8.0.0', verbose=True) # version requirement check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
logger = trt.Logger(trt.Logger.INFO) logger = trt.Logger(trt.Logger.INFO)
with open(w, 'rb') as f, trt.Runtime(logger) as runtime: with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论