Unverified 提交 f7d85620 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

`val.py` refactor (#4053)

* val.py refactor * cleanup * cleanup * cleanup * cleanup * save after eval * opt.imgsz bug fix * wandb refactor * dataloader to train_loader * capitalize global variables * runs/hub/exp to runs/detect/exp * refactor wandb logging * Refactor wandb operations (#4061) Co-authored-by: 's avatarAyush Chaurasia <ayush.chaurarsia@gmail.com>
上级 9dd33fd2
...@@ -21,7 +21,7 @@ from utils.datasets import LoadStreams, LoadImages ...@@ -21,7 +21,7 @@ from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, check_imshow, colorstr, non_max_suppression, \ from utils.general import check_img_size, check_requirements, check_imshow, colorstr, non_max_suppression, \
apply_classifier, scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box apply_classifier, scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
from utils.plots import colors, plot_one_box from utils.plots import colors, plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized from utils.torch_utils import select_device, load_classifier, time_sync
@torch.no_grad() @torch.no_grad()
...@@ -100,14 +100,14 @@ def run(weights='yolov5s.pt', # model.pt path(s) ...@@ -100,14 +100,14 @@ def run(weights='yolov5s.pt', # model.pt path(s)
img = img.unsqueeze(0) img = img.unsqueeze(0)
# Inference # Inference
t1 = time_synchronized() t1 = time_sync()
pred = model(img, pred = model(img,
augment=augment, augment=augment,
visualize=increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False)[0] visualize=increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False)[0]
# Apply NMS # Apply NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
t2 = time_synchronized() t2 = time_sync()
# Apply Classifier # Apply Classifier
if classify: if classify:
......
# YOLOv5 common modules # YOLOv5 common modules
import logging
from copy import copy from copy import copy
from pathlib import Path, PosixPath from pathlib import Path, PosixPath
...@@ -15,7 +16,9 @@ from torch.cuda import amp ...@@ -15,7 +16,9 @@ from torch.cuda import amp
from utils.datasets import exif_transpose, letterbox from utils.datasets import exif_transpose, letterbox
from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh, save_one_box from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh, save_one_box
from utils.plots import colors, plot_one_box from utils.plots import colors, plot_one_box
from utils.torch_utils import time_synchronized from utils.torch_utils import time_sync
LOGGER = logging.getLogger(__name__)
def autopad(k, p=None): # kernel, padding def autopad(k, p=None): # kernel, padding
...@@ -226,7 +229,7 @@ class AutoShape(nn.Module): ...@@ -226,7 +229,7 @@ class AutoShape(nn.Module):
self.model = model.eval() self.model = model.eval()
def autoshape(self): def autoshape(self):
print('AutoShape already enabled, skipping... ') # model already converted to model.autoshape() LOGGER.info('AutoShape already enabled, skipping... ') # model already converted to model.autoshape()
return self return self
@torch.no_grad() @torch.no_grad()
...@@ -240,7 +243,7 @@ class AutoShape(nn.Module): ...@@ -240,7 +243,7 @@ class AutoShape(nn.Module):
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values) # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
t = [time_synchronized()] t = [time_sync()]
p = next(self.model.parameters()) # for device and type p = next(self.model.parameters()) # for device and type
if isinstance(imgs, torch.Tensor): # torch if isinstance(imgs, torch.Tensor): # torch
with amp.autocast(enabled=p.device.type != 'cpu'): with amp.autocast(enabled=p.device.type != 'cpu'):
...@@ -270,19 +273,19 @@ class AutoShape(nn.Module): ...@@ -270,19 +273,19 @@ class AutoShape(nn.Module):
x = np.stack(x, 0) if n > 1 else x[0][None] # stack x = np.stack(x, 0) if n > 1 else x[0][None] # stack
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32 x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
t.append(time_synchronized()) t.append(time_sync())
with amp.autocast(enabled=p.device.type != 'cpu'): with amp.autocast(enabled=p.device.type != 'cpu'):
# Inference # Inference
y = self.model(x, augment, profile)[0] # forward y = self.model(x, augment, profile)[0] # forward
t.append(time_synchronized()) t.append(time_sync())
# Post-process # Post-process
y = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS y = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS
for i in range(n): for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i]) scale_coords(shape1, y[i][:, :4], shape0[i])
t.append(time_synchronized()) t.append(time_sync())
return Detections(imgs, y, files, t, self.names, x.shape) return Detections(imgs, y, files, t, self.names, x.shape)
...@@ -323,31 +326,33 @@ class Detections: ...@@ -323,31 +326,33 @@ class Detections:
im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
if pprint: if pprint:
print(str.rstrip(', ')) LOGGER.info(str.rstrip(', '))
if show: if show:
im.show(self.files[i]) # show im.show(self.files[i]) # show
if save: if save:
f = self.files[i] f = self.files[i]
im.save(save_dir / f) # save im.save(save_dir / f) # save
print(f"{'Saved' * (i == 0)} {f}", end=',' if i < self.n - 1 else f' to {save_dir}\n') if i == self.n - 1:
LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to '{save_dir}'")
if render: if render:
self.imgs[i] = np.asarray(im) self.imgs[i] = np.asarray(im)
def print(self): def print(self):
self.display(pprint=True) # print results self.display(pprint=True) # print results
print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' % self.t) LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' %
self.t)
def show(self): def show(self):
self.display(show=True) # show results self.display(show=True) # show results
def save(self, save_dir='runs/hub/exp'): def save(self, save_dir='runs/detect/exp'):
save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp', mkdir=True) # increment save_dir save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) # increment save_dir
self.display(save=True, save_dir=save_dir) # save results self.display(save=True, save_dir=save_dir) # save results
def crop(self, save_dir='runs/hub/exp'): def crop(self, save_dir='runs/detect/exp'):
save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp', mkdir=True) # increment save_dir save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) # increment save_dir
self.display(crop=True, save_dir=save_dir) # crop results self.display(crop=True, save_dir=save_dir) # crop results
print(f'Saved results to {save_dir}\n') LOGGER.info(f'Saved results to {save_dir}\n')
def render(self): def render(self):
self.display(render=True) # render results self.display(render=True) # render results
......
...@@ -5,7 +5,6 @@ Usage: ...@@ -5,7 +5,6 @@ Usage:
""" """
import argparse import argparse
import logging
import sys import sys
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
...@@ -18,7 +17,7 @@ from models.experimental import * ...@@ -18,7 +17,7 @@ from models.experimental import *
from utils.autoanchor import check_anchor_order from utils.autoanchor import check_anchor_order
from utils.general import make_divisible, check_file, set_logging from utils.general import make_divisible, check_file, set_logging
from utils.plots import feature_visualization from utils.plots import feature_visualization
from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \ from utils.torch_utils import time_sync, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
select_device, copy_attr select_device, copy_attr
try: try:
...@@ -26,7 +25,7 @@ try: ...@@ -26,7 +25,7 @@ try:
except ImportError: except ImportError:
thop = None thop = None
logger = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
class Detect(nn.Module): class Detect(nn.Module):
...@@ -90,15 +89,15 @@ class Model(nn.Module): ...@@ -90,15 +89,15 @@ class Model(nn.Module):
# Define model # Define model
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
if nc and nc != self.yaml['nc']: if nc and nc != self.yaml['nc']:
logger.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
self.yaml['nc'] = nc # override yaml value self.yaml['nc'] = nc # override yaml value
if anchors: if anchors:
logger.info(f'Overriding model.yaml anchors with anchors={anchors}') LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
self.yaml['anchors'] = round(anchors) # override yaml value self.yaml['anchors'] = round(anchors) # override yaml value
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
self.names = [str(i) for i in range(self.yaml['nc'])] # default names self.names = [str(i) for i in range(self.yaml['nc'])] # default names
self.inplace = self.yaml.get('inplace', True) self.inplace = self.yaml.get('inplace', True)
# logger.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))]) # LOGGER.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
# Build strides, anchors # Build strides, anchors
m = self.model[-1] # Detect() m = self.model[-1] # Detect()
...@@ -110,12 +109,12 @@ class Model(nn.Module): ...@@ -110,12 +109,12 @@ class Model(nn.Module):
check_anchor_order(m) check_anchor_order(m)
self.stride = m.stride self.stride = m.stride
self._initialize_biases() # only run once self._initialize_biases() # only run once
# logger.info('Strides: %s' % m.stride.tolist()) # LOGGER.info('Strides: %s' % m.stride.tolist())
# Init weights, biases # Init weights, biases
initialize_weights(self) initialize_weights(self)
self.info() self.info()
logger.info('') LOGGER.info('')
def forward(self, x, augment=False, profile=False, visualize=False): def forward(self, x, augment=False, profile=False, visualize=False):
if augment: if augment:
...@@ -143,13 +142,13 @@ class Model(nn.Module): ...@@ -143,13 +142,13 @@ class Model(nn.Module):
if profile: if profile:
o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
t = time_synchronized() t = time_sync()
for _ in range(10): for _ in range(10):
_ = m(x) _ = m(x)
dt.append((time_synchronized() - t) * 100) dt.append((time_sync() - t) * 100)
if m == self.model[0]: if m == self.model[0]:
logger.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} {'module'}") LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} {'module'}")
logger.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}') LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
x = m(x) # run x = m(x) # run
y.append(x if m.i in self.save else None) # save output y.append(x if m.i in self.save else None) # save output
...@@ -158,7 +157,7 @@ class Model(nn.Module): ...@@ -158,7 +157,7 @@ class Model(nn.Module):
feature_visualization(x, m.type, m.i, save_dir=visualize) feature_visualization(x, m.type, m.i, save_dir=visualize)
if profile: if profile:
logger.info('%.1fms total' % sum(dt)) LOGGER.info('%.1fms total' % sum(dt))
return x return x
def _descale_pred(self, p, flips, scale, img_size): def _descale_pred(self, p, flips, scale, img_size):
...@@ -192,16 +191,16 @@ class Model(nn.Module): ...@@ -192,16 +191,16 @@ class Model(nn.Module):
m = self.model[-1] # Detect() module m = self.model[-1] # Detect() module
for mi in m.m: # from for mi in m.m: # from
b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85) b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
logger.info( LOGGER.info(
('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean())) ('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
# def _print_weights(self): # def _print_weights(self):
# for m in self.model.modules(): # for m in self.model.modules():
# if type(m) is Bottleneck: # if type(m) is Bottleneck:
# logger.info('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights # LOGGER.info('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
logger.info('Fusing layers... ') LOGGER.info('Fusing layers... ')
for m in self.model.modules(): for m in self.model.modules():
if type(m) is Conv and hasattr(m, 'bn'): if type(m) is Conv and hasattr(m, 'bn'):
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
...@@ -213,19 +212,19 @@ class Model(nn.Module): ...@@ -213,19 +212,19 @@ class Model(nn.Module):
def nms(self, mode=True): # add or remove NMS module def nms(self, mode=True): # add or remove NMS module
present = type(self.model[-1]) is NMS # last layer is NMS present = type(self.model[-1]) is NMS # last layer is NMS
if mode and not present: if mode and not present:
logger.info('Adding NMS... ') LOGGER.info('Adding NMS... ')
m = NMS() # module m = NMS() # module
m.f = -1 # from m.f = -1 # from
m.i = self.model[-1].i + 1 # index m.i = self.model[-1].i + 1 # index
self.model.add_module(name='%s' % m.i, module=m) # add self.model.add_module(name='%s' % m.i, module=m) # add
self.eval() self.eval()
elif not mode and present: elif not mode and present:
logger.info('Removing NMS... ') LOGGER.info('Removing NMS... ')
self.model = self.model[:-1] # remove self.model = self.model[:-1] # remove
return self return self
def autoshape(self): # add AutoShape module def autoshape(self): # add AutoShape module
logger.info('Adding AutoShape... ') LOGGER.info('Adding AutoShape... ')
m = AutoShape(self) # wrap model m = AutoShape(self) # wrap model
copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
return m return m
...@@ -235,7 +234,7 @@ class Model(nn.Module): ...@@ -235,7 +234,7 @@ class Model(nn.Module):
def parse_model(d, ch): # model_dict, input_channels(3) def parse_model(d, ch): # model_dict, input_channels(3)
logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments')) LOGGER.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'] anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
no = na * (nc + 5) # number of outputs = anchors * (classes + 5) no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
...@@ -279,7 +278,7 @@ def parse_model(d, ch): # model_dict, input_channels(3) ...@@ -279,7 +278,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
t = str(m)[8:-2].replace('__main__.', '') # module type t = str(m)[8:-2].replace('__main__.', '') # module type
np = sum([x.numel() for x in m_.parameters()]) # number params np = sum([x.numel() for x in m_.parameters()]) # number params
m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print LOGGER.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
layers.append(m_) layers.append(m_)
if i == 0: if i == 0:
...@@ -308,5 +307,5 @@ if __name__ == '__main__': ...@@ -308,5 +307,5 @@ if __name__ == '__main__':
# Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898) # Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898)
# from torch.utils.tensorboard import SummaryWriter # from torch.utils.tensorboard import SummaryWriter
# tb_writer = SummaryWriter('.') # tb_writer = SummaryWriter('.')
# logger.info("Run 'tensorboard --logdir=models' to view tensorboard at http://localhost:6006/") # LOGGER.info("Run 'tensorboard --logdir=models' to view tensorboard at http://localhost:6006/")
# tb_writer.add_graph(torch.jit.trace(model, img, strict=False), []) # add model graph # tb_writer.add_graph(torch.jit.trace(model, img, strict=False), []) # add model graph
差异被折叠。
...@@ -22,17 +22,16 @@ from PIL import Image, ExifTags ...@@ -22,17 +22,16 @@ from PIL import Image, ExifTags
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm from tqdm import tqdm
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective, cutout from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \ from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \
xyn2xy, segments2boxes, clean_str xyn2xy, segments2boxes, clean_str
from utils.torch_utils import torch_distributed_zero_first from utils.torch_utils import torch_distributed_zero_first
# Parameters # Parameters
help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data' HELP_URL = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes IMG_FORMATS = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes
vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes VID_FORMATS = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes
num_threads = min(8, os.cpu_count()) # number of multiprocessing threads NUM_THREADS = min(8, os.cpu_count()) # number of multiprocessing threads
logger = logging.getLogger(__name__)
# Get orientation exif tag # Get orientation exif tag
for orientation in ExifTags.TAGS.keys(): for orientation in ExifTags.TAGS.keys():
...@@ -164,8 +163,8 @@ class LoadImages: # for inference ...@@ -164,8 +163,8 @@ class LoadImages: # for inference
else: else:
raise Exception(f'ERROR: {p} does not exist') raise Exception(f'ERROR: {p} does not exist')
images = [x for x in files if x.split('.')[-1].lower() in img_formats] images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
videos = [x for x in files if x.split('.')[-1].lower() in vid_formats] videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
ni, nv = len(images), len(videos) ni, nv = len(images), len(videos)
self.img_size = img_size self.img_size = img_size
...@@ -179,7 +178,7 @@ class LoadImages: # for inference ...@@ -179,7 +178,7 @@ class LoadImages: # for inference
else: else:
self.cap = None self.cap = None
assert self.nf > 0, f'No images or videos found in {p}. ' \ assert self.nf > 0, f'No images or videos found in {p}. ' \
f'Supported formats are:\nimages: {img_formats}\nvideos: {vid_formats}' f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
def __iter__(self): def __iter__(self):
self.count = 0 self.count = 0
...@@ -389,11 +388,11 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -389,11 +388,11 @@ class LoadImagesAndLabels(Dataset): # for training/testing
# f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib) # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
else: else:
raise Exception(f'{prefix}{p} does not exist') raise Exception(f'{prefix}{p} does not exist')
self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats]) self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS])
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats]) # pathlib # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats]) # pathlib
assert self.img_files, f'{prefix}No images found' assert self.img_files, f'{prefix}No images found'
except Exception as e: except Exception as e:
raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}') raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {HELP_URL}')
# Check cache # Check cache
self.label_files = img2label_paths(self.img_files) # labels self.label_files = img2label_paths(self.img_files) # labels
...@@ -411,7 +410,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -411,7 +410,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
if cache['msgs']: if cache['msgs']:
logging.info('\n'.join(cache['msgs'])) # display warnings logging.info('\n'.join(cache['msgs'])) # display warnings
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}' assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {HELP_URL}'
# Read cache # Read cache
[cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
...@@ -460,7 +459,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -460,7 +459,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
if cache_images: if cache_images:
gb = 0 # Gigabytes of cached images gb = 0 # Gigabytes of cached images
self.img_hw0, self.img_hw = [None] * n, [None] * n self.img_hw0, self.img_hw = [None] * n, [None] * n
results = ThreadPool(num_threads).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) results = ThreadPool(NUM_THREADS).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
pbar = tqdm(enumerate(results), total=n) pbar = tqdm(enumerate(results), total=n)
for i, x in pbar: for i, x in pbar:
self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i) self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
...@@ -473,7 +472,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -473,7 +472,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
x = {} # dict x = {} # dict
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..." desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
with Pool(num_threads) as pool: with Pool(NUM_THREADS) as pool:
pbar = tqdm(pool.imap_unordered(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))), pbar = tqdm(pool.imap_unordered(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))),
desc=desc, total=len(self.img_files)) desc=desc, total=len(self.img_files))
for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar: for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
...@@ -491,7 +490,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -491,7 +490,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
if msgs: if msgs:
logging.info('\n'.join(msgs)) logging.info('\n'.join(msgs))
if nf == 0: if nf == 0:
logging.info(f'{prefix}WARNING: No labels found in {path}. See {help_url}') logging.info(f'{prefix}WARNING: No labels found in {path}. See {HELP_URL}')
x['hash'] = get_hash(self.label_files + self.img_files) x['hash'] = get_hash(self.label_files + self.img_files)
x['results'] = nf, nm, ne, nc, len(self.img_files) x['results'] = nf, nm, ne, nc, len(self.img_files)
x['msgs'] = msgs # warnings x['msgs'] = msgs # warnings
...@@ -789,7 +788,7 @@ def extract_boxes(path='../datasets/coco128'): # from utils.datasets import *; ...@@ -789,7 +788,7 @@ def extract_boxes(path='../datasets/coco128'): # from utils.datasets import *;
files = list(path.rglob('*.*')) files = list(path.rglob('*.*'))
n = len(files) # number of files n = len(files) # number of files
for im_file in tqdm(files, total=n): for im_file in tqdm(files, total=n):
if im_file.suffix[1:] in img_formats: if im_file.suffix[1:] in IMG_FORMATS:
# image # image
im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
h, w = im.shape[:2] h, w = im.shape[:2]
...@@ -825,7 +824,7 @@ def autosplit(path='../datasets/coco128/images', weights=(0.9, 0.1, 0.0), annota ...@@ -825,7 +824,7 @@ def autosplit(path='../datasets/coco128/images', weights=(0.9, 0.1, 0.0), annota
annotated_only: Only use images with an annotated txt file annotated_only: Only use images with an annotated txt file
""" """
path = Path(path) # images dir path = Path(path) # images dir
files = sum([list(path.rglob(f"*.{img_ext}")) for img_ext in img_formats], []) # image files only files = sum([list(path.rglob(f"*.{img_ext}")) for img_ext in IMG_FORMATS], []) # image files only
n = len(files) # number of files n = len(files) # number of files
random.seed(0) # for reproducibility random.seed(0) # for reproducibility
indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
...@@ -850,7 +849,7 @@ def verify_image_label(args): ...@@ -850,7 +849,7 @@ def verify_image_label(args):
im.verify() # PIL verify im.verify() # PIL verify
shape = exif_size(im) # image size shape = exif_size(im) # image size
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
assert im.format.lower() in img_formats, f'invalid image format {im.format}' assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
if im.format.lower() in ('jpg', 'jpeg'): if im.format.lower() in ('jpg', 'jpeg'):
with open(im_file, 'rb') as f: with open(im_file, 'rb') as f:
f.seek(-2, 2) f.seek(-2, 2)
......
...@@ -22,7 +22,7 @@ try: ...@@ -22,7 +22,7 @@ try:
import thop # for FLOPs computation import thop # for FLOPs computation
except ImportError: except ImportError:
thop = None thop = None
logger = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
@contextmanager @contextmanager
...@@ -85,11 +85,11 @@ def select_device(device='', batch_size=None): ...@@ -85,11 +85,11 @@ def select_device(device='', batch_size=None):
else: else:
s += 'CPU\n' s += 'CPU\n'
logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
return torch.device('cuda:0' if cuda else 'cpu') return torch.device('cuda:0' if cuda else 'cpu')
def time_synchronized(): def time_sync():
# pytorch-accurate time # pytorch-accurate time
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -118,12 +118,12 @@ def profile(x, ops, n=100, device=None): ...@@ -118,12 +118,12 @@ def profile(x, ops, n=100, device=None):
flops = 0 flops = 0
for _ in range(n): for _ in range(n):
t[0] = time_synchronized() t[0] = time_sync()
y = m(x) y = m(x)
t[1] = time_synchronized() t[1] = time_sync()
try: try:
_ = y.sum().backward() _ = y.sum().backward()
t[2] = time_synchronized() t[2] = time_sync()
except: # no backward method except: # no backward method
t[2] = float('nan') t[2] = float('nan')
dtf += (t[1] - t[0]) * 1000 / n # ms per op forward dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
...@@ -231,7 +231,7 @@ def model_info(model, verbose=False, img_size=640): ...@@ -231,7 +231,7 @@ def model_info(model, verbose=False, img_size=640):
except (ImportError, Exception): except (ImportError, Exception):
fs = '' fs = ''
logger.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}") LOGGER.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
def load_classifier(name='resnet101', n=2): def load_classifier(name='resnet101', n=2):
......
...@@ -98,7 +98,14 @@ class WandbLogger(): ...@@ -98,7 +98,14 @@ class WandbLogger():
def __init__(self, opt, name, run_id, data_dict, job_type='Training'): def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
# Pre-training routine -- # Pre-training routine --
self.job_type = job_type self.job_type = job_type
self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict self.wandb, self.wandb_run = wandb, None if not wandb else wandb.run
self.val_artifact, self.train_artifact = None, None
self.train_artifact_path, self.val_artifact_path = None, None
self.result_artifact = None
self.val_table, self.result_table = None, None
self.data_dict = data_dict
self.bbox_media_panel_images = []
self.val_table_path_map = None
# It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call # It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call
if isinstance(opt.resume, str): # checks resume from artifact if isinstance(opt.resume, str): # checks resume from artifact
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
...@@ -156,25 +163,27 @@ class WandbLogger(): ...@@ -156,25 +163,27 @@ class WandbLogger():
self.weights), config.save_period, config.batch_size, config.bbox_interval, config.epochs, \ self.weights), config.save_period, config.batch_size, config.bbox_interval, config.epochs, \
config.opt['hyp'] config.opt['hyp']
data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume
if 'val_artifact' not in self.__dict__: # If --upload_dataset is set, use the existing artifact, don't download if self.val_artifact is None: # If --upload_dataset is set, use the existing artifact, don't download
self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'), self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'),
opt.artifact_alias) opt.artifact_alias)
self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'), self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'),
opt.artifact_alias) opt.artifact_alias)
self.result_artifact, self.result_table, self.val_table, self.weights = None, None, None, None
if self.train_artifact_path is not None: if self.train_artifact_path is not None:
train_path = Path(self.train_artifact_path) / 'data/images/' train_path = Path(self.train_artifact_path) / 'data/images/'
data_dict['train'] = str(train_path) data_dict['train'] = str(train_path)
if self.val_artifact_path is not None: if self.val_artifact_path is not None:
val_path = Path(self.val_artifact_path) / 'data/images/' val_path = Path(self.val_artifact_path) / 'data/images/'
data_dict['val'] = str(val_path) data_dict['val'] = str(val_path)
self.val_table = self.val_artifact.get("val")
self.map_val_table_path()
wandb.log({"validation dataset": self.val_table})
if self.val_artifact is not None: if self.val_artifact is not None:
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation") self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"]) self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"])
self.val_table = self.val_artifact.get("val")
if self.val_table_path_map is None:
self.map_val_table_path()
wandb.log({"validation dataset": self.val_table})
if opt.bbox_interval == -1: if opt.bbox_interval == -1:
self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1 self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1
return data_dict return data_dict
...@@ -182,7 +191,7 @@ class WandbLogger(): ...@@ -182,7 +191,7 @@ class WandbLogger():
def download_dataset_artifact(self, path, alias): def download_dataset_artifact(self, path, alias):
if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX): if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX):
artifact_path = Path(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias) artifact_path = Path(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
dataset_artifact = wandb.use_artifact(artifact_path.as_posix().replace("\\","/")) dataset_artifact = wandb.use_artifact(artifact_path.as_posix().replace("\\", "/"))
assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'" assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'"
datadir = dataset_artifact.download() datadir = dataset_artifact.download()
return datadir, dataset_artifact return datadir, dataset_artifact
...@@ -246,10 +255,10 @@ class WandbLogger(): ...@@ -246,10 +255,10 @@ class WandbLogger():
return path return path
def map_val_table_path(self): def map_val_table_path(self):
self.val_table_map = {} self.val_table_path_map = {}
print("Mapping dataset") print("Mapping dataset")
for i, data in enumerate(tqdm(self.val_table.data)): for i, data in enumerate(tqdm(self.val_table.data)):
self.val_table_map[data[3]] = data[0] self.val_table_path_map[data[3]] = data[0]
def create_dataset_table(self, dataset, class_to_id, name='dataset'): def create_dataset_table(self, dataset, class_to_id, name='dataset'):
# TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging # TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging
...@@ -283,7 +292,6 @@ class WandbLogger(): ...@@ -283,7 +292,6 @@ class WandbLogger():
return artifact return artifact
def log_training_progress(self, predn, path, names): def log_training_progress(self, predn, path, names):
if self.val_table and self.result_table:
class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()]) class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()])
box_data = [] box_data = []
total_conf = 0 total_conf = 0
...@@ -297,7 +305,7 @@ class WandbLogger(): ...@@ -297,7 +305,7 @@ class WandbLogger():
"domain": "pixel"}) "domain": "pixel"})
total_conf = total_conf + conf total_conf = total_conf + conf
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
id = self.val_table_map[Path(path).name] id = self.val_table_path_map[Path(path).name]
self.result_table.add_data(self.current_epoch, self.result_table.add_data(self.current_epoch,
id, id,
self.val_table.data[id][1], self.val_table.data[id][1],
...@@ -305,6 +313,22 @@ class WandbLogger(): ...@@ -305,6 +313,22 @@ class WandbLogger():
total_conf / max(1, len(box_data)) total_conf / max(1, len(box_data))
) )
def val_one_image(self, pred, predn, path, names, im):
if self.val_table and self.result_table: # Log Table if Val dataset is uploaded as artifact
self.log_training_progress(predn, path, names)
else: # Default to bbox media panelif Val artifact not found
log_imgs = min(self.log_imgs, 100)
if len(self.bbox_media_panel_images) < log_imgs and self.current_epoch > 0:
if self.current_epoch % self.bbox_interval == 0:
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
"class_id": int(cls),
"box_caption": "%s %.3f" % (names[cls], conf),
"scores": {"class_score": conf},
"domain": "pixel"} for *xyxy, conf, cls in pred.tolist()]
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
self.bbox_media_panel_images.append(wandb.Image(im, boxes=boxes, caption=path.name))
def log(self, log_dict): def log(self, log_dict):
if self.wandb_run: if self.wandb_run:
for key, value in log_dict.items(): for key, value in log_dict.items():
...@@ -313,13 +337,16 @@ class WandbLogger(): ...@@ -313,13 +337,16 @@ class WandbLogger():
def end_epoch(self, best_result=False): def end_epoch(self, best_result=False):
if self.wandb_run: if self.wandb_run:
with all_logging_disabled(): with all_logging_disabled():
if self.bbox_media_panel_images:
self.log_dict["Bounding Box Debugger/Images"] = self.bbox_media_panel_images
wandb.log(self.log_dict) wandb.log(self.log_dict)
self.log_dict = {} self.log_dict = {}
self.bbox_media_panel_images = []
if self.result_artifact: if self.result_artifact:
self.result_artifact.add(self.result_table, 'result') self.result_artifact.add(self.result_table, 'result')
wandb.log_artifact(self.result_artifact, aliases=['latest', 'last', 'epoch ' + str(self.current_epoch), wandb.log_artifact(self.result_artifact, aliases=['latest', 'last', 'epoch ' + str(self.current_epoch),
('best' if best_result else '')]) ('best' if best_result else '')])
wandb.log({"evaluation": self.result_table}) wandb.log({"evaluation": self.result_table})
self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"]) self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"])
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation") self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
......
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论