Unverified 提交 f17c86b7 authored 作者: Zengyf-CVer's avatar Zengyf-CVer 提交者: GitHub

Save *.npy features on detect.py `--visualize` (#5701)

* Add feature map to save npy files Add feature map to save npy files,export npy files with 32 feature maps per layer. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update plots.py * Update plots.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update plots.py Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 d6ae1c83
...@@ -1104,4 +1104,4 @@ ...@@ -1104,4 +1104,4 @@
"outputs": [] "outputs": []
} }
] ]
} }
\ No newline at end of file
...@@ -132,7 +132,7 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec ...@@ -132,7 +132,7 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec
if 'Detect' not in module_type: if 'Detect' not in module_type:
batch, channels, height, width = x.shape # batch, channels, height, width batch, channels, height, width = x.shape # batch, channels, height, width
if height > 1 and width > 1: if height > 1 and width > 1:
f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
n = min(n, channels) # number of plots n = min(n, channels) # number of plots
...@@ -143,9 +143,10 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec ...@@ -143,9 +143,10 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec
ax[i].imshow(blocks[i].squeeze()) # cmap='gray' ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
ax[i].axis('off') ax[i].axis('off')
print(f'Saving {save_dir / f}... ({n}/{channels})') print(f'Saving {f}... ({n}/{channels})')
plt.savefig(save_dir / f, dpi=300, bbox_inches='tight') plt.savefig(f, dpi=300, bbox_inches='tight')
plt.close() plt.close()
np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save
def hist2d(x, y, n=100): def hist2d(x, y, n=100):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论