Unverified 提交 a0e15046 authored 作者: Jebastin Nadar's avatar Jebastin Nadar 提交者: GitHub

Fix different devices bug when moving model from GPU to CPU (#5110)

* fix different devices bug * extend _apply() instead of to() for a general fix * Only apply if Detect() is last layer Co-authored-by: 's avatarJebastin Nadar <njebastin10@gmail.com> * Indent fix * Add comment to yolo.py * Add comment to common.py Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 4a6dfffd
......@@ -289,6 +289,14 @@ class AutoShape(nn.Module):
LOGGER.info('AutoShape already enabled, skipping... ') # model already converted to model.autoshape()
return self
def _apply(self, fn):
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
self = super()._apply(fn)
m = self.model.model[-1] # Detect()
m.stride = fn(m.stride)
m.grid = list(map(fn, m.grid))
return self
@torch.no_grad()
def forward(self, imgs, size=640, augment=False, profile=False):
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
......
......@@ -232,6 +232,15 @@ class Model(nn.Module):
def info(self, verbose=False, img_size=640): # print model information
model_info(self, verbose, img_size)
def _apply(self, fn):
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
self = super()._apply(fn)
m = self.model[-1] # Detect()
if isinstance(m, Detect):
m.stride = fn(m.stride)
m.grid = list(map(fn, m.grid))
return self
def parse_model(d, ch): # model_dict, input_channels(3)
LOGGER.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论