提交 89655a84 authored 作者: Glenn Jocher's avatar Glenn Jocher

.fuse() gradient introduction bug fix

上级 c4cb7857
...@@ -104,8 +104,8 @@ def prune(model, amount=0.3): ...@@ -104,8 +104,8 @@ def prune(model, amount=0.3):
def fuse_conv_and_bn(conv, bn): def fuse_conv_and_bn(conv, bn):
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/ # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
with torch.no_grad():
# init # init
fusedconv = nn.Conv2d(conv.in_channels, fusedconv = nn.Conv2d(conv.in_channels,
conv.out_channels, conv.out_channels,
...@@ -113,7 +113,7 @@ def fuse_conv_and_bn(conv, bn): ...@@ -113,7 +113,7 @@ def fuse_conv_and_bn(conv, bn):
stride=conv.stride, stride=conv.stride,
padding=conv.padding, padding=conv.padding,
groups=conv.groups, groups=conv.groups,
bias=True).to(conv.weight.device) bias=True).requires_grad_(False).to(conv.weight.device)
# prepare filters # prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1) w_conv = conv.weight.clone().view(conv.out_channels, -1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论