Unverified 提交 53bfcbe0 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update AP calculation (#4260)

* Update AP calculation * Cleanup * Remove original
上级 cd540d86
...@@ -50,26 +50,27 @@ def save_one_json(predn, jdict, path, class_map): ...@@ -50,26 +50,27 @@ def save_one_json(predn, jdict, path, class_map):
'score': round(p[4], 5)}) 'score': round(p[4], 5)})
def process_batch(predictions, labels, iouv): def process_batch(detections, labels, iouv):
# Evaluate 1 batch of predictions """
correct = torch.zeros(predictions.shape[0], len(iouv), dtype=torch.bool, device=iouv.device) Return correct predictions matrix. Both sets of boxes are in (x1, y1, x2, y2) format.
detected = [] # label indices Arguments:
tcls, pcls = labels[:, 0], predictions[:, 5] detections (Array[N, 6]), x1, y1, x2, y2, conf, class
nl = labels.shape[0] # number of labels labels (Array[M, 5]), class, x1, y1, x2, y2
for cls in torch.unique(tcls): Returns:
ti = (cls == tcls).nonzero().view(-1) # label indices correct (Array[N, 10]), for 10 IoU levels
pi = (cls == pcls).nonzero().view(-1) # prediction indices """
if pi.shape[0]: # find detections correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device)
ious, i = box_iou(predictions[pi, 0:4], labels[ti, 1:5]).max(1) # best ious, indices iou = box_iou(labels[:, 1:], detections[:, :4])
detected_set = set() x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5])) # IoU above threshold and classes match
for j in (ious > iouv[0]).nonzero(): if x[0].shape[0]:
d = ti[i[j]] # detected label matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detection, iou]
if d.item() not in detected_set: if x[0].shape[0] > 1:
detected_set.add(d.item()) matches = matches[matches[:, 2].argsort()[::-1]]
detected.append(d) # append detections matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
correct[pi[j]] = ious[j] > iouv # iou_thres is 1xn # matches = matches[matches[:, 2].argsort()[::-1]]
if len(detected) == nl: # all labels already located in image matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
break matches = torch.Tensor(matches).to(iouv.device)
correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv
return correct return correct
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论