Unverified 提交 8ae81a6c authored 作者: Junjie Zhang's avatar Junjie Zhang 提交者: GitHub

Fix cutout bug (#9452)

* fix cutout bug Signed-off-by: 's avatarJunjie Zhang <46258221+Oswells@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ciSigned-off-by: 's avatarJunjie Zhang <46258221+Oswells@users.noreply.github.com> Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 e8a9c5ae
...@@ -12,7 +12,7 @@ import torch ...@@ -12,7 +12,7 @@ import torch
import torchvision.transforms as T import torchvision.transforms as T
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
from utils.general import LOGGER, check_version, colorstr, resample_segments, segment2box from utils.general import LOGGER, check_version, colorstr, resample_segments, segment2box, xywhn2xyxy
from utils.metrics import bbox_ioa from utils.metrics import bbox_ioa
IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean
...@@ -281,7 +281,7 @@ def cutout(im, labels, p=0.5): ...@@ -281,7 +281,7 @@ def cutout(im, labels, p=0.5):
# return unobscured labels # return unobscured labels
if len(labels) and s > 0.03: if len(labels) and s > 0.03:
box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32) box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area ioa = bbox_ioa(box, xywhn2xyxy(labels[:, 1:5], w, h)) # intersection over area
labels = labels[ioa < 0.60] # remove >60% obscured labels labels = labels[ioa < 0.60] # remove >60% obscured labels
return labels return labels
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论