提交 8fe299f1 authored 作者: Glenn Jocher's avatar Glenn Jocher

model fuse

上级 c672bef1
...@@ -90,7 +90,7 @@ def fuse_conv_and_bn(conv, bn): ...@@ -90,7 +90,7 @@ def fuse_conv_and_bn(conv, bn):
if conv.bias is not None: if conv.bias is not None:
b_conv = conv.bias b_conv = conv.bias
else: else:
b_conv = torch.zeros(conv.weight.size(0)) b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device)
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论