Unverified 提交 6ab58958 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Add colorstr() (#1887)

* Add colorful() * update * newline fix * add git description * --always * update loss scaling * update loss scaling 2 * rename to colorstr()
上级 3e25f1e9
...@@ -216,8 +216,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ...@@ -216,8 +216,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
# Model parameters # Model parameters
hyp['cls'] *= nc / 80. # scale hyp['cls'] to class count hyp['box'] *= 3. / nl # scale to layers
hyp['obj'] *= imgsz ** 2 / 640. ** 2 * 3. / nl # scale hyp['obj'] to image size and output layers hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
model.nc = nc # attach number of classes to model model.nc = nc # attach number of classes to model
model.hyp = hyp # attach hyperparameters to model model.hyp = hyp # attach hyperparameters to model
model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou) model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
......
...@@ -6,6 +6,8 @@ import yaml ...@@ -6,6 +6,8 @@ import yaml
from scipy.cluster.vq import kmeans from scipy.cluster.vq import kmeans
from tqdm import tqdm from tqdm import tqdm
from utils.general import colorstr
def check_anchor_order(m): def check_anchor_order(m):
# Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
...@@ -20,7 +22,8 @@ def check_anchor_order(m): ...@@ -20,7 +22,8 @@ def check_anchor_order(m):
def check_anchors(dataset, model, thr=4.0, imgsz=640): def check_anchors(dataset, model, thr=4.0, imgsz=640):
# Check anchor fit to data, recompute if necessary # Check anchor fit to data, recompute if necessary
print('\nAnalyzing anchors... ', end='') prefix = colorstr('blue', 'bold', 'autoanchor') + ': '
print(f'\n{prefix}Analyzing anchors... ', end='')
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect() m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True) shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
...@@ -35,7 +38,7 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640): ...@@ -35,7 +38,7 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
return bpr, aat return bpr, aat
bpr, aat = metric(m.anchor_grid.clone().cpu().view(-1, 2)) bpr, aat = metric(m.anchor_grid.clone().cpu().view(-1, 2))
print('anchors/target = %.2f, Best Possible Recall (BPR) = %.4f' % (aat, bpr), end='') print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='')
if bpr < 0.98: # threshold to recompute if bpr < 0.98: # threshold to recompute
print('. Attempting to improve anchors, please wait...') print('. Attempting to improve anchors, please wait...')
na = m.anchor_grid.numel() // 2 # number of anchors na = m.anchor_grid.numel() // 2 # number of anchors
...@@ -46,9 +49,9 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640): ...@@ -46,9 +49,9 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference
m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
check_anchor_order(m) check_anchor_order(m)
print('New anchors saved to model. Update model *.yaml to use these anchors in the future.') print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.')
else: else:
print('Original anchors better than new anchors. Proceeding with original anchors.') print(f'{prefix}Original anchors better than new anchors. Proceeding with original anchors.')
print('') # newline print('') # newline
...@@ -70,6 +73,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10 ...@@ -70,6 +73,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
from utils.autoanchor import *; _ = kmean_anchors() from utils.autoanchor import *; _ = kmean_anchors()
""" """
thr = 1. / thr thr = 1. / thr
prefix = colorstr('blue', 'bold', 'autoanchor') + ': '
def metric(k, wh): # compute metrics def metric(k, wh): # compute metrics
r = wh[:, None] / k[None] r = wh[:, None] / k[None]
...@@ -85,9 +89,9 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10 ...@@ -85,9 +89,9 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
k = k[np.argsort(k.prod(1))] # sort small to large k = k[np.argsort(k.prod(1))] # sort small to large
x, best = metric(k, wh0) x, best = metric(k, wh0)
bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
print('thr=%.2f: %.4f best possible recall, %.2f anchors past thr' % (thr, bpr, aat)) print(f'{prefix}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr')
print('n=%g, img_size=%s, metric_all=%.3f/%.3f-mean/best, past_thr=%.3f-mean: ' % print(f'{prefix}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, '
(n, img_size, x.mean(), best.mean(), x[x > thr].mean()), end='') f'past_thr={x[x > thr].mean():.3f}-mean: ', end='')
for i, x in enumerate(k): for i, x in enumerate(k):
print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg
return k return k
...@@ -107,13 +111,12 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10 ...@@ -107,13 +111,12 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
# Filter # Filter
i = (wh0 < 3.0).any(1).sum() i = (wh0 < 3.0).any(1).sum()
if i: if i:
print('WARNING: Extremely small objects found. ' print(f'{prefix}WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.')
'%g of %g labels are < 3 pixels in width or height.' % (i, len(wh0)))
wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
# wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1 # wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
# Kmeans calculation # Kmeans calculation
print('Running kmeans for %g anchors on %g points...' % (n, len(wh))) print(f'{prefix}Running kmeans for {n} anchors on {len(wh)} points...')
s = wh.std(0) # sigmas for whitening s = wh.std(0) # sigmas for whitening
k, dist = kmeans(wh / s, n, iter=30) # points, mean distance k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
k *= s k *= s
...@@ -136,7 +139,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10 ...@@ -136,7 +139,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
# Evolve # Evolve
npr = np.random npr = np.random
f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
pbar = tqdm(range(gen), desc='Evolving anchors with Genetic Algorithm') # progress bar pbar = tqdm(range(gen), desc=f'{prefix}Evolving anchors with Genetic Algorithm:') # progress bar
for _ in pbar: for _ in pbar:
v = np.ones(sh) v = np.ones(sh)
while (v == 1).all(): # mutate until a change occurs (prevent duplicates) while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
...@@ -145,7 +148,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10 ...@@ -145,7 +148,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
fg = anchor_fitness(kg) fg = anchor_fitness(kg)
if fg > f: if fg > f:
f, k = fg, kg.copy() f, k = fg, kg.copy()
pbar.desc = 'Evolving anchors with Genetic Algorithm: fitness = %.4f' % f pbar.desc = f'{prefix}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
if verbose: if verbose:
print_results(k) print_results(k)
......
...@@ -47,7 +47,7 @@ def get_latest_run(search_dir='.'): ...@@ -47,7 +47,7 @@ def get_latest_run(search_dir='.'):
def check_git_status(): def check_git_status():
# Suggest 'git pull' if repo is out of date # Suggest 'git pull' if repo is out of date
if platform.system() in ['Linux', 'Darwin'] and not os.path.isfile('/.dockerenv'): if Path('.git').exists() and platform.system() in ['Linux', 'Darwin'] and not Path('/.dockerenv').is_file():
s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8') s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8')
if 'Your branch is behind' in s: if 'Your branch is behind' in s:
print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n') print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
...@@ -115,6 +115,32 @@ def one_cycle(y1=0.0, y2=1.0, steps=100): ...@@ -115,6 +115,32 @@ def one_cycle(y1=0.0, y2=1.0, steps=100):
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1 return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
def colorstr(*input):
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
*prefix, str = input # color arguments, string
colors = {'black': '\033[30m', # basic colors
'red': '\033[31m',
'green': '\033[32m',
'yellow': '\033[33m',
'blue': '\033[34m',
'magenta': '\033[35m',
'cyan': '\033[36m',
'white': '\033[37m',
'bright_black': '\033[90m', # bright colors
'bright_red': '\033[91m',
'bright_green': '\033[92m',
'bright_yellow': '\033[93m',
'bright_blue': '\033[94m',
'bright_magenta': '\033[95m',
'bright_cyan': '\033[96m',
'bright_white': '\033[97m',
'end': '\033[0m', # misc
'bold': '\033[1m',
'undelrine': '\033[4m'}
return ''.join(colors[x] for x in prefix) + str + colors['end']
def labels_to_class_weights(labels, nc=80): def labels_to_class_weights(labels, nc=80):
# Get class weights (inverse frequency) from training labels # Get class weights (inverse frequency) from training labels
if labels[0] is None: # no labels loaded if labels[0] is None: # no labels loaded
......
...@@ -105,7 +105,6 @@ def compute_loss(p, targets, model): # predictions, targets, model ...@@ -105,7 +105,6 @@ def compute_loss(p, targets, model): # predictions, targets, model
# Losses # Losses
nt = 0 # number of targets nt = 0 # number of targets
no = len(p) # number of outputs
balance = [4.0, 1.0, 0.3, 0.1, 0.03] # P3-P7 balance = [4.0, 1.0, 0.3, 0.1, 0.03] # P3-P7
for i, pi in enumerate(p): # layer index, layer predictions for i, pi in enumerate(p): # layer index, layer predictions
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
...@@ -138,10 +137,9 @@ def compute_loss(p, targets, model): # predictions, targets, model ...@@ -138,10 +137,9 @@ def compute_loss(p, targets, model): # predictions, targets, model
lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss
s = 3 / no # output count scaling lbox *= h['box']
lbox *= h['box'] * s
lobj *= h['obj'] lobj *= h['obj']
lcls *= h['cls'] * s lcls *= h['cls']
bs = tobj.shape[0] # batch size bs = tobj.shape[0] # batch size
loss = lbox + lobj + lcls loss = lbox + lobj + lcls
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
import logging import logging
import math import math
import os import os
import subprocess
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from pathlib import Path
import torch import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
...@@ -41,9 +43,17 @@ def init_torch_seeds(seed=0): ...@@ -41,9 +43,17 @@ def init_torch_seeds(seed=0):
cudnn.benchmark, cudnn.deterministic = True, False cudnn.benchmark, cudnn.deterministic = True, False
def git_describe():
# return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
if Path('.git').exists():
return subprocess.check_output('git describe --tags --long --always', shell=True).decode('utf-8')[:-1]
else:
return ''
def select_device(device='', batch_size=None): def select_device(device='', batch_size=None):
# device = 'cpu' or '0' or '0,1,2,3' # device = 'cpu' or '0' or '0,1,2,3'
s = f'Using torch {torch.__version__} ' # string s = f'YOLOv5 {git_describe()} torch {torch.__version__} ' # string
cpu = device.lower() == 'cpu' cpu = device.lower() == 'cpu'
if cpu: if cpu:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
...@@ -61,9 +71,9 @@ def select_device(device='', batch_size=None): ...@@ -61,9 +71,9 @@ def select_device(device='', batch_size=None):
p = torch.cuda.get_device_properties(i) p = torch.cuda.get_device_properties(i)
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
else: else:
s += 'CPU' s += 'CPU\n'
logger.info(f'{s}\n') # skip a line logger.info(s) # skip a line
return torch.device('cuda:0' if cuda else 'cpu') return torch.device('cuda:0' if cuda else 'cpu')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论