Unverified 提交 87b094bc authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Feature visualization update (#3920)

* Feature visualization update * Save to jpg (faster) * Save to png
上级 61047a2b
......@@ -40,6 +40,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
classes=None, # filter by class: --class 0, or --class 0 2 3
agnostic_nms=False, # class-agnostic NMS
augment=False, # augmented inference
visualize=False, # visualize features
update=False, # update all models
project='runs/detect', # save results to project/name
name='exp', # save results to project/name
......@@ -100,7 +101,9 @@ def run(weights='yolov5s.pt', # model.pt path(s)
# Inference
t1 = time_synchronized()
pred = model(img, augment=augment)[0]
pred = model(img,
augment=augment,
visualize=increment_path(save_dir / 'features', mkdir=True) if visualize else False)[0]
# Apply NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
......@@ -201,6 +204,7 @@ def parse_opt():
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--visualize', action='store_true', help='visualize features')
parser.add_argument('--update', action='store_true', help='update all models')
parser.add_argument('--project', default='runs/detect', help='save results to project/name')
parser.add_argument('--name', default='exp', help='save results to project/name')
......
......@@ -117,11 +117,10 @@ class Model(nn.Module):
self.info()
logger.info('')
def forward(self, x, augment=False, profile=False):
def forward(self, x, augment=False, profile=False, visualize=False):
if augment:
return self.forward_augment(x) # augmented inference, None
else:
return self.forward_once(x, profile) # single-scale inference, train
return self.forward_once(x, profile, visualize) # single-scale inference, train
def forward_augment(self, x):
img_size = x.shape[-2:] # height, width
......@@ -136,7 +135,7 @@ class Model(nn.Module):
y.append(yi)
return torch.cat(y, 1), None # augmented inference, train
def forward_once(self, x, profile=False, feature_vis=False):
def forward_once(self, x, profile=False, visualize=False):
y, dt = [], [] # outputs
for m in self.model:
if m.f != -1: # if not from previous layer
......@@ -155,8 +154,8 @@ class Model(nn.Module):
x = m(x) # run
y.append(x if m.i in self.save else None) # save output
if feature_vis and m.type == 'models.common.SPP':
feature_visualization(x, m.type, m.i)
if visualize:
feature_visualization(x, m.type, m.i, save_dir=visualize)
if profile:
logger.info('%.1fms total' % sum(dt))
......
# Plotting utils
import glob
import math
import os
from copy import copy
from pathlib import Path
import cv2
import math
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
......@@ -15,7 +15,6 @@ import seaborn as sn
import torch
import yaml
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms
from utils.general import increment_path, xywh2xyxy, xyxy2xywh
from utils.metrics import fitness
......@@ -448,28 +447,26 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
fig.savefig(Path(save_dir) / 'results.png', dpi=200)
def feature_visualization(x, module_type, stage, n=64):
def feature_visualization(x, module_type, stage, n=64, save_dir=Path('runs/detect/exp')):
"""
x: Features to be visualized
module_type: Module type
stage: Module stage within model
n: Maximum number of feature maps to plot
save_dir: Directory to save results
"""
batch, channels, height, width = x.shape # batch, channels, height, width
if height > 1 and width > 1:
project, name = 'runs/features', 'exp'
save_dir = increment_path(Path(project) / name) # increment run
save_dir.mkdir(parents=True, exist_ok=True) # make dir
plt.figure(tight_layout=True)
blocks = torch.chunk(x, channels, dim=1) # block by channel dimension
n = min(n, len(blocks))
for i in range(n):
feature = transforms.ToPILImage()(blocks[i].squeeze())
ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1)
ax.axis('off')
plt.imshow(feature) # cmap='gray'
f = f"stage_{stage}_{module_type.split('.')[-1]}_features.png"
print(f'Saving {save_dir / f}...')
plt.savefig(save_dir / f, dpi=300)
if 'Detect' not in module_type:
batch, channels, height, width = x.shape # batch, channels, height, width
if height > 1 and width > 1:
f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
plt.figure(tight_layout=True)
blocks = torch.chunk(x[0], channels, dim=0) # select batch index 0, block by channels
n = min(n, channels) # number of plots
ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True)[1].ravel() # 8 rows x n/8 cols
for i in range(n):
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
ax[i].axis('off')
print(f'Saving {save_dir / f}... ({n}/{channels})')
plt.savefig(save_dir / f, dpi=300)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论