Unverified 提交 53bfcbe0 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update AP calculation (#4260)

* Update AP calculation * Cleanup * Remove original
上级 cd540d86
......@@ -50,26 +50,27 @@ def save_one_json(predn, jdict, path, class_map):
'score': round(p[4], 5)})
def process_batch(predictions, labels, iouv):
# Evaluate 1 batch of predictions
correct = torch.zeros(predictions.shape[0], len(iouv), dtype=torch.bool, device=iouv.device)
detected = [] # label indices
tcls, pcls = labels[:, 0], predictions[:, 5]
nl = labels.shape[0] # number of labels
for cls in torch.unique(tcls):
ti = (cls == tcls).nonzero().view(-1) # label indices
pi = (cls == pcls).nonzero().view(-1) # prediction indices
if pi.shape[0]: # find detections
ious, i = box_iou(predictions[pi, 0:4], labels[ti, 1:5]).max(1) # best ious, indices
detected_set = set()
for j in (ious > iouv[0]).nonzero():
d = ti[i[j]] # detected label
if d.item() not in detected_set:
detected_set.add(d.item())
detected.append(d) # append detections
correct[pi[j]] = ious[j] > iouv # iou_thres is 1xn
if len(detected) == nl: # all labels already located in image
break
def process_batch(detections, labels, iouv):
"""
Return correct predictions matrix. Both sets of boxes are in (x1, y1, x2, y2) format.
Arguments:
detections (Array[N, 6]), x1, y1, x2, y2, conf, class
labels (Array[M, 5]), class, x1, y1, x2, y2
Returns:
correct (Array[N, 10]), for 10 IoU levels
"""
correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device)
iou = box_iou(labels[:, 1:], detections[:, :4])
x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5])) # IoU above threshold and classes match
if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detection, iou]
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
# matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
matches = torch.Tensor(matches).to(iouv.device)
correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv
return correct
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论