提交 2f77cf33 authored 作者: Glenn Jocher's avatar Glenn Jocher

.fuse() additional error checking

上级 89655a84
...@@ -160,7 +160,7 @@ class Model(nn.Module): ...@@ -160,7 +160,7 @@ class Model(nn.Module):
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
print('Fusing layers... ') print('Fusing layers... ')
for m in self.model.modules(): for m in self.model.modules():
if type(m) is Conv: if type(m) is Conv and hasattr(Conv, 'bn'):
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
delattr(m, 'bn') # remove batchnorm delattr(m, 'bn') # remove batchnorm
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论