Unverified 提交 d8f5fcfe authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Improved FLOPS computation (#1398)

* Improved FLOPS computation * update comment
上级 0c26c4e8
...@@ -192,8 +192,8 @@ class Model(nn.Module): ...@@ -192,8 +192,8 @@ class Model(nn.Module):
copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
return m return m
def info(self, verbose=False): # print model information def info(self, verbose=False, img_size=640): # print model information
model_info(self, verbose) model_info(self, verbose, img_size)
def parse_model(d, ch): # model_dict, input_channels(3) def parse_model(d, ch): # model_dict, input_channels(3)
......
...@@ -139,8 +139,8 @@ def fuse_conv_and_bn(conv, bn): ...@@ -139,8 +139,8 @@ def fuse_conv_and_bn(conv, bn):
return fusedconv return fusedconv
def model_info(model, verbose=False): def model_info(model, verbose=False, img_size=640):
# Plots a line-by-line description of a PyTorch model # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
n_p = sum(x.numel() for x in model.parameters()) # number parameters n_p = sum(x.numel() for x in model.parameters()) # number parameters
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
if verbose: if verbose:
...@@ -152,8 +152,10 @@ def model_info(model, verbose=False): ...@@ -152,8 +152,10 @@ def model_info(model, verbose=False):
try: # FLOPS try: # FLOPS
from thop import profile from thop import profile
flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, 64, 64),), verbose=False)[0] / 1E9 * 2 stride = int(model.stride.max())
fs = ', %.1f GFLOPS' % (flops * 100) # 640x640 FLOPS flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, stride, stride),), verbose=False)[0] / 1E9 * 2
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS
except ImportError: except ImportError:
fs = '' fs = ''
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论