Unverified 提交 33202b7f authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

YOLOv5 + Albumentations integration (#3882)

* Albumentations integration * ToGray p=0.01 * print confirmation * create instance in dataloader init method * improved version handling * transform not defined fix * assert string update * create check_version() * add spaces * update class comment
上级 6a3ee7cf
...@@ -27,4 +27,5 @@ pandas ...@@ -27,4 +27,5 @@ pandas
# extras -------------------------------------- # extras --------------------------------------
# Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172 # Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172
# pycocotools>=2.0 # COCO mAP # pycocotools>=2.0 # COCO mAP
# albumentations>=1.0.0
thop # FLOPs computation thop # FLOPs computation
# YOLOv5 image augmentation functions # YOLOv5 image augmentation functions
import logging
import random import random
import cv2 import cv2
import math import math
import numpy as np import numpy as np
from utils.general import segment2box, resample_segments from utils.general import colorstr, segment2box, resample_segments, check_version
from utils.metrics import bbox_ioa from utils.metrics import bbox_ioa
class Albumentations:
# YOLOv5 Albumentations class (optional, only used if package is installed)
def __init__(self):
self.transform = None
try:
import albumentations as A
check_version(A.__version__, '1.0.0') # version requirement
self.transform = A.Compose([
A.Blur(p=0.1),
A.MedianBlur(p=0.1),
A.ToGray(p=0.01)],
bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
logging.info(colorstr('albumentations: ') + ', '.join(f'{x}' for x in self.transform.transforms))
except ImportError: # package not installed, skip
pass
except Exception as e:
logging.info(colorstr('albumentations: ') + f'{e}')
def __call__(self, im, labels, p=1.0):
if self.transform and random.random() < p:
new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0]) # transformed
im, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])])
return im, labels
def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5): def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
# HSV color-space augmentation # HSV color-space augmentation
if hgain or sgain or vgain: if hgain or sgain or vgain:
......
...@@ -22,7 +22,7 @@ from PIL import Image, ExifTags ...@@ -22,7 +22,7 @@ from PIL import Image, ExifTags
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm from tqdm import tqdm
from utils.augmentations import augment_hsv, copy_paste, letterbox, mixup, random_perspective from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \ from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \
xyn2xy, segments2boxes, clean_str xyn2xy, segments2boxes, clean_str
from utils.torch_utils import torch_distributed_zero_first from utils.torch_utils import torch_distributed_zero_first
...@@ -372,6 +372,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -372,6 +372,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
self.mosaic_border = [-img_size // 2, -img_size // 2] self.mosaic_border = [-img_size // 2, -img_size // 2]
self.stride = stride self.stride = stride
self.path = path self.path = path
self.albumentations = Albumentations() if augment else None
try: try:
f = [] # image files f = [] # image files
...@@ -540,8 +541,6 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -540,8 +541,6 @@ class LoadImagesAndLabels(Dataset): # for training/testing
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1]) labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
if self.augment: if self.augment:
# Augment imagespace
if not mosaic:
img, labels = random_perspective(img, labels, img, labels = random_perspective(img, labels,
degrees=hyp['degrees'], degrees=hyp['degrees'],
translate=hyp['translate'], translate=hyp['translate'],
...@@ -549,32 +548,35 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -549,32 +548,35 @@ class LoadImagesAndLabels(Dataset): # for training/testing
shear=hyp['shear'], shear=hyp['shear'],
perspective=hyp['perspective']) perspective=hyp['perspective'])
# Augment colorspace nl = len(labels) # number of labels
augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v']) if nl:
# Apply cutouts
# if random.random() < 0.9:
# labels = cutout(img, labels)
nL = len(labels) # number of labels
if nL:
labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0]) # xyxy to xywh normalized labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0]) # xyxy to xywh normalized
if self.augment: if self.augment:
# flip up-down # Albumentations
img, labels = self.albumentations(img, labels)
# HSV color-space
augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
# Flip up-down
if random.random() < hyp['flipud']: if random.random() < hyp['flipud']:
img = np.flipud(img) img = np.flipud(img)
if nL: if nl:
labels[:, 2] = 1 - labels[:, 2] labels[:, 2] = 1 - labels[:, 2]
# flip left-right # Flip left-right
if random.random() < hyp['fliplr']: if random.random() < hyp['fliplr']:
img = np.fliplr(img) img = np.fliplr(img)
if nL: if nl:
labels[:, 1] = 1 - labels[:, 1] labels[:, 1] = 1 - labels[:, 1]
labels_out = torch.zeros((nL, 6)) # Cutouts
if nL: # if random.random() < 0.9:
# labels = cutout(img, labels)
labels_out = torch.zeros((nl, 6))
if nl:
labels_out[:, 1:] = torch.from_numpy(labels) labels_out[:, 1:] = torch.from_numpy(labels)
# Convert # Convert
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import contextlib import contextlib
import glob import glob
import logging import logging
import math
import os import os
import platform import platform
import random import random
...@@ -17,6 +16,7 @@ from pathlib import Path ...@@ -17,6 +16,7 @@ from pathlib import Path
from subprocess import check_output from subprocess import check_output
import cv2 import cv2
import math
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import pkg_resources as pkg import pkg_resources as pkg
...@@ -136,13 +136,16 @@ def check_git_status(err_msg=', for updates see https://github.com/ultralytics/y ...@@ -136,13 +136,16 @@ def check_git_status(err_msg=', for updates see https://github.com/ultralytics/y
print(f'{e}{err_msg}') print(f'{e}{err_msg}')
def check_python(minimum='3.6.2', required=True): def check_python(minimum='3.6.2'):
# Check current python version vs. required python version # Check current python version vs. required python version
current = platform.python_version() check_version(platform.python_version(), minimum, name='Python ')
result = pkg.parse_version(current) >= pkg.parse_version(minimum)
if required:
assert result, f'Python {minimum} required by YOLOv5, but Python {current} is currently installed' def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False):
return result # Check version vs. required version
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
result = (current == minimum) if pinned else (current >= minimum)
assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'
def check_requirements(requirements='requirements.txt', exclude=()): def check_requirements(requirements='requirements.txt', exclude=()):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论