Unverified 提交 d669a746 authored 作者: Gaz Iqbal's avatar Gaz Iqbal 提交者: GitHub

Detect.py supports running against a Triton container (#9228)

* update coco128-seg comments * Enables detect.py to use Triton for inference Triton Inference Server is an open source inference serving software that streamlines AI inferencing. https://github.com/triton-inference-server/server The user can now provide a "--triton-url" argument to detect.py to use a local or remote Triton server for inference. For e.g., http://localhost:8000 will use http over port 8000 and grpc://localhost:8001 will use grpc over port 8001. Note, it is not necessary to specify a weights file to use Triton. A Triton container can be created by first exporting the Yolov5 model to a Triton supported runtime. Onnx, Torchscript, TensorRT are supported by both Triton and the export.py script. The exported model can then be containerized via the OctoML CLI. See https://github.com/octoml/octo-cli#getting-started for a guide. * added triton client to requirements * fixed support for TFSavedModels in Triton * reverted change * Test CoreML update Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update ci-testing.yml Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Use pathlib Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Refacto DetectMultiBackend to directly accept triton url as --weights http://... Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Deploy category Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update detect.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update common.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update common.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update predict.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update predict.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update predict.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update triton.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update triton.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add printout and requirements check * Cleanup Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * triton fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed triton model query over grpc * Update check_requirements('tritonclient[all]') * group imports * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix likely remote URL bug * update comment * Update is_url() * Fix 2x download attempt on http://path/to/model.ptSigned-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: 's avatarglennjocher <glenn.jocher@ultralytics.com> Co-authored-by: 's avatarGaz Iqbal <giqbal@octoml.ai> Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
上级 1320ce18
...@@ -104,7 +104,7 @@ def run( ...@@ -104,7 +104,7 @@ def run(
seen, windows, dt = 0, [], (Profile(), Profile(), Profile()) seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
for path, im, im0s, vid_cap, s in dataset: for path, im, im0s, vid_cap, s in dataset:
with dt[0]: with dt[0]:
im = torch.Tensor(im).to(device) im = torch.Tensor(im).to(model.device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
if len(im.shape) == 3: if len(im.shape) == 3:
im = im[None] # expand for batch dim im = im[None] # expand for batch dim
......
...@@ -49,7 +49,7 @@ from utils.torch_utils import select_device, smart_inference_mode ...@@ -49,7 +49,7 @@ from utils.torch_utils import select_device, smart_inference_mode
@smart_inference_mode() @smart_inference_mode()
def run( def run(
weights=ROOT / 'yolov5s.pt', # model.pt path(s) weights=ROOT / 'yolov5s.pt', # model path or triton URL
source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam) source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam)
data=ROOT / 'data/coco128.yaml', # dataset.yaml path data=ROOT / 'data/coco128.yaml', # dataset.yaml path
imgsz=(640, 640), # inference size (height, width) imgsz=(640, 640), # inference size (height, width)
...@@ -108,11 +108,11 @@ def run( ...@@ -108,11 +108,11 @@ def run(
vid_path, vid_writer = [None] * bs, [None] * bs vid_path, vid_writer = [None] * bs, [None] * bs
# Run inference # Run inference
model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
seen, windows, dt = 0, [], (Profile(), Profile(), Profile()) seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
for path, im, im0s, vid_cap, s in dataset: for path, im, im0s, vid_cap, s in dataset:
with dt[0]: with dt[0]:
im = torch.from_numpy(im).to(device) im = torch.from_numpy(im).to(model.device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0 im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3: if len(im.shape) == 3:
...@@ -214,7 +214,7 @@ def run( ...@@ -214,7 +214,7 @@ def run(
def parse_opt(): def parse_opt():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path(s)') parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path or triton URL')
parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)') parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)')
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path') parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w') parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
......
...@@ -10,6 +10,7 @@ import warnings ...@@ -10,6 +10,7 @@ import warnings
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
from copy import copy from copy import copy
from pathlib import Path from pathlib import Path
from urllib.parse import urlparse
import cv2 import cv2
import numpy as np import numpy as np
...@@ -327,11 +328,13 @@ class DetectMultiBackend(nn.Module): ...@@ -327,11 +328,13 @@ class DetectMultiBackend(nn.Module):
super().__init__() super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights) w = str(weights[0] if isinstance(weights, list) else weights)
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = self._model_type(w) # type pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w)
w = attempt_download(w) # download if not local
fp16 &= pt or jit or onnx or engine # FP16 fp16 &= pt or jit or onnx or engine # FP16
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
stride = 32 # default stride stride = 32 # default stride
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
if not (pt or triton):
w = attempt_download(w) # download if not local
if pt: # PyTorch if pt: # PyTorch
model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse) model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
...@@ -342,7 +345,7 @@ class DetectMultiBackend(nn.Module): ...@@ -342,7 +345,7 @@ class DetectMultiBackend(nn.Module):
elif jit: # TorchScript elif jit: # TorchScript
LOGGER.info(f'Loading {w} for TorchScript inference...') LOGGER.info(f'Loading {w} for TorchScript inference...')
extra_files = {'config.txt': ''} # model metadata extra_files = {'config.txt': ''} # model metadata
model = torch.jit.load(w, _extra_files=extra_files) model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
model.half() if fp16 else model.float() model.half() if fp16 else model.float()
if extra_files['config.txt']: # load metadata dict if extra_files['config.txt']: # load metadata dict
d = json.loads(extra_files['config.txt'], d = json.loads(extra_files['config.txt'],
...@@ -472,6 +475,12 @@ class DetectMultiBackend(nn.Module): ...@@ -472,6 +475,12 @@ class DetectMultiBackend(nn.Module):
predictor = pdi.create_predictor(config) predictor = pdi.create_predictor(config)
input_handle = predictor.get_input_handle(predictor.get_input_names()[0]) input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
output_names = predictor.get_output_names() output_names = predictor.get_output_names()
elif triton: # NVIDIA Triton Inference Server
LOGGER.info(f'Using {w} as Triton Inference Server...')
check_requirements('tritonclient[all]')
from utils.triton import TritonRemoteModel
model = TritonRemoteModel(url=w)
nhwc = model.runtime.startswith("tensorflow")
else: else:
raise NotImplementedError(f'ERROR: {w} is not a supported format') raise NotImplementedError(f'ERROR: {w} is not a supported format')
...@@ -488,6 +497,8 @@ class DetectMultiBackend(nn.Module): ...@@ -488,6 +497,8 @@ class DetectMultiBackend(nn.Module):
b, ch, h, w = im.shape # batch, channel, height, width b, ch, h, w = im.shape # batch, channel, height, width
if self.fp16 and im.dtype != torch.float16: if self.fp16 and im.dtype != torch.float16:
im = im.half() # to FP16 im = im.half() # to FP16
if self.nhwc:
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
if self.pt: # PyTorch if self.pt: # PyTorch
y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im) y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
...@@ -517,7 +528,7 @@ class DetectMultiBackend(nn.Module): ...@@ -517,7 +528,7 @@ class DetectMultiBackend(nn.Module):
self.context.execute_v2(list(self.binding_addrs.values())) self.context.execute_v2(list(self.binding_addrs.values()))
y = [self.bindings[x].data for x in sorted(self.output_names)] y = [self.bindings[x].data for x in sorted(self.output_names)]
elif self.coreml: # CoreML elif self.coreml: # CoreML
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) im = im.cpu().numpy()
im = Image.fromarray((im[0] * 255).astype('uint8')) im = Image.fromarray((im[0] * 255).astype('uint8'))
# im = im.resize((192, 320), Image.ANTIALIAS) # im = im.resize((192, 320), Image.ANTIALIAS)
y = self.model.predict({'image': im}) # coordinates are xywh normalized y = self.model.predict({'image': im}) # coordinates are xywh normalized
...@@ -532,8 +543,10 @@ class DetectMultiBackend(nn.Module): ...@@ -532,8 +543,10 @@ class DetectMultiBackend(nn.Module):
self.input_handle.copy_from_cpu(im) self.input_handle.copy_from_cpu(im)
self.predictor.run() self.predictor.run()
y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names] y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
elif self.triton: # NVIDIA Triton Inference Server
y = self.model(im)
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) im = im.cpu().numpy()
if self.saved_model: # SavedModel if self.saved_model: # SavedModel
y = self.model(im, training=False) if self.keras else self.model(im) y = self.model(im, training=False) if self.keras else self.model(im)
elif self.pb: # GraphDef elif self.pb: # GraphDef
...@@ -566,8 +579,8 @@ class DetectMultiBackend(nn.Module): ...@@ -566,8 +579,8 @@ class DetectMultiBackend(nn.Module):
def warmup(self, imgsz=(1, 3, 640, 640)): def warmup(self, imgsz=(1, 3, 640, 640)):
# Warmup model by running inference once # Warmup model by running inference once
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton
if any(warmup_types) and self.device.type != 'cpu': if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
for _ in range(2 if self.jit else 1): # for _ in range(2 if self.jit else 1): #
self.forward(im) # warmup self.forward(im) # warmup
...@@ -575,14 +588,17 @@ class DetectMultiBackend(nn.Module): ...@@ -575,14 +588,17 @@ class DetectMultiBackend(nn.Module):
@staticmethod @staticmethod
def _model_type(p='path/to/model.pt'): def _model_type(p='path/to/model.pt'):
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
from export import export_formats from export import export_formats
sf = list(export_formats().Suffix) + ['.xml'] # export suffixes from utils.downloads import is_url
check_suffix(p, sf) # checks sf = list(export_formats().Suffix) # export suffixes
p = Path(p).name # eliminate trailing separators if not is_url(p, check=False):
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, xml2 = (s in p for s in sf) check_suffix(p, sf) # checks
xml |= xml2 # *_openvino_model or *.xml url = urlparse(p) # if url may be Triton inference server
tflite &= not edgetpu # *.tflite types = [s in Path(p).name for s in sf]
return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle types[8] &= not types[9] # tflite &= not edgetpu
triton = not any(types) and all([any(s in url.scheme for s in ["http", "grpc"]), url.netloc])
return types + [triton]
@staticmethod @staticmethod
def _load_metadata(f=Path('path/to/meta.yaml')): def _load_metadata(f=Path('path/to/meta.yaml')):
......
...@@ -34,6 +34,9 @@ seaborn>=0.11.0 ...@@ -34,6 +34,9 @@ seaborn>=0.11.0
# tensorflowjs>=3.9.0 # TF.js export # tensorflowjs>=3.9.0 # TF.js export
# openvino-dev # OpenVINO export # openvino-dev # OpenVINO export
# Deploy --------------------------------------
# tritonclient[all]~=2.24.0
# Extras -------------------------------------- # Extras --------------------------------------
ipython # interactive notebook ipython # interactive notebook
psutil # system utilization psutil # system utilization
......
...@@ -114,7 +114,7 @@ def run( ...@@ -114,7 +114,7 @@ def run(
seen, windows, dt = 0, [], (Profile(), Profile(), Profile()) seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
for path, im, im0s, vid_cap, s in dataset: for path, im, im0s, vid_cap, s in dataset:
with dt[0]: with dt[0]:
im = torch.from_numpy(im).to(device) im = torch.from_numpy(im).to(model.device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0 im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3: if len(im.shape) == 3:
......
...@@ -16,13 +16,13 @@ import requests ...@@ -16,13 +16,13 @@ import requests
import torch import torch
def is_url(url, check_exists=True): def is_url(url, check=True):
# Check if string is URL and check if URL exists # Check if string is URL and check if URL exists
try: try:
url = str(url) url = str(url)
result = urllib.parse.urlparse(url) result = urllib.parse.urlparse(url)
assert all([result.scheme, result.netloc, result.path]) # check if is url assert all([result.scheme, result.netloc, result.path]) # check if is url
return (urllib.request.urlopen(url).getcode() == 200) if check_exists else True # check if exists online return (urllib.request.urlopen(url).getcode() == 200) if check else True # check if exists online
except (AssertionError, urllib.request.HTTPError): except (AssertionError, urllib.request.HTTPError):
return False return False
......
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
""" Utils to interact with the Triton Inference Server
"""
import typing
from urllib.parse import urlparse
import torch
class TritonRemoteModel:
""" A wrapper over a model served by the Triton Inference Server. It can
be configured to communicate over GRPC or HTTP. It accepts Torch Tensors
as input and returns them as outputs.
"""
def __init__(self, url: str):
"""
Keyword arguments:
url: Fully qualified address of the Triton server - for e.g. grpc://localhost:8000
"""
parsed_url = urlparse(url)
if parsed_url.scheme == "grpc":
from tritonclient.grpc import InferenceServerClient, InferInput
self.client = InferenceServerClient(parsed_url.netloc) # Triton GRPC client
model_repository = self.client.get_model_repository_index()
self.model_name = model_repository.models[0].name
self.metadata = self.client.get_model_metadata(self.model_name, as_json=True)
def create_input_placeholders() -> typing.List[InferInput]:
return [
InferInput(i['name'], [int(s) for s in i["shape"]], i['datatype']) for i in self.metadata['inputs']]
else:
from tritonclient.http import InferenceServerClient, InferInput
self.client = InferenceServerClient(parsed_url.netloc) # Triton HTTP client
model_repository = self.client.get_model_repository_index()
self.model_name = model_repository[0]['name']
self.metadata = self.client.get_model_metadata(self.model_name)
def create_input_placeholders() -> typing.List[InferInput]:
return [
InferInput(i['name'], [int(s) for s in i["shape"]], i['datatype']) for i in self.metadata['inputs']]
self._create_input_placeholders_fn = create_input_placeholders
@property
def runtime(self):
"""Returns the model runtime"""
return self.metadata.get("backend", self.metadata.get("platform"))
def __call__(self, *args, **kwargs) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, ...]]:
""" Invokes the model. Parameters can be provided via args or kwargs.
args, if provided, are assumed to match the order of inputs of the model.
kwargs are matched with the model input names.
"""
inputs = self._create_inputs(*args, **kwargs)
response = self.client.infer(model_name=self.model_name, inputs=inputs)
result = []
for output in self.metadata['outputs']:
tensor = torch.as_tensor(response.as_numpy(output['name']))
result.append(tensor)
return result[0] if len(result) == 1 else result
def _create_inputs(self, *args, **kwargs):
args_len, kwargs_len = len(args), len(kwargs)
if not args_len and not kwargs_len:
raise RuntimeError("No inputs provided.")
if args_len and kwargs_len:
raise RuntimeError("Cannot specify args and kwargs at the same time")
placeholders = self._create_input_placeholders_fn()
if args_len:
if args_len != len(placeholders):
raise RuntimeError(f"Expected {len(placeholders)} inputs, got {args_len}.")
for input, value in zip(placeholders, args):
input.set_data_from_numpy(value.cpu().numpy())
else:
for input in placeholders:
value = kwargs[input.name]
input.set_data_from_numpy(value.cpu().numpy())
return placeholders
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论