Unverified 提交 2787ad70 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Add segment line predictions (#9571)

* Add segment line predictions Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
上级 ee91dc9b
...@@ -42,9 +42,10 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative ...@@ -42,9 +42,10 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from models.common import DetectMultiBackend from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2, from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh) increment_path, non_max_suppression, print_args, scale_boxes, scale_segments,
strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box from utils.plots import Annotator, colors, save_one_box
from utils.segment.general import process_mask from utils.segment.general import masks2segments, process_mask
from utils.torch_utils import select_device, smart_inference_mode from utils.torch_utils import select_device, smart_inference_mode
...@@ -145,14 +146,16 @@ def run( ...@@ -145,14 +146,16 @@ def run(
save_path = str(save_dir / p.name) # im.jpg save_path = str(save_dir / p.name) # im.jpg
txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
s += '%gx%g ' % im.shape[2:] # print string s += '%gx%g ' % im.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
imc = im0.copy() if save_crop else im0 # for save_crop imc = im0.copy() if save_crop else im0 # for save_crop
annotator = Annotator(im0, line_width=line_thickness, example=str(names)) annotator = Annotator(im0, line_width=line_thickness, example=str(names))
if len(det): if len(det):
masks = process_mask(proto[i], det[:, 6:], det[:, :4], im.shape[2:], upsample=True) # HWC masks = process_mask(proto[i], det[:, 6:], det[:, :4], im.shape[2:], upsample=True) # HWC
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() # rescale boxes to im0 size
# Rescale boxes from img_size to im0 size # Segments
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() if save_txt:
segments = reversed(masks2segments(masks))
segments = [scale_segments(im.shape[2:], x, im0.shape).round() for x in segments]
# Print results # Print results
for c in det[:, 5].unique(): for c in det[:, 5].unique():
...@@ -165,10 +168,10 @@ def run( ...@@ -165,10 +168,10 @@ def run(
im_gpu=None if retina_masks else im[i]) im_gpu=None if retina_masks else im[i])
# Write results # Write results
for *xyxy, conf, cls in reversed(det[:, :6]): for j, (*xyxy, conf, cls) in enumerate(reversed(det[:, :6])):
if save_txt: # Write to file if save_txt: # Write to file
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh segj = segments[j].reshape(-1) # (n,2) to (n*2)
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format line = (cls, *segj, conf) if save_conf else (cls, *segj) # label format
with open(f'{txt_path}.txt', 'a') as f: with open(f'{txt_path}.txt', 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n') f.write(('%g ' * len(line)).rstrip() % line + '\n')
...@@ -176,6 +179,7 @@ def run( ...@@ -176,6 +179,7 @@ def run(
c = int(cls) # integer class c = int(cls) # integer class
label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}') label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
annotator.box_label(xyxy, label, color=colors(c, True)) annotator.box_label(xyxy, label, color=colors(c, True))
annotator.draw.polygon(segments[j], outline=colors(c, True), width=3)
if save_crop: if save_crop:
save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True) save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
......
import cv2 import cv2
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -118,3 +119,16 @@ def masks_iou(mask1, mask2, eps=1e-7): ...@@ -118,3 +119,16 @@ def masks_iou(mask1, mask2, eps=1e-7):
intersection = (mask1 * mask2).sum(1).clamp(0) # (N, ) intersection = (mask1 * mask2).sum(1).clamp(0) # (N, )
union = (mask1.sum(1) + mask2.sum(1))[None] - intersection # (area1 + area2) - intersection union = (mask1.sum(1) + mask2.sum(1))[None] - intersection # (area1 + area2) - intersection
return intersection / (union + eps) return intersection / (union + eps)
def masks2segments(masks, strategy='largest'):
# Convert masks(n,160,160) into segments(n,xy)
segments = []
for x in masks.int().numpy().astype('uint8'):
c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
if strategy == 'concat': # concatenate all segments
c = np.concatenate([x.reshape(-1, 2) for x in c])
elif strategy == 'largest': # select largest segment
c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
segments.append(c.astype('float32'))
return segments
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论