Unverified 提交 aff02819 authored 作者: bilzard's avatar bilzard 提交者: GitHub

Load checkpoint on CPU instead of on GPU (#6516)

* Load checkpoint on CPU instead of on GPU * refactor: simplify code * Cleanup * Update train.py Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 8fcdf3b6
...@@ -120,7 +120,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -120,7 +120,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if pretrained: if pretrained:
with torch_distributed_zero_first(LOCAL_RANK): with torch_distributed_zero_first(LOCAL_RANK):
weights = attempt_download(weights) # download if not found locally weights = attempt_download(weights) # download if not found locally
ckpt = torch.load(weights, map_location=device) # load checkpoint ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论