Unverified 提交 5bdb28ed authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Default PyTorch Hub to `autocast(False)` (#5926)

上级 c77a5a84
...@@ -443,6 +443,7 @@ class AutoShape(nn.Module): ...@@ -443,6 +443,7 @@ class AutoShape(nn.Module):
multi_label = False # NMS multiple labels per box multi_label = False # NMS multiple labels per box
classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
max_det = 1000 # maximum number of detections per image max_det = 1000 # maximum number of detections per image
amp = False # Automatic Mixed Precision (AMP) inference
def __init__(self, model): def __init__(self, model):
super().__init__() super().__init__()
...@@ -476,8 +477,9 @@ class AutoShape(nn.Module): ...@@ -476,8 +477,9 @@ class AutoShape(nn.Module):
t = [time_sync()] t = [time_sync()]
p = next(self.model.parameters()) if self.pt else torch.zeros(1) # for device and type p = next(self.model.parameters()) if self.pt else torch.zeros(1) # for device and type
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
if isinstance(imgs, torch.Tensor): # torch if isinstance(imgs, torch.Tensor): # torch
with amp.autocast(enabled=p.device.type != 'cpu'): with amp.autocast(enabled=autocast):
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
# Pre-process # Pre-process
...@@ -506,7 +508,7 @@ class AutoShape(nn.Module): ...@@ -506,7 +508,7 @@ class AutoShape(nn.Module):
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32 x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
t.append(time_sync()) t.append(time_sync())
with amp.autocast(enabled=p.device.type != 'cpu'): with amp.autocast(enabled=autocast):
# Inference # Inference
y = self.model(x, augment, profile) # forward y = self.model(x, augment, profile) # forward
t.append(time_sync()) t.append(time_sync())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论