Unverified 提交 2373d547 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

NMS MPS device wrapper (#9620)

* NMS MPS device wrapper May resolve https://github.com/ultralytics/yolov5/issues/9613Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update general.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 7314363f
......@@ -843,7 +843,9 @@ def non_max_suppression(
if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output
if 'mps' in prediction.device.type: # MPS not fully supported yet, convert tensors to CPU before NMS
device = prediction.device
mps = 'mps' in device.type # Apple MPS
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
prediction = prediction.cpu()
bs = prediction.shape[0] # batch size
nc = prediction.shape[2] - nm - 5 # number of classes
......@@ -930,6 +932,8 @@ def non_max_suppression(
i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
if mps:
output[xi] = output[xi].to(device)
if (time.time() - t) > time_limit:
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
break # time limit exceeded
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论