Unverified 提交 4e8c81a3 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Add `yolov5s-ghost.yaml` (#4412)

* Add yolov5s-ghost.yaml * Finish C3Ghost * Add C3Ghost to list * Add C3Ghost to number of repeats if statement * Fixes * Cleanup
上级 e0863473
...@@ -149,6 +149,14 @@ class C3SPP(C3): ...@@ -149,6 +149,14 @@ class C3SPP(C3):
self.m = SPP(c_, c_, k) self.m = SPP(c_, c_, k)
class C3Ghost(C3):
# C3 module with GhostBottleneck()
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
super().__init__(c1, c2, n, shortcut, g, e)
c_ = int(c2 * e) # hidden channels
self.m = nn.Sequential(*[GhostBottleneck(c_, c_) for _ in range(n)])
class SPP(nn.Module): class SPP(nn.Module):
# Spatial pyramid pooling layer used in YOLOv3-SPP # Spatial pyramid pooling layer used in YOLOv3-SPP
def __init__(self, c1, c2, k=(5, 9, 13)): def __init__(self, c1, c2, k=(5, 9, 13)):
...@@ -177,6 +185,34 @@ class Focus(nn.Module): ...@@ -177,6 +185,34 @@ class Focus(nn.Module):
# return self.conv(self.contract(x)) # return self.conv(self.contract(x))
class GhostConv(nn.Module):
# Ghost Convolution https://github.com/huawei-noah/ghostnet
def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
super().__init__()
c_ = c2 // 2 # hidden channels
self.cv1 = Conv(c1, c_, k, s, None, g, act)
self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
def forward(self, x):
y = self.cv1(x)
return torch.cat([y, self.cv2(y)], 1)
class GhostBottleneck(nn.Module):
# Ghost Bottleneck https://github.com/huawei-noah/ghostnet
def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
super().__init__()
c_ = c2 // 2
self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
def forward(self, x):
return self.conv(x) + self.shortcut(x)
class Contract(nn.Module): class Contract(nn.Module):
# Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40) # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
def __init__(self, gain=2): def __init__(self, gain=2):
......
...@@ -43,34 +43,6 @@ class Sum(nn.Module): ...@@ -43,34 +43,6 @@ class Sum(nn.Module):
return y return y
class GhostConv(nn.Module):
# Ghost Convolution https://github.com/huawei-noah/ghostnet
def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
super().__init__()
c_ = c2 // 2 # hidden channels
self.cv1 = Conv(c1, c_, k, s, None, g, act)
self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
def forward(self, x):
y = self.cv1(x)
return torch.cat([y, self.cv2(y)], 1)
class GhostBottleneck(nn.Module):
# Ghost Bottleneck https://github.com/huawei-noah/ghostnet
def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
super().__init__()
c_ = c2 // 2
self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
def forward(self, x):
return self.conv(x) + self.shortcut(x)
class MixConv2d(nn.Module): class MixConv2d(nn.Module):
# Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595 # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
......
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.50 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# YOLOv5 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Focus, [64, 3]], # 0-P1/2
[-1, 1, GhostConv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3Ghost, [128]],
[-1, 1, GhostConv, [256, 3, 2]], # 3-P3/8
[-1, 9, C3Ghost, [256]],
[-1, 1, GhostConv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3Ghost, [512]],
[-1, 1, GhostConv, [1024, 3, 2]], # 7-P5/32
[-1, 1, SPP, [1024, [5, 9, 13]]],
[-1, 3, C3Ghost, [1024, False]], # 9
]
# YOLOv5 head
head:
[[-1, 1, GhostConv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3Ghost, [512, False]], # 13
[-1, 1, GhostConv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3Ghost, [256, False]], # 17 (P3/8-small)
[-1, 1, GhostConv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3Ghost, [512, False]], # 20 (P4/16-medium)
[-1, 1, GhostConv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3Ghost, [1024, False]], # 23 (P5/32-large)
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]
...@@ -236,13 +236,13 @@ def parse_model(d, ch): # model_dict, input_channels(3) ...@@ -236,13 +236,13 @@ def parse_model(d, ch): # model_dict, input_channels(3)
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP,
C3, C3TR, C3SPP]: C3, C3TR, C3SPP, C3Ghost]:
c1, c2 = ch[f], args[0] c1, c2 = ch[f], args[0]
if c2 != no: # if not output if c2 != no: # if not output
c2 = make_divisible(c2 * gw, 8) c2 = make_divisible(c2 * gw, 8)
args = [c1, c2, *args[1:]] args = [c1, c2, *args[1:]]
if m in [BottleneckCSP, C3, C3TR]: if m in [BottleneckCSP, C3, C3TR, C3Ghost]:
args.insert(2, n) # number of repeats args.insert(2, n) # number of repeats
n = 1 n = 1
elif m is nn.BatchNorm2d: elif m is nn.BatchNorm2d:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论