Unverified 提交 f02481c7 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update torch_utils.py

上级 fc7c4272
...@@ -54,6 +54,11 @@ def time_synchronized(): ...@@ -54,6 +54,11 @@ def time_synchronized():
return time.time() return time.time()
def is_parallel(model):
# is model is parallel with DP or DDP
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
def initialize_weights(model): def initialize_weights(model):
for m in model.modules(): for m in model.modules():
t = type(m) t = type(m)
...@@ -111,8 +116,8 @@ def model_info(model, verbose=False): ...@@ -111,8 +116,8 @@ def model_info(model, verbose=False):
try: # FLOPS try: # FLOPS
from thop import profile from thop import profile
macs, _ = profile(model, inputs=(torch.zeros(1, 3, 480, 640),), verbose=False) flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, 64, 64),), verbose=False)[0] / 1E9 * 2
fs = ', %.1f GFLOPS' % (macs / 1E9 * 2) fs = ', %.1f GFLOPS' % (flops * 100) # 640x640 FLOPS
except: except:
fs = '' fs = ''
...@@ -185,7 +190,7 @@ class ModelEMA: ...@@ -185,7 +190,7 @@ class ModelEMA:
self.updates += 1 self.updates += 1
d = self.decay(self.updates) d = self.decay(self.updates)
with torch.no_grad(): with torch.no_grad():
if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel): if is_parallel(model):
msd, esd = model.module.state_dict(), self.ema.module.state_dict() msd, esd = model.module.state_dict(), self.ema.module.state_dict()
else: else:
msd, esd = model.state_dict(), self.ema.state_dict() msd, esd = model.state_dict(), self.ema.state_dict()
...@@ -196,7 +201,8 @@ class ModelEMA: ...@@ -196,7 +201,8 @@ class ModelEMA:
v += (1. - d) * msd[k].detach() v += (1. - d) * msd[k].detach()
def update_attr(self, model): def update_attr(self, model):
# Assign attributes (which may change during training) # Update class attributes
for k in model.__dict__.keys(): ema = self.ema.module if is_parallel(model) else self.ema
if not k.startswith('_'): for k, v in model.__dict__.items():
setattr(self.ema, k, getattr(model, k)) if not k.startswith('_') and k != 'module':
setattr(ema, k, v)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论