Unverified 提交 e96c74b5 authored 作者: Yonghye Kwon's avatar Yonghye Kwon 提交者: GitHub

Simpler code for DWConvClass (#4310)

* more simpler code for DWConvClass more simpler code for DWConvClass * remove DWConv function * Replace DWConvClass with DWConv
上级 f409d8e5
...@@ -29,11 +29,6 @@ def autopad(k, p=None): # kernel, padding ...@@ -29,11 +29,6 @@ def autopad(k, p=None): # kernel, padding
return p return p
def DWConv(c1, c2, k=1, s=1, act=True):
# Depth-wise convolution function
return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
class Conv(nn.Module): class Conv(nn.Module):
# Standard convolution # Standard convolution
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
...@@ -49,11 +44,10 @@ class Conv(nn.Module): ...@@ -49,11 +44,10 @@ class Conv(nn.Module):
return self.act(self.conv(x)) return self.act(self.conv(x))
class DWConvClass(Conv): class DWConv(Conv):
# Depth-wise convolution class # Depth-wise convolution class
def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super().__init__(c1, c2, k, s, act) super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k), groups=math.gcd(c1, c2), bias=False)
class TransformerLayer(nn.Module): class TransformerLayer(nn.Module):
......
...@@ -202,7 +202,7 @@ class Model(nn.Module): ...@@ -202,7 +202,7 @@ class Model(nn.Module):
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
LOGGER.info('Fusing layers... ') LOGGER.info('Fusing layers... ')
for m in self.model.modules(): for m in self.model.modules():
if isinstance(m, (Conv, DWConvClass)) and hasattr(m, 'bn'): if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
delattr(m, 'bn') # remove batchnorm delattr(m, 'bn') # remove batchnorm
m.forward = m.forward_fuse # update forward m.forward = m.forward_fuse # update forward
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论