Unverified 提交 6f028476 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update albumentations (#9503)

* Add `RandomResizedCrop(ratio)` * Update ratio * Update ratio * Update ratio * Update ratio * Update ratio * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Create augmentations.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update augmentations.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
上级 db684743
...@@ -21,7 +21,7 @@ IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation ...@@ -21,7 +21,7 @@ IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
class Albumentations: class Albumentations:
# YOLOv5 Albumentations class (optional, only used if package is installed) # YOLOv5 Albumentations class (optional, only used if package is installed)
def __init__(self): def __init__(self, size=640):
self.transform = None self.transform = None
prefix = colorstr('albumentations: ') prefix = colorstr('albumentations: ')
try: try:
...@@ -29,6 +29,7 @@ class Albumentations: ...@@ -29,6 +29,7 @@ class Albumentations:
check_version(A.__version__, '1.0.3', hard=True) # version requirement check_version(A.__version__, '1.0.3', hard=True) # version requirement
T = [ T = [
A.RandomResizedCrop(height=size, width=size, scale=(0.8, 1.0), ratio=(0.9, 1.11), p=0.0),
A.Blur(p=0.01), A.Blur(p=0.01),
A.MedianBlur(p=0.01), A.MedianBlur(p=0.01),
A.ToGray(p=0.01), A.ToGray(p=0.01),
...@@ -303,15 +304,17 @@ def box_candidates(box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): ...@@ -303,15 +304,17 @@ def box_candidates(box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16):
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
def classify_albumentations(augment=True, def classify_albumentations(
size=224, augment=True,
scale=(0.08, 1.0), size=224,
hflip=0.5, scale=(0.08, 1.0),
vflip=0.0, ratio=(0.75, 1.0 / 0.75), # 0.75, 1.33
jitter=0.4, hflip=0.5,
mean=IMAGENET_MEAN, vflip=0.0,
std=IMAGENET_STD, jitter=0.4,
auto_aug=False): mean=IMAGENET_MEAN,
std=IMAGENET_STD,
auto_aug=False):
# YOLOv5 classification Albumentations (optional, only used if package is installed) # YOLOv5 classification Albumentations (optional, only used if package is installed)
prefix = colorstr('albumentations: ') prefix = colorstr('albumentations: ')
try: try:
...@@ -319,7 +322,7 @@ def classify_albumentations(augment=True, ...@@ -319,7 +322,7 @@ def classify_albumentations(augment=True,
from albumentations.pytorch import ToTensorV2 from albumentations.pytorch import ToTensorV2
check_version(A.__version__, '1.0.3', hard=True) # version requirement check_version(A.__version__, '1.0.3', hard=True) # version requirement
if augment: # Resize and crop if augment: # Resize and crop
T = [A.RandomResizedCrop(height=size, width=size, scale=scale)] T = [A.RandomResizedCrop(height=size, width=size, scale=scale, ratio=ratio)]
if auto_aug: if auto_aug:
# TODO: implement AugMix, AutoAug & RandAug in albumentation # TODO: implement AugMix, AutoAug & RandAug in albumentation
LOGGER.info(f'{prefix}auto augmentations are currently not supported') LOGGER.info(f'{prefix}auto augmentations are currently not supported')
...@@ -338,7 +341,7 @@ def classify_albumentations(augment=True, ...@@ -338,7 +341,7 @@ def classify_albumentations(augment=True,
return A.Compose(T) return A.Compose(T)
except ImportError: # package not installed, skip except ImportError: # package not installed, skip
pass LOGGER.warning(f'{prefix}⚠️ not found, install with `pip install albumentations` (recommended)')
except Exception as e: except Exception as e:
LOGGER.info(f'{prefix}{e}') LOGGER.info(f'{prefix}{e}')
......
...@@ -404,7 +404,7 @@ class LoadImagesAndLabels(Dataset): ...@@ -404,7 +404,7 @@ class LoadImagesAndLabels(Dataset):
self.mosaic_border = [-img_size // 2, -img_size // 2] self.mosaic_border = [-img_size // 2, -img_size // 2]
self.stride = stride self.stride = stride
self.path = path self.path = path
self.albumentations = Albumentations() if augment else None self.albumentations = Albumentations(size=img_size) if augment else None
try: try:
f = [] # image files f = [] # image files
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论