Unverified 提交 87b094bc authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Feature visualization update (#3920)

* Feature visualization update * Save to jpg (faster) * Save to png
上级 61047a2b
...@@ -40,6 +40,7 @@ def run(weights='yolov5s.pt', # model.pt path(s) ...@@ -40,6 +40,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
classes=None, # filter by class: --class 0, or --class 0 2 3 classes=None, # filter by class: --class 0, or --class 0 2 3
agnostic_nms=False, # class-agnostic NMS agnostic_nms=False, # class-agnostic NMS
augment=False, # augmented inference augment=False, # augmented inference
visualize=False, # visualize features
update=False, # update all models update=False, # update all models
project='runs/detect', # save results to project/name project='runs/detect', # save results to project/name
name='exp', # save results to project/name name='exp', # save results to project/name
...@@ -100,7 +101,9 @@ def run(weights='yolov5s.pt', # model.pt path(s) ...@@ -100,7 +101,9 @@ def run(weights='yolov5s.pt', # model.pt path(s)
# Inference # Inference
t1 = time_synchronized() t1 = time_synchronized()
pred = model(img, augment=augment)[0] pred = model(img,
augment=augment,
visualize=increment_path(save_dir / 'features', mkdir=True) if visualize else False)[0]
# Apply NMS # Apply NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
...@@ -201,6 +204,7 @@ def parse_opt(): ...@@ -201,6 +204,7 @@ def parse_opt():
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3') parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
parser.add_argument('--augment', action='store_true', help='augmented inference') parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--visualize', action='store_true', help='visualize features')
parser.add_argument('--update', action='store_true', help='update all models') parser.add_argument('--update', action='store_true', help='update all models')
parser.add_argument('--project', default='runs/detect', help='save results to project/name') parser.add_argument('--project', default='runs/detect', help='save results to project/name')
parser.add_argument('--name', default='exp', help='save results to project/name') parser.add_argument('--name', default='exp', help='save results to project/name')
......
...@@ -117,11 +117,10 @@ class Model(nn.Module): ...@@ -117,11 +117,10 @@ class Model(nn.Module):
self.info() self.info()
logger.info('') logger.info('')
def forward(self, x, augment=False, profile=False): def forward(self, x, augment=False, profile=False, visualize=False):
if augment: if augment:
return self.forward_augment(x) # augmented inference, None return self.forward_augment(x) # augmented inference, None
else: return self.forward_once(x, profile, visualize) # single-scale inference, train
return self.forward_once(x, profile) # single-scale inference, train
def forward_augment(self, x): def forward_augment(self, x):
img_size = x.shape[-2:] # height, width img_size = x.shape[-2:] # height, width
...@@ -136,7 +135,7 @@ class Model(nn.Module): ...@@ -136,7 +135,7 @@ class Model(nn.Module):
y.append(yi) y.append(yi)
return torch.cat(y, 1), None # augmented inference, train return torch.cat(y, 1), None # augmented inference, train
def forward_once(self, x, profile=False, feature_vis=False): def forward_once(self, x, profile=False, visualize=False):
y, dt = [], [] # outputs y, dt = [], [] # outputs
for m in self.model: for m in self.model:
if m.f != -1: # if not from previous layer if m.f != -1: # if not from previous layer
...@@ -155,8 +154,8 @@ class Model(nn.Module): ...@@ -155,8 +154,8 @@ class Model(nn.Module):
x = m(x) # run x = m(x) # run
y.append(x if m.i in self.save else None) # save output y.append(x if m.i in self.save else None) # save output
if feature_vis and m.type == 'models.common.SPP': if visualize:
feature_visualization(x, m.type, m.i) feature_visualization(x, m.type, m.i, save_dir=visualize)
if profile: if profile:
logger.info('%.1fms total' % sum(dt)) logger.info('%.1fms total' % sum(dt))
......
# Plotting utils # Plotting utils
import glob import glob
import math
import os import os
from copy import copy from copy import copy
from pathlib import Path from pathlib import Path
import cv2 import cv2
import math
import matplotlib import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
...@@ -15,7 +15,6 @@ import seaborn as sn ...@@ -15,7 +15,6 @@ import seaborn as sn
import torch import torch
import yaml import yaml
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms
from utils.general import increment_path, xywh2xyxy, xyxy2xywh from utils.general import increment_path, xywh2xyxy, xyxy2xywh
from utils.metrics import fitness from utils.metrics import fitness
...@@ -448,28 +447,26 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): ...@@ -448,28 +447,26 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
fig.savefig(Path(save_dir) / 'results.png', dpi=200) fig.savefig(Path(save_dir) / 'results.png', dpi=200)
def feature_visualization(x, module_type, stage, n=64): def feature_visualization(x, module_type, stage, n=64, save_dir=Path('runs/detect/exp')):
""" """
x: Features to be visualized x: Features to be visualized
module_type: Module type module_type: Module type
stage: Module stage within model stage: Module stage within model
n: Maximum number of feature maps to plot n: Maximum number of feature maps to plot
save_dir: Directory to save results
""" """
batch, channels, height, width = x.shape # batch, channels, height, width if 'Detect' not in module_type:
if height > 1 and width > 1: batch, channels, height, width = x.shape # batch, channels, height, width
project, name = 'runs/features', 'exp' if height > 1 and width > 1:
save_dir = increment_path(Path(project) / name) # increment run f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
save_dir.mkdir(parents=True, exist_ok=True) # make dir
plt.figure(tight_layout=True)
plt.figure(tight_layout=True) blocks = torch.chunk(x[0], channels, dim=0) # select batch index 0, block by channels
blocks = torch.chunk(x, channels, dim=1) # block by channel dimension n = min(n, channels) # number of plots
n = min(n, len(blocks)) ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True)[1].ravel() # 8 rows x n/8 cols
for i in range(n): for i in range(n):
feature = transforms.ToPILImage()(blocks[i].squeeze()) ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1) ax[i].axis('off')
ax.axis('off')
plt.imshow(feature) # cmap='gray' print(f'Saving {save_dir / f}... ({n}/{channels})')
plt.savefig(save_dir / f, dpi=300)
f = f"stage_{stage}_{module_type.split('.')[-1]}_features.png"
print(f'Saving {save_dir / f}...')
plt.savefig(save_dir / f, dpi=300)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论