Unverified 提交 2e538443 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

ONNX inference update (#4073)

上级 39ef6c7a
...@@ -64,18 +64,23 @@ def run(weights='yolov5s.pt', # model.pt path(s) ...@@ -64,18 +64,23 @@ def run(weights='yolov5s.pt', # model.pt path(s)
half &= device.type != 'cpu' # half precision only supported on CUDA half &= device.type != 'cpu' # half precision only supported on CUDA
# Load model # Load model
model = attempt_load(weights, map_location=device) # load FP32 model w = weights[0] if isinstance(weights, list) else weights
stride = int(model.stride.max()) # model stride classify, pt, onnx = False, w.endswith('.pt'), w.endswith('.onnx') # inference type
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
if pt:
model = attempt_load(weights, map_location=device) # load FP32 model
stride = int(model.stride.max()) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names
if half:
model.half() # to FP16
if classify: # second-stage classifier
modelc = load_classifier(name='resnet50', n=2) # initialize
modelc.load_state_dict(torch.load('resnet50.pt', map_location=device)['model']).to(device).eval()
elif onnx:
check_requirements(('onnx', 'onnxruntime'))
import onnxruntime
session = onnxruntime.InferenceSession(w, None)
imgsz = check_img_size(imgsz, s=stride) # check image size imgsz = check_img_size(imgsz, s=stride) # check image size
names = model.module.names if hasattr(model, 'module') else model.names # get class names
if half:
model.half() # to FP16
# Second-stage classifier
classify = False
if classify:
modelc = load_classifier(name='resnet50', n=2) # initialize
modelc.load_state_dict(torch.load('resnet50.pt', map_location=device)['model']).to(device).eval()
# Dataloader # Dataloader
if webcam: if webcam:
...@@ -89,31 +94,36 @@ def run(weights='yolov5s.pt', # model.pt path(s) ...@@ -89,31 +94,36 @@ def run(weights='yolov5s.pt', # model.pt path(s)
vid_path, vid_writer = [None] * bs, [None] * bs vid_path, vid_writer = [None] * bs, [None] * bs
# Run inference # Run inference
if device.type != 'cpu': if pt and device.type != 'cpu':
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
t0 = time.time() t0 = time.time()
for path, img, im0s, vid_cap in dataset: for path, img, im0s, vid_cap in dataset:
img = torch.from_numpy(img).to(device) if pt:
img = img.half() if half else img.float() # uint8 to fp16/32 img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float() # uint8 to fp16/32
elif onnx:
img = img.astype('float32')
img /= 255.0 # 0 - 255 to 0.0 - 1.0 img /= 255.0 # 0 - 255 to 0.0 - 1.0
if img.ndimension() == 3: if len(img.shape) == 3:
img = img.unsqueeze(0) img = img[None] # expand for batch dim
# Inference # Inference
t1 = time_sync() t1 = time_sync()
pred = model(img, if pt:
augment=augment, visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
visualize=increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False)[0] pred = model(img, augment=augment, visualize=visualize)[0]
elif onnx:
pred = torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: img}))
# Apply NMS # NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
t2 = time_sync() t2 = time_sync()
# Apply Classifier # Second-stage classifier (optional)
if classify: if classify:
pred = apply_classifier(pred, modelc, img, im0s) pred = apply_classifier(pred, modelc, img, im0s)
# Process detections # Process predictions
for i, det in enumerate(pred): # detections per image for i, det in enumerate(pred): # detections per image
if webcam: # batch_size >= 1 if webcam: # batch_size >= 1
p, s, im0, frame = path[i], f'{i}: ', im0s[i].copy(), dataset.count p, s, im0, frame = path[i], f'{i}: ', im0s[i].copy(), dataset.count
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论