Unverified 提交 f639e14e authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Metric-Confidence plots feature addition (#2057)

* Metric-Confidence plots feature addition * cleanup * Metric-Confidence plots feature addition * cleanup * Update run-once lines * cleanup * save all 4 curves to wandb
上级 2a835c79
...@@ -215,7 +215,7 @@ def test(data, ...@@ -215,7 +215,7 @@ def test(data,
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
if len(stats) and stats[0].any(): if len(stats) and stats[0].any():
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names) p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(1) # [P, R, AP@0.5, AP@0.5:0.95] ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean() mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
else: else:
......
...@@ -403,7 +403,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ...@@ -403,7 +403,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
if plots: if plots:
plot_results(save_dir=save_dir) # save as results.png plot_results(save_dir=save_dir) # save as results.png
if wandb: if wandb:
files = ['results.png', 'precision_recall_curve.png', 'confusion_matrix.png'] files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
if (save_dir / f).exists()]}) if (save_dir / f).exists()]})
if opt.log_artifacts: if opt.log_artifacts:
......
...@@ -15,7 +15,7 @@ def fitness(x): ...@@ -15,7 +15,7 @@ def fitness(x):
return (x[:, :4] * w).sum(1) return (x[:, :4] * w).sum(1)
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision-recall_curve.png', names=[]): def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=()):
""" Compute the average precision, given the recall and precision curves. """ Compute the average precision, given the recall and precision curves.
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
# Arguments # Arguments
...@@ -35,12 +35,11 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision ...@@ -35,12 +35,11 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision
# Find unique classes # Find unique classes
unique_classes = np.unique(target_cls) unique_classes = np.unique(target_cls)
nc = unique_classes.shape[0] # number of classes, number of detections
# Create Precision-Recall curve and compute AP for each class # Create Precision-Recall curve and compute AP for each class
px, py = np.linspace(0, 1, 1000), [] # for plotting px, py = np.linspace(0, 1, 1000), [] # for plotting
pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898 ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95)
ap, p, r = np.zeros(s), np.zeros(s), np.zeros(s)
for ci, c in enumerate(unique_classes): for ci, c in enumerate(unique_classes):
i = pred_cls == c i = pred_cls == c
n_l = (target_cls == c).sum() # number of labels n_l = (target_cls == c).sum() # number of labels
...@@ -55,25 +54,28 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision ...@@ -55,25 +54,28 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision
# Recall # Recall
recall = tpc / (n_l + 1e-16) # recall curve recall = tpc / (n_l + 1e-16) # recall curve
r[ci] = np.interp(-pr_score, -conf[i], recall[:, 0]) # r at pr_score, negative x, xp because xp decreases r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
# Precision # Precision
precision = tpc / (tpc + fpc) # precision curve precision = tpc / (tpc + fpc) # precision curve
p[ci] = np.interp(-pr_score, -conf[i], precision[:, 0]) # p at pr_score p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
# AP from recall-precision curve # AP from recall-precision curve
for j in range(tp.shape[1]): for j in range(tp.shape[1]):
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j]) ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
if plot and (j == 0): if plot and j == 0:
py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5 py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
# Compute F1 score (harmonic mean of precision and recall) # Compute F1 (harmonic mean of precision and recall)
f1 = 2 * p * r / (p + r + 1e-16) f1 = 2 * p * r / (p + r + 1e-16)
if plot: if plot:
plot_pr_curve(px, py, ap, save_dir, names) plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names)
plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1')
plot_mc_curve(px, p, Path(save_dir) / 'P_curve.png', names, ylabel='Precision')
plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall')
return p, r, ap, f1, unique_classes.astype('int32') i = f1.mean(0).argmax() # max F1 index
return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32')
def compute_ap(recall, precision): def compute_ap(recall, precision):
...@@ -181,13 +183,14 @@ class ConfusionMatrix: ...@@ -181,13 +183,14 @@ class ConfusionMatrix:
# Plots ---------------------------------------------------------------------------------------------------------------- # Plots ----------------------------------------------------------------------------------------------------------------
def plot_pr_curve(px, py, ap, save_dir='.', names=()): def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()):
# Precision-recall curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
py = np.stack(py, axis=1) py = np.stack(py, axis=1)
if 0 < len(names) < 21: # show mAP in legend if < 10 classes if 0 < len(names) < 21: # display per-class legend if < 21 classes
for i, y in enumerate(py.T): for i, y in enumerate(py.T):
ax.plot(px, y, linewidth=1, label=f'{names[i]} %.3f' % ap[i, 0]) # plot(recall, precision) ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
else: else:
ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision) ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
...@@ -197,4 +200,24 @@ def plot_pr_curve(px, py, ap, save_dir='.', names=()): ...@@ -197,4 +200,24 @@ def plot_pr_curve(px, py, ap, save_dir='.', names=()):
ax.set_xlim(0, 1) ax.set_xlim(0, 1)
ax.set_ylim(0, 1) ax.set_ylim(0, 1)
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
fig.savefig(Path(save_dir) / 'precision_recall_curve.png', dpi=250) fig.savefig(Path(save_dir), dpi=250)
def plot_mc_curve(px, py, save_dir='mc_curve.png', names=(), xlabel='Confidence', ylabel='Metric'):
# Metric-confidence curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
if 0 < len(names) < 21: # display per-class legend if < 21 classes
for i, y in enumerate(py):
ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
else:
ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
y = py.mean(0)
ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
fig.savefig(Path(save_dir), dpi=250)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论