Unverified 提交 9776e709 authored 作者: rafale77's avatar rafale77 提交者: GitHub

torch.ops.torchvision.nms (#860)

Don't load the entire torchvision library just for nms when the function is already in the torch library.
上级 0c01afc3
...@@ -17,7 +17,6 @@ import matplotlib.pyplot as plt ...@@ -17,7 +17,6 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision
import yaml import yaml
from scipy.cluster.vq import kmeans from scipy.cluster.vq import kmeans
from scipy.signal import butter, filtfilt from scipy.signal import butter, filtfilt
...@@ -651,7 +650,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False, ...@@ -651,7 +650,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False,
# Batched NMS # Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) i = torch.ops.torchvision.nms(boxes, scores, iou_thres)
if i.shape[0] > max_det: # limit detections if i.shape[0] > max_det: # limit detections
i = i[:max_det] i = i[:max_det]
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论