提交 a97c3f94 authored 作者: Glenn Jocher's avatar Glenn Jocher

update common.py Classify()

上级 2efa01db
...@@ -112,4 +112,5 @@ class Classify(nn.Module): ...@@ -112,4 +112,5 @@ class Classify(nn.Module):
self.flat = Flatten() self.flat = Flatten()
def forward(self, x): def forward(self, x):
return self.flat(self.conv(self.aap(x))) # flatten to x(b,c2) z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
return self.flat(self.conv(z)) # flatten to x(b,c2)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论