Unverified 提交 96a84468 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update labels_to_image_weights() (#1545)

上级 97a5227a
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
import glob import glob
import logging import logging
import math
import os import os
import platform import platform
import random import random
...@@ -12,7 +11,7 @@ import time ...@@ -12,7 +11,7 @@ import time
from pathlib import Path from pathlib import Path
import cv2 import cv2
import matplotlib import math
import numpy as np import numpy as np
import torch import torch
import torchvision import torchvision
...@@ -22,13 +21,10 @@ from utils.google_utils import gsutil_getsize ...@@ -22,13 +21,10 @@ from utils.google_utils import gsutil_getsize
from utils.metrics import fitness from utils.metrics import fitness
from utils.torch_utils import init_torch_seeds from utils.torch_utils import init_torch_seeds
# Set printoptions # Settings
torch.set_printoptions(linewidth=320, precision=5, profile='long') torch.set_printoptions(linewidth=320, precision=5, profile='long')
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5 np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
matplotlib.rc('font', **{'size': 11}) cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
# Prevent OpenCV from multithreading (to use PyTorch DataLoader)
cv2.setNumThreads(0)
def set_logging(rank=-1): def set_logging(rank=-1):
...@@ -121,9 +117,8 @@ def labels_to_class_weights(labels, nc=80): ...@@ -121,9 +117,8 @@ def labels_to_class_weights(labels, nc=80):
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)): def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
# Produces image weights based on class mAPs # Produces image weights based on class_weights and image contents
n = len(labels) class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
class_counts = np.array([np.bincount(labels[i][:, 0].astype(np.int), minlength=nc) for i in range(n)])
image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1) image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
# index = random.choices(range(n), weights=image_weights, k=1) # weight image sample # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
return image_weights return image_weights
......
...@@ -20,6 +20,7 @@ from utils.general import xywh2xyxy, xyxy2xywh ...@@ -20,6 +20,7 @@ from utils.general import xywh2xyxy, xyxy2xywh
from utils.metrics import fitness from utils.metrics import fitness
# Settings # Settings
matplotlib.rc('font', **{'size': 11})
matplotlib.use('Agg') # for writing to files only matplotlib.use('Agg') # for writing to files only
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论