Unverified 提交 13530402 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub
上级 53711bac
......@@ -10,6 +10,7 @@ import torch
import yaml
from tqdm import tqdm
from utils import TryExcept
from utils.general import LOGGER, colorstr
PREFIX = colorstr('AutoAnchor: ')
......@@ -25,6 +26,7 @@ def check_anchor_order(m):
m.anchors[:] = m.anchors.flip(0)
@TryExcept(f'{PREFIX}ERROR:')
def check_anchors(dataset, model, thr=4.0, imgsz=640):
# Check anchor fit to data, recompute if necessary
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
......@@ -49,10 +51,7 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
else:
LOGGER.info(f'{s}Anchors are a poor fit to dataset ⚠️, attempting to improve...')
na = m.anchors.numel() // 2 # number of anchors
try:
anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
except Exception as e:
LOGGER.info(f'{PREFIX}ERROR: {e}')
new_bpr = metric(anchors)[0]
if new_bpr > bpr: # replace anchors
anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
......@@ -124,7 +123,7 @@ def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen
i = (wh0 < 3.0).any(1).sum()
if i:
LOGGER.info(f'{PREFIX}WARNING: Extremely small objects found: {i} of {len(wh0)} labels are < 3 pixels in size')
wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
wh = wh0[(wh0 >= 2.0).any(1)].astype(np.float32) # filter > 2 pixels
# wh = wh * (npr.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
# Kmeans init
......@@ -167,4 +166,4 @@ def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen
if verbose:
print_results(k, verbose)
return print_results(k)
return print_results(k).astype(np.float32)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论