Unverified 提交 72cad398 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Squeezenet reshape outputs fix (#10222)

上级 40bb8030
...@@ -82,7 +82,7 @@ def reshape_classifier_output(model, n=1000): ...@@ -82,7 +82,7 @@ def reshape_classifier_output(model, n=1000):
elif nn.Conv2d in types: elif nn.Conv2d in types:
i = types.index(nn.Conv2d) # nn.Conv2d index i = types.index(nn.Conv2d) # nn.Conv2d index
if m[i].out_channels != n: if m[i].out_channels != n:
m[i] = nn.Conv2d(m[i].in_channels, n, m[i].kernel_size, m[i].stride, bias=m[i].bias) m[i] = nn.Conv2d(m[i].in_channels, n, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
@contextmanager @contextmanager
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论