Unverified 提交 ada90e39 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Profile() feature addition (#1673)

* Profile() feature addition * cleanup
上级 94a7f55c
# PyTorch utils
import logging
import math
import os
import time
from contextlib import contextmanager
from copy import deepcopy
import math
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torchvision
try:
import thop # for FLOPS computation
except ImportError:
thop = None
logger = logging.getLogger(__name__)
......@@ -66,10 +70,45 @@ def select_device(device='', batch_size=None):
def time_synchronized():
# pytorch-accurate time
torch.cuda.synchronize() if torch.cuda.is_available() else None
return time.time()
def profile(x, ops, n=100, device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')):
# profile a pytorch module or list of modules. Example usage:
# x = torch.randn(16, 3, 640, 640) # input
# m1 = lambda x: x * torch.sigmoid(x)
# m2 = nn.SiLU()
# profile(x, [m1, m2], n=100) # profile speed over 100 iterations
x = x.to(device)
x.requires_grad = True
print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
print(f"\n{'Params':>12s}{'FLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
for m in ops if isinstance(ops, list) else [ops]:
m = m.to(device) if hasattr(m, 'to') else m
dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
try:
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPS
except:
flops = 0
for _ in range(n):
t[0] = time_synchronized()
y = m(x)
t[1] = time_synchronized()
_ = y.sum().backward()
t[2] = time_synchronized()
dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
print(f'{p:12.4g}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}')
def is_parallel(model):
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论