Unverified 提交 5e970d45 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update train.py (#462)

上级 999804fe
...@@ -123,9 +123,12 @@ def train(hyp, tb_writer, opt, device): ...@@ -123,9 +123,12 @@ def train(hyp, tb_writer, opt, device):
# load model # load model
try: try:
exclude = ['anchor'] # exclude keys
ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items() ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items()
if k in model.state_dict() and model.state_dict()[k].shape == v.shape} if k in model.state_dict() and not any(x in k for x in exclude)
and model.state_dict()[k].shape == v.shape}
model.load_state_dict(ckpt['model'], strict=False) model.load_state_dict(ckpt['model'], strict=False)
print('Transferred %g/%g items from %s' % (len(ckpt['model']), len(model.state_dict()), weights))
except KeyError as e: except KeyError as e:
s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \ s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \
"Please delete or update %s and try again, or use --weights '' to train from scratch." \ "Please delete or update %s and try again, or use --weights '' to train from scratch." \
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论