提交 c672bef1 authored 作者: Glenn Jocher's avatar Glenn Jocher

model fuse

上级 12b0c046
...@@ -21,6 +21,8 @@ def detect(save_img=False): ...@@ -21,6 +21,8 @@ def detect(save_img=False):
google_utils.attempt_download(weights) google_utils.attempt_download(weights)
model = torch.load(weights, map_location=device)['model'] model = torch.load(weights, map_location=device)['model']
# torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning # torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning
# model.fuse()
model.to(device).eval()
# Second-stage classifier # Second-stage classifier
classify = False classify = False
...@@ -29,12 +31,6 @@ def detect(save_img=False): ...@@ -29,12 +31,6 @@ def detect(save_img=False):
modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']) # load weights modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']) # load weights
modelc.to(device).eval() modelc.to(device).eval()
# Eval mode
model.to(device).eval()
# Fuse Conv2d + BatchNorm2d layers
# model.fuse()
# Half precision # Half precision
half = half and device.type != 'cpu' # half precision only supported on CUDA half = half and device.type != 'cpu' # half precision only supported on CUDA
if half: if half:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论