提交 5ba1de0c authored 作者: Glenn Jocher's avatar Glenn Jocher

update experimental.py with Ensemble() module

上级 38f5c1ad
......@@ -107,3 +107,15 @@ class MixConv2d(nn.Module):
def forward(self, x):
return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
class Ensemble(nn.ModuleList):
# Ensemble of models
def __init__(self):
super(Ensemble, self).__init__()
def forward(self, x, augment=False):
y = []
for module in self:
y.append(module(x, augment)[0])
return torch.cat(y, 1), None # ensembled inference output, train output
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论