Unverified 提交 02719dde authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update `feature_visualization()` (#3807)

* Update `feature_visualization()` Only plot for data with height, width > 1 * cleanup * Cleanup
上级 20d45aa4
...@@ -448,26 +448,28 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): ...@@ -448,26 +448,28 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
fig.savefig(Path(save_dir) / 'results.png', dpi=200) fig.savefig(Path(save_dir) / 'results.png', dpi=200)
def feature_visualization(features, module_type, module_idx, n=64): def feature_visualization(x, module_type, stage, n=64):
""" """
features: Features to be visualized x: Features to be visualized
module_type: Module type module_type: Module type
module_idx: Module layer index within model stage: Module stage within model
n: Maximum number of feature maps to plot n: Maximum number of feature maps to plot
""" """
project, name = 'runs/features', 'exp' batch, channels, height, width = x.shape # batch, channels, height, width
save_dir = increment_path(Path(project) / name) # increment run if height > 1 and width > 1:
save_dir.mkdir(parents=True, exist_ok=True) # make dir project, name = 'runs/features', 'exp'
save_dir = increment_path(Path(project) / name) # increment run
plt.figure(tight_layout=True) save_dir.mkdir(parents=True, exist_ok=True) # make dir
blocks = torch.chunk(features, features.shape[1], dim=1) # block by channel dimension
n = min(n, len(blocks)) plt.figure(tight_layout=True)
for i in range(n): blocks = torch.chunk(x, channels, dim=1) # block by channel dimension
feature = transforms.ToPILImage()(blocks[i].squeeze()) n = min(n, len(blocks))
ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1) for i in range(n):
ax.axis('off') feature = transforms.ToPILImage()(blocks[i].squeeze())
plt.imshow(feature) # cmap='gray' ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1)
ax.axis('off')
f = f"layer_{module_idx}_{module_type.split('.')[-1]}_features.png" plt.imshow(feature) # cmap='gray'
print(f'Saving {save_dir / f}...')
plt.savefig(save_dir / f, dpi=300) f = f"stage_{stage}_{module_type.split('.')[-1]}_features.png"
print(f'Saving {save_dir / f}...')
plt.savefig(save_dir / f, dpi=300)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论