Unverified 提交 a34b376d authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Link fuse() to AutoShape() for Hub models (#8599)

上级 6e86af3d
......@@ -36,7 +36,6 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
if not verbose:
LOGGER.setLevel(logging.WARNING)
check_requirements(exclude=('tensorboard', 'thop', 'opencv-python'))
name = Path(name)
path = name.with_suffix('.pt') if name.suffix == '' and not name.is_dir() else name # checkpoint path
......@@ -44,7 +43,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
device = select_device(device)
if pretrained and channels == 3 and classes == 80:
model = DetectMultiBackend(path, device=device) # download/load FP32 model
model = DetectMultiBackend(path, device=device, fuse=autoshape) # download/load FP32 model
# model = models.experimental.attempt_load(path, map_location=device) # download/load FP32 model
else:
cfg = list((Path(__file__).parent / 'models').rglob(f'{path.stem}.yaml'))[0] # model.yaml path
......
......@@ -305,7 +305,7 @@ class Concat(nn.Module):
class DetectMultiBackend(nn.Module):
# YOLOv5 MultiBackend class for python inference on various backends
def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False):
def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True):
# Usage:
# PyTorch: weights = *.pt
# TorchScript: *.torchscript
......@@ -331,7 +331,7 @@ class DetectMultiBackend(nn.Module):
names = yaml.safe_load(f)['names']
if pt: # PyTorch
model = attempt_load(weights if isinstance(weights, list) else w, device=device)
model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
stride = max(int(model.stride.max()), 32) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names
model.half() if fp16 else model.float()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论