Unverified 提交 9d75e42f authored 作者: Jebastin Nadar's avatar Jebastin Nadar 提交者: GitHub

Refactor `Detect()` anchors for ONNX <> OpenCV DNN compatibility (#4833)

* refactor anchors and anchor_grid in Detect Layer * fix CI failures by adding compatibility * fix tf failure * fix different devices errors * Cleanup * fix anchors overwriting issue * better refactoring * Remove self.anchor_grid shape check (redundant with self.grid check) Also PEP8 / 120 line width * Convert _make_grid() from static to dynamic method * Remove anchor_grid.to(device) clone() should already clone to same device as self.anchors * fix different devices error Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 153873e9
...@@ -295,6 +295,8 @@ class AutoShape(nn.Module): ...@@ -295,6 +295,8 @@ class AutoShape(nn.Module):
m = self.model.model[-1] # Detect() m = self.model.model[-1] # Detect()
m.stride = fn(m.stride) m.stride = fn(m.stride)
m.grid = list(map(fn, m.grid)) m.grid = list(map(fn, m.grid))
if isinstance(m.anchor_grid, list):
m.anchor_grid = list(map(fn, m.anchor_grid))
return self return self
@torch.no_grad() @torch.no_grad()
......
...@@ -102,6 +102,10 @@ def attempt_load(weights, map_location=None, inplace=True, fuse=True): ...@@ -102,6 +102,10 @@ def attempt_load(weights, map_location=None, inplace=True, fuse=True):
for m in model.modules(): for m in model.modules():
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]: if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
m.inplace = inplace # pytorch 1.7.0 compatibility m.inplace = inplace # pytorch 1.7.0 compatibility
if type(m) is Detect:
if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility
delattr(m, 'anchor_grid')
setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
elif type(m) is Conv: elif type(m) is Conv:
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
......
...@@ -193,7 +193,7 @@ class TFDetect(keras.layers.Layer): ...@@ -193,7 +193,7 @@ class TFDetect(keras.layers.Layer):
self.na = len(anchors[0]) // 2 # number of anchors self.na = len(anchors[0]) // 2 # number of anchors
self.grid = [tf.zeros(1)] * self.nl # init grid self.grid = [tf.zeros(1)] * self.nl # init grid
self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32) self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
self.anchor_grid = tf.reshape(tf.convert_to_tensor(w.anchor_grid.numpy(), dtype=tf.float32), self.anchor_grid = tf.reshape(self.anchors * tf.reshape(self.stride, [self.nl, 1, 1]),
[self.nl, 1, -1, 1, 2]) [self.nl, 1, -1, 1, 2])
self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)] self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
self.training = False # set to False after building model self.training = False # set to False after building model
......
...@@ -44,9 +44,8 @@ class Detect(nn.Module): ...@@ -44,9 +44,8 @@ class Detect(nn.Module):
self.nl = len(anchors) # number of detection layers self.nl = len(anchors) # number of detection layers
self.na = len(anchors[0]) // 2 # number of anchors self.na = len(anchors[0]) // 2 # number of anchors
self.grid = [torch.zeros(1)] * self.nl # init grid self.grid = [torch.zeros(1)] * self.nl # init grid
a = torch.tensor(anchors).float().view(self.nl, -1, 2) self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid
self.register_buffer('anchors', a) # shape(nl,na,2) self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
self.inplace = inplace # use in-place ops (e.g. slice assignment) self.inplace = inplace # use in-place ops (e.g. slice assignment)
...@@ -59,7 +58,7 @@ class Detect(nn.Module): ...@@ -59,7 +58,7 @@ class Detect(nn.Module):
if not self.training: # inference if not self.training: # inference
if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic: if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic:
self.grid[i] = self._make_grid(nx, ny).to(x[i].device) self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
y = x[i].sigmoid() y = x[i].sigmoid()
if self.inplace: if self.inplace:
...@@ -67,16 +66,19 @@ class Detect(nn.Module): ...@@ -67,16 +66,19 @@ class Detect(nn.Module):
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953 else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
y = torch.cat((xy, wh, y[..., 4:]), -1) y = torch.cat((xy, wh, y[..., 4:]), -1)
z.append(y.view(bs, -1, self.no)) z.append(y.view(bs, -1, self.no))
return x if self.training else (torch.cat(z, 1), x) return x if self.training else (torch.cat(z, 1), x)
@staticmethod def _make_grid(self, nx=20, ny=20, i=0):
def _make_grid(nx=20, ny=20): d = self.anchors[i].device
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)])
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float() grid = torch.stack((xv, yv), 2).expand((1, self.na, ny, nx, 2)).float()
anchor_grid = (self.anchors[i].clone() * self.stride[i]) \
.view((1, self.na, 1, 1, 2)).expand((1, self.na, ny, nx, 2)).float()
return grid, anchor_grid
class Model(nn.Module): class Model(nn.Module):
...@@ -239,6 +241,8 @@ class Model(nn.Module): ...@@ -239,6 +241,8 @@ class Model(nn.Module):
if isinstance(m, Detect): if isinstance(m, Detect):
m.stride = fn(m.stride) m.stride = fn(m.stride)
m.grid = list(map(fn, m.grid)) m.grid = list(map(fn, m.grid))
if isinstance(m.anchor_grid, list):
m.anchor_grid = list(map(fn, m.anchor_grid))
return self return self
......
...@@ -15,13 +15,12 @@ from utils.general import colorstr ...@@ -15,13 +15,12 @@ from utils.general import colorstr
def check_anchor_order(m): def check_anchor_order(m):
# Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
a = m.anchor_grid.prod(-1).view(-1) # anchor area a = m.anchors.prod(-1).view(-1) # anchor area
da = a[-1] - a[0] # delta a da = a[-1] - a[0] # delta a
ds = m.stride[-1] - m.stride[0] # delta s ds = m.stride[-1] - m.stride[0] # delta s
if da.sign() != ds.sign(): # same order if da.sign() != ds.sign(): # same order
print('Reversing anchor order') print('Reversing anchor order')
m.anchors[:] = m.anchors.flip(0) m.anchors[:] = m.anchors.flip(0)
m.anchor_grid[:] = m.anchor_grid.flip(0)
def check_anchors(dataset, model, thr=4.0, imgsz=640): def check_anchors(dataset, model, thr=4.0, imgsz=640):
...@@ -41,12 +40,12 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640): ...@@ -41,12 +40,12 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
bpr = (best > 1. / thr).float().mean() # best possible recall bpr = (best > 1. / thr).float().mean() # best possible recall
return bpr, aat return bpr, aat
anchors = m.anchor_grid.clone().cpu().view(-1, 2) # current anchors anchors = m.anchors.clone() * m.stride.to(m.anchors.device).view(-1, 1, 1) # current anchors
bpr, aat = metric(anchors) bpr, aat = metric(anchors.cpu().view(-1, 2))
print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='') print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='')
if bpr < 0.98: # threshold to recompute if bpr < 0.98: # threshold to recompute
print('. Attempting to improve anchors, please wait...') print('. Attempting to improve anchors, please wait...')
na = m.anchor_grid.numel() // 2 # number of anchors na = m.anchors.numel() // 2 # number of anchors
try: try:
anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False) anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
except Exception as e: except Exception as e:
...@@ -54,7 +53,6 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640): ...@@ -54,7 +53,6 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
new_bpr = metric(anchors)[0] new_bpr = metric(anchors)[0]
if new_bpr > bpr: # replace anchors if new_bpr > bpr: # replace anchors
anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors) anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
m.anchor_grid[:] = anchors.clone().view_as(m.anchor_grid) # for inference
m.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss m.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
check_anchor_order(m) check_anchor_order(m)
print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.') print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论