Unverified 提交 b6fdd2e5 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Create `dataset_stats()` for HUB

上级 ac8691e2
...@@ -17,12 +17,13 @@ import cv2 ...@@ -17,12 +17,13 @@ import cv2
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import yaml
from PIL import Image, ExifTags from PIL import Image, ExifTags
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm from tqdm import tqdm
from utils.general import check_requirements, xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, segment2box, segments2boxes, \ from utils.general import check_requirements, check_file, check_dataset, xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, \
resample_segments, clean_str segment2box, segments2boxes, resample_segments, clean_str
from utils.torch_utils import torch_distributed_zero_first from utils.torch_utils import torch_distributed_zero_first
# Parameters # Parameters
...@@ -1083,3 +1084,34 @@ def verify_image_label(params): ...@@ -1083,3 +1084,34 @@ def verify_image_label(params):
nc = 1 nc = 1
logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}') logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
return [None] * 4 + [nm, nf, ne, nc] return [None] * 4 + [nm, nf, ne, nc]
def dataset_stats(path='data/coco128.yaml', verbose=False):
""" Return dataset statistics dictionary with images and instances counts per split per class
Usage: from utils.datasets import *; dataset_stats('data/coco128.yaml')
Arguments
path: Path to data.yaml
verbose: Print stats dictionary
"""
path = check_file(Path(path))
with open(path) as f:
data = yaml.safe_load(f) # data dict
check_dataset(data) # download dataset if missing
nc = data['nc'] # number of classes
stats = {'nc': nc, 'names': data['names']} # statistics dictionary
for split in 'train', 'val', 'test':
if split not in data:
stats[split] = None # i.e. no test set
continue
x = []
dataset = LoadImagesAndLabels(data[split], augment=False, rect=True) # load dataset
for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
x.append(np.bincount(label[:, 0].astype(int), minlength=nc))
x = np.array(x) # shape(128x80)
stats[split] = {'instances': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()},
'images': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
'per_class': (x > 0).sum(0).tolist()}}
if verbose:
print(yaml.dump([stats], sort_keys=False, default_flow_style=False))
return stats
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论