Unverified 提交 f62609eb authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update check_requirements() with `cmds=()` argument (#7543)

上级 4b284a12
...@@ -218,13 +218,7 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F ...@@ -218,13 +218,7 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
try: try:
assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`' assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
try: check_requirements(('nvidia-tensorrt',), cmds=('-U --index-url https://pypi.ngc.nvidia.com',))
import tensorrt as trt
except Exception:
s = f"\n{prefix} tensorrt not found and is required by YOLOv5"
LOGGER.info(f"{s}, attempting auto-update...")
r = '-U nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com'
LOGGER.info(subprocess.check_output(f"pip install {r}", shell=True).decode())
import tensorrt as trt import tensorrt as trt
if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012 if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
......
...@@ -321,7 +321,7 @@ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=Fals ...@@ -321,7 +321,7 @@ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=Fals
@try_except @try_except
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True): def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
# Check installed dependencies meet requirements (pass *.txt file or list of packages) # Check installed dependencies meet requirements (pass *.txt file or list of packages)
prefix = colorstr('red', 'bold', 'requirements:') prefix = colorstr('red', 'bold', 'requirements:')
check_python() # check python version check_python() # check python version
...@@ -334,7 +334,7 @@ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), insta ...@@ -334,7 +334,7 @@ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), insta
requirements = [x for x in requirements if x not in exclude] requirements = [x for x in requirements if x not in exclude]
n = 0 # number of packages updates n = 0 # number of packages updates
for r in requirements: for i, r in enumerate(requirements):
try: try:
pkg.require(r) pkg.require(r)
except Exception: # DistributionNotFound or VersionConflict if requirements not met except Exception: # DistributionNotFound or VersionConflict if requirements not met
...@@ -343,7 +343,7 @@ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), insta ...@@ -343,7 +343,7 @@ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), insta
LOGGER.info(f"{s}, attempting auto-update...") LOGGER.info(f"{s}, attempting auto-update...")
try: try:
assert check_online(), f"'pip install {r}' skipped (offline)" assert check_online(), f"'pip install {r}' skipped (offline)"
LOGGER.info(check_output(f"pip install '{r}'", shell=True).decode()) LOGGER.info(check_output(f"pip install '{r}' {cmds[i] if cmds else ''}", shell=True).decode())
n += 1 n += 1
except Exception as e: except Exception as e:
LOGGER.warning(f'{prefix} {e}') LOGGER.warning(f'{prefix} {e}')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论