Unverified 提交 c03d5903 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Add Hub results.pandas() method (#2725)

* Add Hub results.pandas() method New method converts results from torch tensors to pandas DataFrames with column names. This PR may partially resolve issue https://github.com/ultralytics/yolov5/issues/2703 ```python results = model(imgs) print(results.pandas().xyxy[0]) xmin ymin xmax ymax confidence class name 0 57.068970 391.770599 241.383545 905.797852 0.868964 0 person 1 667.661255 399.303589 810.000000 881.396667 0.851888 0 person 2 222.878387 414.774231 343.804474 857.825073 0.838376 0 person 3 4.205386 234.447678 803.739136 750.023376 0.658006 5 bus 4 0.000000 550.596008 76.681190 878.669922 0.450596 0 person ``` * Update comments torch example input now shown resized to size=640 and also now a multiple of P6 stride 64 (see https://github.com/ultralytics/yolov5/issues/2722#issuecomment-814785930) * apply decorators * PEP8 * Update common.py * pd.options.display.max_columns = 10 * Update common.py
上级 c8c8da60
# YOLOv5 common modules # YOLOv5 common modules
import math import math
from copy import copy
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import pandas as pd
import requests import requests
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from PIL import Image
from torch.cuda import amp
from utils.datasets import letterbox from utils.datasets import letterbox
from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh
...@@ -235,14 +236,16 @@ class autoShape(nn.Module): ...@@ -235,14 +236,16 @@ class autoShape(nn.Module):
print('autoShape already enabled, skipping... ') # model already converted to model.autoshape() print('autoShape already enabled, skipping... ') # model already converted to model.autoshape()
return self return self
@torch.no_grad()
@torch.cuda.amp.autocast()
def forward(self, imgs, size=640, augment=False, profile=False): def forward(self, imgs, size=640, augment=False, profile=False):
# Inference from various sources. For height=720, width=1280, RGB images example inputs are: # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
# filename: imgs = 'data/samples/zidane.jpg' # filename: imgs = 'data/samples/zidane.jpg'
# URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg' # URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3) # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
# PIL: = Image.open('image.jpg') # HWC x(720,1280,3) # PIL: = Image.open('image.jpg') # HWC x(640,1280,3)
# numpy: = np.zeros((720,1280,3)) # HWC # numpy: = np.zeros((640,1280,3)) # HWC
# torch: = torch.zeros(16,3,720,1280) # BCHW # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
t = [time_synchronized()] t = [time_synchronized()]
...@@ -275,7 +278,6 @@ class autoShape(nn.Module): ...@@ -275,7 +278,6 @@ class autoShape(nn.Module):
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32 x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
t.append(time_synchronized()) t.append(time_synchronized())
with torch.no_grad(), amp.autocast(enabled=p.device.type != 'cpu'):
# Inference # Inference
y = self.model(x, augment, profile)[0] # forward y = self.model(x, augment, profile)[0] # forward
t.append(time_synchronized()) t.append(time_synchronized())
...@@ -347,17 +349,27 @@ class Detections: ...@@ -347,17 +349,27 @@ class Detections:
self.display(render=True) # render results self.display(render=True) # render results
return self.imgs return self.imgs
def __len__(self): def pandas(self):
return self.n # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
new = copy(self) # return copy
ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
return new
def tolist(self): def tolist(self):
# return a list of Detections objects, i.e. 'for result in results.tolist():' # return a list of Detections objects, i.e. 'for result in results.tolist():'
x = [Detections([self.imgs[i]], [self.pred[i]], self.names) for i in range(self.n)] x = [Detections([self.imgs[i]], [self.pred[i]], self.names, self.s) for i in range(self.n)]
for d in x: for d in x:
for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']: for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
setattr(d, k, getattr(d, k)[0]) # pop out of list setattr(d, k, getattr(d, k)[0]) # pop out of list
return x return x
def __len__(self):
return self.n
class Classify(nn.Module): class Classify(nn.Module):
# Classification head, i.e. x(b,c1,20,20) to x(b,c2) # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
......
...@@ -13,6 +13,7 @@ from pathlib import Path ...@@ -13,6 +13,7 @@ from pathlib import Path
import cv2 import cv2
import numpy as np import numpy as np
import pandas as pd
import torch import torch
import torchvision import torchvision
import yaml import yaml
...@@ -24,6 +25,7 @@ from utils.torch_utils import init_torch_seeds ...@@ -24,6 +25,7 @@ from utils.torch_utils import init_torch_seeds
# Settings # Settings
torch.set_printoptions(linewidth=320, precision=5, profile='long') torch.set_printoptions(linewidth=320, precision=5, profile='long')
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5 np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
pd.options.display.max_columns = 10
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader) cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论