Unverified 提交 d8f18834 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update `profile()` for CUDA Memory allocation (#4239)

* Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Cleanup
上级 bceb57b9
...@@ -1172,11 +1172,11 @@ ...@@ -1172,11 +1172,11 @@
}, },
"source": [ "source": [
"# Profile\n", "# Profile\n",
"from utils.torch_utils import profile \n", "from utils.torch_utils import profile\n",
"\n", "\n",
"m1 = lambda x: x * torch.sigmoid(x)\n", "m1 = lambda x: x * torch.sigmoid(x)\n",
"m2 = torch.nn.SiLU()\n", "m2 = torch.nn.SiLU()\n",
"profile(x=torch.randn(16, 3, 640, 640), ops=[m1, m2], n=100)" "results = profile(input=torch.randn(16, 3, 640, 640), ops=[m1, m2], n=100)"
], ],
"execution_count": null, "execution_count": null,
"outputs": [] "outputs": []
......
...@@ -98,42 +98,56 @@ def time_sync(): ...@@ -98,42 +98,56 @@ def time_sync():
return time.time() return time.time()
def profile(x, ops, n=100, device=None): def profile(input, ops, n=10, device=None):
# profile a pytorch module or list of modules. Example usage: # YOLOv5 speed/memory/FLOPs profiler
# x = torch.randn(16, 3, 640, 640) # input #
# Usage:
# input = torch.randn(16, 3, 640, 640)
# m1 = lambda x: x * torch.sigmoid(x) # m1 = lambda x: x * torch.sigmoid(x)
# m2 = nn.SiLU() # m2 = nn.SiLU()
# profile(x, [m1, m2], n=100) # profile speed over 100 iterations # profile(input, [m1, m2], n=100) # profile over 100 iterations
results = []
device = device or select_device() device = device or select_device()
x = x.to(device) print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
x.requires_grad = True f"{'input':>24s}{'output':>24s}")
print(f"{'Params':>12s}{'GFLOPs':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
for m in ops if isinstance(ops, list) else [ops]: for x in input if isinstance(input, list) else [input]:
m = m.to(device) if hasattr(m, 'to') else m # device x = x.to(device)
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type x.requires_grad = True
dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward for m in ops if isinstance(ops, list) else [ops]:
try: m = m.to(device) if hasattr(m, 'to') else m # device
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
except: tf, tb, t = 0., 0., [0., 0., 0.] # dt forward, backward
flops = 0
for _ in range(n):
t[0] = time_sync()
y = m(x)
t[1] = time_sync()
try: try:
_ = y.sum().backward() flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
t[2] = time_sync() except:
except: # no backward method flops = 0
t[2] = float('nan')
dtf += (t[1] - t[0]) * 1000 / n # ms per op forward try:
dtb += (t[2] - t[1]) * 1000 / n # ms per op backward for _ in range(n):
t[0] = time_sync()
s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' y = m(x)
s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list' t[1] = time_sync()
p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters try:
print(f'{p:12}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}') _ = (sum([yi.sum() for yi in y]) if isinstance(y, list) else y).sum().backward()
t[2] = time_sync()
except Exception as e: # no backward method
print(e)
t[2] = float('nan')
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
results.append([p, flops, mem, tf, tb, s_in, s_out])
except Exception as e:
print(e)
results.append(None)
torch.cuda.empty_cache()
return results
def is_parallel(model): def is_parallel(model):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论