Unverified 提交 bdd88e1e authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

YOLOv5 Segmentation Dataloader Updates (#2188)

* Update C3 module * Update C3 module * Update C3 module * Update C3 module * update * update * update * update * update * update * update * update * update * updates * updates * updates * updates * updates * updates * updates * updates * updates * updates * update * update * update * update * updates * updates * updates * updates * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update datasets * update * update * update * update attempt_downlaod() * merge * merge * update * update * update * update * update * update * update * update * update * update * parameterize eps * comments * gs-multiple * update * max_nms implemented * Create one_cycle() function * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * GitHub API rate limit fix * update * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * astuple * epochs * update * update * ComputeLoss() * update * update * update * update * update * update * update * update * update * update * update * merge * merge * merge * merge * update * update * update * update * commit=tag == tags[-1] * Update cudnn.benchmark * update * update * update * updates * updates * updates * updates * updates * updates * updates * update * update * update * update * update * mosaic9 * update * update * update * update * update * update * institute cache versioning * only display on existing cache * reverse cache exists booleans
上级 404749a3
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
# Download/unzip labels # Download/unzip labels
d='../' # unzip directory d='../' # unzip directory
url=https://github.com/ultralytics/yolov5/releases/download/v1.0/ url=https://github.com/ultralytics/yolov5/releases/download/v1.0/
f='coco2017labels.zip' # 68 MB f='coco2017labels.zip' # or 'coco2017labels-segments.zip', 68 MB
echo 'Downloading' $url$f ' ...' echo 'Downloading' $url$f ' ...'
curl -L $url$f -o $f && unzip -q $f -d $d && rm $f & # download, unzip, remove in background curl -L $url$f -o $f && unzip -q $f -d $d && rm $f & # download, unzip, remove in background
......
...@@ -20,7 +20,8 @@ from PIL import Image, ExifTags ...@@ -20,7 +20,8 @@ from PIL import Image, ExifTags
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm from tqdm import tqdm
from utils.general import xyxy2xywh, xywh2xyxy, xywhn2xyxy, clean_str from utils.general import xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, segment2box, segments2boxes, resample_segments, \
clean_str
from utils.torch_utils import torch_distributed_zero_first from utils.torch_utils import torch_distributed_zero_first
# Parameters # Parameters
...@@ -374,21 +375,23 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -374,21 +375,23 @@ class LoadImagesAndLabels(Dataset): # for training/testing
self.label_files = img2label_paths(self.img_files) # labels self.label_files = img2label_paths(self.img_files) # labels
cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') # cached labels cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') # cached labels
if cache_path.is_file(): if cache_path.is_file():
cache = torch.load(cache_path) # load cache, exists = torch.load(cache_path), True # load
if cache['hash'] != get_hash(self.label_files + self.img_files) or 'results' not in cache: # changed if cache['hash'] != get_hash(self.label_files + self.img_files) or 'version' not in cache: # changed
cache = self.cache_labels(cache_path, prefix) # re-cache cache, exists = self.cache_labels(cache_path, prefix), False # re-cache
else: else:
cache = self.cache_labels(cache_path, prefix) # cache cache, exists = self.cache_labels(cache_path, prefix), False # cache
# Display cache # Display cache
[nf, nm, ne, nc, n] = cache.pop('results') # found, missing, empty, corrupted, total nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
desc = f"Scanning '{cache_path}' for images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted" if exists:
tqdm(None, desc=prefix + desc, total=n, initial=n) d = f"Scanning '{cache_path}' for images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}' assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
# Read cache # Read cache
cache.pop('hash') # remove hash cache.pop('hash') # remove hash
labels, shapes = zip(*cache.values()) cache.pop('version') # remove version
labels, shapes, self.segments = zip(*cache.values())
self.labels = list(labels) self.labels = list(labels)
self.shapes = np.array(shapes, dtype=np.float64) self.shapes = np.array(shapes, dtype=np.float64)
self.img_files = list(cache.keys()) # update self.img_files = list(cache.keys()) # update
...@@ -451,6 +454,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -451,6 +454,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
im = Image.open(im_file) im = Image.open(im_file)
im.verify() # PIL verify im.verify() # PIL verify
shape = exif_size(im) # image size shape = exif_size(im) # image size
segments = [] # instance segments
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
assert im.format.lower() in img_formats, f'invalid image format {im.format}' assert im.format.lower() in img_formats, f'invalid image format {im.format}'
...@@ -458,7 +462,12 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -458,7 +462,12 @@ class LoadImagesAndLabels(Dataset): # for training/testing
if os.path.isfile(lb_file): if os.path.isfile(lb_file):
nf += 1 # label found nf += 1 # label found
with open(lb_file, 'r') as f: with open(lb_file, 'r') as f:
l = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels l = [x.split() for x in f.read().strip().splitlines()]
if any([len(x) > 8 for x in l]): # is segment
classes = np.array([x[0] for x in l], dtype=np.float32)
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...)
l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
l = np.array(l, dtype=np.float32)
if len(l): if len(l):
assert l.shape[1] == 5, 'labels require 5 columns each' assert l.shape[1] == 5, 'labels require 5 columns each'
assert (l >= 0).all(), 'negative labels' assert (l >= 0).all(), 'negative labels'
...@@ -470,7 +479,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -470,7 +479,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
else: else:
nm += 1 # label missing nm += 1 # label missing
l = np.zeros((0, 5), dtype=np.float32) l = np.zeros((0, 5), dtype=np.float32)
x[im_file] = [l, shape] x[im_file] = [l, shape, segments]
except Exception as e: except Exception as e:
nc += 1 nc += 1
print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}') print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
...@@ -482,7 +491,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -482,7 +491,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
print(f'{prefix}WARNING: No labels found in {path}. See {help_url}') print(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
x['hash'] = get_hash(self.label_files + self.img_files) x['hash'] = get_hash(self.label_files + self.img_files)
x['results'] = [nf, nm, ne, nc, i + 1] x['results'] = nf, nm, ne, nc, i + 1
x['version'] = 0.1 # cache version
torch.save(x, path) # save for next time torch.save(x, path) # save for next time
logging.info(f'{prefix}New cache created: {path}') logging.info(f'{prefix}New cache created: {path}')
return x return x
...@@ -652,7 +662,7 @@ def hist_equalize(img, clahe=True, bgr=False): ...@@ -652,7 +662,7 @@ def hist_equalize(img, clahe=True, bgr=False):
def load_mosaic(self, index): def load_mosaic(self, index):
# loads images in a 4-mosaic # loads images in a 4-mosaic
labels4 = [] labels4, segments4 = [], []
s = self.img_size s = self.img_size
yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
indices = [index] + [self.indices[random.randint(0, self.n - 1)] for _ in range(3)] # 3 additional image indices indices = [index] + [self.indices[random.randint(0, self.n - 1)] for _ in range(3)] # 3 additional image indices
...@@ -680,19 +690,21 @@ def load_mosaic(self, index): ...@@ -680,19 +690,21 @@ def load_mosaic(self, index):
padh = y1a - y1b padh = y1a - y1b
# Labels # Labels
labels = self.labels[index].copy() labels, segments = self.labels[index].copy(), self.segments[index].copy()
if labels.size: if labels.size:
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
labels4.append(labels) labels4.append(labels)
segments4.extend(segments)
# Concat/clip labels # Concat/clip labels
if len(labels4):
labels4 = np.concatenate(labels4, 0) labels4 = np.concatenate(labels4, 0)
np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_perspective for x in (labels4[:, 1:], *segments4):
np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
# img4, labels4 = replicate(img4, labels4) # replicate # img4, labels4 = replicate(img4, labels4) # replicate
# Augment # Augment
img4, labels4 = random_perspective(img4, labels4, img4, labels4 = random_perspective(img4, labels4, segments4,
degrees=self.hyp['degrees'], degrees=self.hyp['degrees'],
translate=self.hyp['translate'], translate=self.hyp['translate'],
scale=self.hyp['scale'], scale=self.hyp['scale'],
...@@ -706,7 +718,7 @@ def load_mosaic(self, index): ...@@ -706,7 +718,7 @@ def load_mosaic(self, index):
def load_mosaic9(self, index): def load_mosaic9(self, index):
# loads images in a 9-mosaic # loads images in a 9-mosaic
labels9 = [] labels9, segments9 = [], []
s = self.img_size s = self.img_size
indices = [index] + [self.indices[random.randint(0, self.n - 1)] for _ in range(8)] # 8 additional image indices indices = [index] + [self.indices[random.randint(0, self.n - 1)] for _ in range(8)] # 8 additional image indices
for i, index in enumerate(indices): for i, index in enumerate(indices):
...@@ -739,30 +751,34 @@ def load_mosaic9(self, index): ...@@ -739,30 +751,34 @@ def load_mosaic9(self, index):
x1, y1, x2, y2 = [max(x, 0) for x in c] # allocate coords x1, y1, x2, y2 = [max(x, 0) for x in c] # allocate coords
# Labels # Labels
labels = self.labels[index].copy() labels, segments = self.labels[index].copy(), self.segments[index].copy()
if labels.size: if labels.size:
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
labels9.append(labels) labels9.append(labels)
segments9.extend(segments)
# Image # Image
img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax] img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
hp, wp = h, w # height, width previous hp, wp = h, w # height, width previous
# Offset # Offset
yc, xc = [int(random.uniform(0, s)) for x in self.mosaic_border] # mosaic center x, y yc, xc = [int(random.uniform(0, s)) for _ in self.mosaic_border] # mosaic center x, y
img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s] img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
# Concat/clip labels # Concat/clip labels
if len(labels9):
labels9 = np.concatenate(labels9, 0) labels9 = np.concatenate(labels9, 0)
labels9[:, [1, 3]] -= xc labels9[:, [1, 3]] -= xc
labels9[:, [2, 4]] -= yc labels9[:, [2, 4]] -= yc
c = np.array([xc, yc]) # centers
segments9 = [x - c for x in segments9]
np.clip(labels9[:, 1:], 0, 2 * s, out=labels9[:, 1:]) # use with random_perspective for x in (labels9[:, 1:], *segments9):
np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
# img9, labels9 = replicate(img9, labels9) # replicate # img9, labels9 = replicate(img9, labels9) # replicate
# Augment # Augment
img9, labels9 = random_perspective(img9, labels9, img9, labels9 = random_perspective(img9, labels9, segments9,
degrees=self.hyp['degrees'], degrees=self.hyp['degrees'],
translate=self.hyp['translate'], translate=self.hyp['translate'],
scale=self.hyp['scale'], scale=self.hyp['scale'],
...@@ -823,7 +839,8 @@ def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale ...@@ -823,7 +839,8 @@ def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale
return img, ratio, (dw, dh) return img, ratio, (dw, dh)
def random_perspective(img, targets=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0, border=(0, 0)): def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0,
border=(0, 0)):
# torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10)) # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
# targets = [cls, xyxy] # targets = [cls, xyxy]
...@@ -875,37 +892,38 @@ def random_perspective(img, targets=(), degrees=10, translate=.1, scale=.1, shea ...@@ -875,37 +892,38 @@ def random_perspective(img, targets=(), degrees=10, translate=.1, scale=.1, shea
# Transform label coordinates # Transform label coordinates
n = len(targets) n = len(targets)
if n: if n:
# warp points use_segments = any(x.any() for x in segments)
new = np.zeros((n, 4))
if use_segments: # warp segments
segments = resample_segments(segments) # upsample
for i, segment in enumerate(segments):
xy = np.ones((len(segment), 3))
xy[:, :2] = segment
xy = xy @ M.T # transform
xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine
# clip
new[i] = segment2box(xy, width, height)
else: # warp boxes
xy = np.ones((n * 4, 3)) xy = np.ones((n * 4, 3))
xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1 xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
xy = xy @ M.T # transform xy = xy @ M.T # transform
if perspective: xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale
else: # affine
xy = xy[:, :2].reshape(n, 8)
# create new boxes # create new boxes
x = xy[:, [0, 2, 4, 6]] x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]] y = xy[:, [1, 3, 5, 7]]
xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
# # apply angle-based reduction of bounding boxes
# radians = a * math.pi / 180
# reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5
# x = (xy[:, 2] + xy[:, 0]) / 2
# y = (xy[:, 3] + xy[:, 1]) / 2
# w = (xy[:, 2] - xy[:, 0]) * reduction
# h = (xy[:, 3] - xy[:, 1]) * reduction
# xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
# clip boxes # clip
xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width) new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height) new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
# filter candidates # filter candidates
i = box_candidates(box1=targets[:, 1:5].T * s, box2=xy.T) i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10)
targets = targets[i] targets = targets[i]
targets[:, 1:5] = xy[i] targets[:, 1:5] = new[i]
return img, targets return img, targets
......
...@@ -225,7 +225,7 @@ def xywh2xyxy(x): ...@@ -225,7 +225,7 @@ def xywh2xyxy(x):
return y return y
def xywhn2xyxy(x, w=640, h=640, padw=32, padh=32): def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
# Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
...@@ -235,6 +235,40 @@ def xywhn2xyxy(x, w=640, h=640, padw=32, padh=32): ...@@ -235,6 +235,40 @@ def xywhn2xyxy(x, w=640, h=640, padw=32, padh=32):
return y return y
def xyn2xy(x, w=640, h=640, padw=0, padh=0):
# Convert normalized segments into pixel segments, shape (n,2)
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = w * x[:, 0] + padw # top left x
y[:, 1] = h * x[:, 1] + padh # top left y
return y
def segment2box(segment, width=640, height=640):
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
x, y = segment.T # segment xy
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
x, y, = x[inside], y[inside]
return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # cls, xyxy
def segments2boxes(segments):
# Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
boxes = []
for s in segments:
x, y = s.T # segment xy
boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
return xyxy2xywh(np.array(boxes)) # cls, xywh
def resample_segments(segments, n=1000):
# Up-sample an (n,2) segment
for i, s in enumerate(segments):
x = np.linspace(0, len(s) - 1, n)
xp = np.arange(len(s))
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
return segments
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
# Rescale coords (xyxy) from img1_shape to img0_shape # Rescale coords (xyxy) from img1_shape to img0_shape
if ratio_pad is None: # calculate from img0_shape if ratio_pad is None: # calculate from img0_shape
......
...@@ -105,7 +105,7 @@ class ComputeLoss: ...@@ -105,7 +105,7 @@ class ComputeLoss:
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
self.balance = {3: [3.67, 1.0, 0.43], 4: [4.0, 1.0, 0.25, 0.06], 5: [4.0, 1.0, 0.25, 0.06, .02]}[det.nl] self.balance = {3: [4.0, 1.0, 0.4], 4: [4.0, 1.0, 0.25, 0.06], 5: [4.0, 1.0, 0.25, 0.06, .02]}[det.nl]
self.ssi = (det.stride == 16).nonzero(as_tuple=False).item() # stride 16 index self.ssi = (det.stride == 16).nonzero(as_tuple=False).item() # stride 16 index
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance
for k in 'na', 'nc', 'nl', 'anchors': for k in 'na', 'nc', 'nl', 'anchors':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论