Unverified 提交 c6c88dc6 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Copy-Paste augmentation for YOLOv5 (#3845)

* Copy-paste augmentation initial commit * if any segments * Add obscuration rejection * Add copy_paste hyperparameter * Update comments
上级 25d1f293
...@@ -36,3 +36,4 @@ flipud: 0.00856 ...@@ -36,3 +36,4 @@ flipud: 0.00856
fliplr: 0.5 fliplr: 0.5
mosaic: 1.0 mosaic: 1.0
mixup: 0.243 mixup: 0.243
copy_paste: 0.0
...@@ -26,3 +26,4 @@ flipud: 0.0 ...@@ -26,3 +26,4 @@ flipud: 0.0
fliplr: 0.5 fliplr: 0.5
mosaic: 1.0 mosaic: 1.0
mixup: 0.0 mixup: 0.0
copy_paste: 0.0
...@@ -31,3 +31,4 @@ flipud: 0.0 # image flip up-down (probability) ...@@ -31,3 +31,4 @@ flipud: 0.0 # image flip up-down (probability)
fliplr: 0.5 # image flip left-right (probability) fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability) mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability) mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability)
...@@ -31,3 +31,4 @@ flipud: 0.0 # image flip up-down (probability) ...@@ -31,3 +31,4 @@ flipud: 0.0 # image flip up-down (probability)
fliplr: 0.5 # image flip left-right (probability) fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability) mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability) mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability)
...@@ -6,7 +6,6 @@ Usage: ...@@ -6,7 +6,6 @@ Usage:
import argparse import argparse
import logging import logging
import math
import os import os
import random import random
import sys import sys
...@@ -16,6 +15,7 @@ from copy import deepcopy ...@@ -16,6 +15,7 @@ from copy import deepcopy
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread
import math
import numpy as np import numpy as np
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
...@@ -591,7 +591,8 @@ def main(opt): ...@@ -591,7 +591,8 @@ def main(opt):
'flipud': (1, 0.0, 1.0), # image flip up-down (probability) 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
'fliplr': (0, 0.0, 1.0), # image flip left-right (probability) 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
'mosaic': (1, 0.0, 1.0), # image mixup (probability) 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
'mixup': (1, 0.0, 1.0)} # image mixup (probability) 'mixup': (1, 0.0, 1.0), # image mixup (probability)
'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
with open(opt.hyp) as f: with open(opt.hyp) as f:
hyp = yaml.safe_load(f) # load hyps dict hyp = yaml.safe_load(f) # load hyps dict
......
...@@ -25,6 +25,7 @@ from tqdm import tqdm ...@@ -25,6 +25,7 @@ from tqdm import tqdm
from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \ from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \
xyn2xy, segment2box, segments2boxes, resample_segments, clean_str xyn2xy, segment2box, segments2boxes, resample_segments, clean_str
from utils.metrics import bbox_ioa
from utils.torch_utils import torch_distributed_zero_first from utils.torch_utils import torch_distributed_zero_first
# Parameters # Parameters
...@@ -683,6 +684,7 @@ def load_mosaic(self, index): ...@@ -683,6 +684,7 @@ def load_mosaic(self, index):
# img4, labels4 = replicate(img4, labels4) # replicate # img4, labels4 = replicate(img4, labels4) # replicate
# Augment # Augment
img4, labels4, segments4 = copy_paste(img4, labels4, segments4, probability=self.hyp['copy_paste'])
img4, labels4 = random_perspective(img4, labels4, segments4, img4, labels4 = random_perspective(img4, labels4, segments4,
degrees=self.hyp['degrees'], degrees=self.hyp['degrees'],
translate=self.hyp['translate'], translate=self.hyp['translate'],
...@@ -907,6 +909,30 @@ def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, s ...@@ -907,6 +909,30 @@ def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, s
return img, targets return img, targets
def copy_paste(img, labels, segments, probability=0.5):
# Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
n = len(segments)
if probability and n:
h, w, c = img.shape # height, width, channels
im_new = np.zeros(img.shape, np.uint8)
for j in random.sample(range(n), k=round(probability * n)):
l, s = labels[j], segments[j]
box = w - l[3], l[2], w - l[1], l[4]
ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
if (ioa < 0.30).all(): # allow 30% obscuration of existing labels
labels = np.concatenate((labels, [[l[0], *box]]), 0)
segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1))
cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED)
result = cv2.bitwise_and(src1=img, src2=im_new)
result = cv2.flip(result, 1) # augment segments (flip left-right)
i = result > 0 # pixels to replace
# i[:, :] = result.max(2).reshape(h, w, 1) # act over ch
img[i] = result[i] # cv2.imwrite('debug.jpg', img) # debug
return img, labels, segments
def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n) def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
# Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
w1, h1 = box1[2] - box1[0], box1[3] - box1[1] w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
...@@ -919,24 +945,6 @@ def cutout(image, labels): ...@@ -919,24 +945,6 @@ def cutout(image, labels):
# Applies image cutout augmentation https://arxiv.org/abs/1708.04552 # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
h, w = image.shape[:2] h, w = image.shape[:2]
def bbox_ioa(box1, box2):
# Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
box2 = box2.transpose()
# Get the coordinates of bounding boxes
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
# Intersection area
inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
(np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
# box2 area
box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
# Intersection over box2 area
return inter_area / box2_area
# create random masks # create random masks
scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
for s in scales: for s in scales:
......
# Model validation metrics # Model validation metrics
import math
import warnings import warnings
from pathlib import Path from pathlib import Path
import math
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
...@@ -253,6 +253,30 @@ def box_iou(box1, box2): ...@@ -253,6 +253,30 @@ def box_iou(box1, box2):
return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter) return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
def bbox_ioa(box1, box2, eps=1E-7):
""" Returns the intersection over box2 area given box1, box2. Boxes are x1y1x2y2
box1: np.array of shape(4)
box2: np.array of shape(nx4)
returns: np.array of shape(n)
"""
box2 = box2.transpose()
# Get the coordinates of bounding boxes
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
# Intersection area
inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
(np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
# box2 area
box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + eps
# Intersection over box2 area
return inter_area / box2_area
def wh_iou(wh1, wh2): def wh_iou(wh1, wh2):
# Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2 # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
wh1 = wh1[:, None] # [N,1,2] wh1 = wh1[:, None] # [N,1,2]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论