提交 a5867519 authored 作者: Glenn Jocher's avatar Glenn Jocher

multi-gpu ckpt filesize bug fix #253

上级 5de4e25d
...@@ -287,7 +287,7 @@ def train(hyp): ...@@ -287,7 +287,7 @@ def train(hyp):
scheduler.step() scheduler.step()
# mAP # mAP
ema.update_attr(model) ema.update_attr(model, include=['md', 'nc', 'hyp', 'names', 'stride'])
final_epoch = epoch + 1 == epochs final_epoch = epoch + 1 == epochs
if not opt.notest or final_epoch: # Calculate mAP if not opt.notest or final_epoch: # Calculate mAP
results, maps, times = test.test(opt.data, results, maps, times = test.test(opt.data,
......
...@@ -173,22 +173,23 @@ def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio ...@@ -173,22 +173,23 @@ def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
def copy_attr(a, b, include=(), exclude=()):
# Copy attributes from b to a, options to only include [...] and to exclude [...]
for k, v in b.__dict__.items():
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
continue
else:
setattr(a, k, v)
class ModelEMA: class ModelEMA:
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
Keep a moving average of everything in the model state_dict (parameters and buffers). Keep a moving average of everything in the model state_dict (parameters and buffers).
This is intended to allow functionality like This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well. A smoothed version of the weights is necessary for some training schemes to perform well.
E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
smoothing of weights to match results. Pay attention to the decay constant you are using
relative to your update count per epoch.
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
disable validation of the EMA weights. Validation will have to be done manually in a separate
process, or after the training stops converging.
This class is sensitive where it is initialized in the sequence of model init, This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers. GPU assignment and distributed training wrappers.
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
""" """
def __init__(self, model, decay=0.9999, updates=0): def __init__(self, model, decay=0.9999, updates=0):
...@@ -211,8 +212,6 @@ class ModelEMA: ...@@ -211,8 +212,6 @@ class ModelEMA:
v *= d v *= d
v += (1. - d) * msd[k].detach() v += (1. - d) * msd[k].detach()
def update_attr(self, model): def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
# Update EMA attributes # Update EMA attributes
for k, v in model.__dict__.items(): copy_attr(self.ema, model, include, exclude)
if not k.startswith('_') and k not in ["process_group", "reducer"]:
setattr(self.ema, k, v)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论