Unverified 提交 d2e698c7 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Reduce val device transfers (#7525)

上级 23718df1
...@@ -220,14 +220,14 @@ def run( ...@@ -220,14 +220,14 @@ def run(
# Metrics # Metrics
for si, pred in enumerate(out): for si, pred in enumerate(out):
labels = targets[targets[:, 0] == si, 1:] labels = targets[targets[:, 0] == si, 1:]
nl = len(labels) nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
tcls = labels[:, 0].tolist() if nl else [] # target class
path, shape = Path(paths[si]), shapes[si][0] path, shape = Path(paths[si]), shapes[si][0]
correct = torch.zeros(npr, niou, dtype=torch.bool, device=device) # init
seen += 1 seen += 1
if len(pred) == 0: if npr == 0:
if nl: if nl:
stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls)) stats.append((correct, *torch.zeros((3, 0))))
continue continue
# Predictions # Predictions
...@@ -244,9 +244,7 @@ def run( ...@@ -244,9 +244,7 @@ def run(
correct = process_batch(predn, labelsn, iouv) correct = process_batch(predn, labelsn, iouv)
if plots: if plots:
confusion_matrix.process_batch(predn, labelsn) confusion_matrix.process_batch(predn, labelsn)
else: stats.append((correct, pred[:, 4], pred[:, 5], labels[:, 0])) # (correct, conf, pcls, tcls)
correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool)
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) # (correct, conf, pcls, tcls)
# Save/log # Save/log
if save_txt: if save_txt:
...@@ -265,7 +263,7 @@ def run( ...@@ -265,7 +263,7 @@ def run(
callbacks.run('on_val_batch_end') callbacks.run('on_val_batch_end')
# Compute metrics # Compute metrics
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy
if len(stats) and stats[0].any(): if len(stats) and stats[0].any():
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names) tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95 ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论