Unverified 提交 b412696f authored 作者: Laughing's avatar Laughing 提交者: GitHub

Fix & speed up segment plot (#10350)

* fix plot&&speed up * fix segment save-txt * fix channel * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ciCo-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
上级 7f5724ba
...@@ -46,7 +46,7 @@ from utils.general import (LOGGER, Profile, check_file, check_img_size, check_im ...@@ -46,7 +46,7 @@ from utils.general import (LOGGER, Profile, check_file, check_img_size, check_im
increment_path, non_max_suppression, print_args, scale_boxes, scale_segments, increment_path, non_max_suppression, print_args, scale_boxes, scale_segments,
strip_optimizer, xyxy2xywh) strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box from utils.plots import Annotator, colors, save_one_box
from utils.segment.general import masks2segments, process_mask from utils.segment.general import masks2segments, process_mask, process_mask_native
from utils.torch_utils import select_device, smart_inference_mode from utils.torch_utils import select_device, smart_inference_mode
...@@ -151,13 +151,20 @@ def run( ...@@ -151,13 +151,20 @@ def run(
imc = im0.copy() if save_crop else im0 # for save_crop imc = im0.copy() if save_crop else im0 # for save_crop
annotator = Annotator(im0, line_width=line_thickness, example=str(names)) annotator = Annotator(im0, line_width=line_thickness, example=str(names))
if len(det): if len(det):
masks = process_mask(proto[i], det[:, 6:], det[:, :4], im.shape[2:], upsample=True) # HWC if retina_masks:
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() # rescale boxes to im0 size # scale bbox first the crop masks
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() # rescale boxes to im0 size
masks = process_mask_native(proto[i], det[:, 6:], det[:, :4], im0.shape[:2]) # HWC
else:
masks = process_mask(proto[i], det[:, 6:], det[:, :4], im.shape[2:], upsample=True) # HWC
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() # rescale boxes to im0 size
# Segments # Segments
if save_txt: if save_txt:
segments = reversed(masks2segments(masks)) segments = reversed(masks2segments(masks))
segments = [scale_segments(im.shape[2:], x, im0.shape, normalize=True) for x in segments] segments = [
scale_segments(im0.shape if retina_masks else im.shape[2:], x, im0.shape, normalize=True)
for x in segments]
# Print results # Print results
for c in det[:, 5].unique(): for c in det[:, 5].unique():
...@@ -165,9 +172,9 @@ def run( ...@@ -165,9 +172,9 @@ def run(
s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
# Mask plotting # Mask plotting
annotator.masks(masks, plot_img = torch.as_tensor(im0, dtype=torch.float16).to(device).permute(2, 0, 1).flip(0).contiguous() / 255. \
colors=[colors(x, True) for x in det[:, 5]], if retina_masks else im[i]
im_gpu=None if retina_masks else im[i]) annotator.masks(masks, colors=[colors(x, True) for x in det[:, 5]], im_gpu=plot_img)
# Write results # Write results
for j, (*xyxy, conf, cls) in enumerate(reversed(det[:, :6])): for j, (*xyxy, conf, cls) in enumerate(reversed(det[:, :6])):
......
...@@ -114,7 +114,7 @@ class Annotator: ...@@ -114,7 +114,7 @@ class Annotator:
thickness=tf, thickness=tf,
lineType=cv2.LINE_AA) lineType=cv2.LINE_AA)
def masks(self, masks, colors, im_gpu=None, alpha=0.5): def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
"""Plot masks at once. """Plot masks at once.
Args: Args:
masks (tensor): predicted masks on cuda, shape: [n, h, w] masks (tensor): predicted masks on cuda, shape: [n, h, w]
...@@ -125,37 +125,21 @@ class Annotator: ...@@ -125,37 +125,21 @@ class Annotator:
if self.pil: if self.pil:
# convert to numpy first # convert to numpy first
self.im = np.asarray(self.im).copy() self.im = np.asarray(self.im).copy()
if im_gpu is None: if len(masks) == 0:
# Add multiple masks of shape(h,w,n) with colors list([r,g,b], [r,g,b], ...) self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
if len(masks) == 0: colors = torch.tensor(colors, device=im_gpu.device, dtype=torch.float32) / 255.0
return colors = colors[:, None, None] # shape(n,1,1,3)
if isinstance(masks, torch.Tensor): masks = masks.unsqueeze(3) # shape(n,h,w,1)
masks = torch.as_tensor(masks, dtype=torch.uint8) masks_color = masks * (colors * alpha) # shape(n,h,w,3)
masks = masks.permute(1, 2, 0).contiguous()
masks = masks.cpu().numpy() inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
# masks = np.ascontiguousarray(masks.transpose(1, 2, 0)) mcs = (masks_color * inv_alph_masks).sum(0) * 2 # mask color summand shape(n,h,w,3)
masks = scale_image(masks.shape[:2], masks, self.im.shape)
masks = np.asarray(masks, dtype=np.float32) im_gpu = im_gpu.flip(dims=[0]) # flip channel
colors = np.asarray(colors, dtype=np.float32) # shape(n,3) im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
s = masks.sum(2, keepdims=True).clip(0, 1) # add all masks together im_gpu = im_gpu * inv_alph_masks[-1] + mcs
masks = (masks @ colors).clip(0, 255) # (h,w,n) @ (n,3) = (h,w,3) im_mask = (im_gpu * 255).byte().cpu().numpy()
self.im[:] = masks * alpha + self.im * (1 - s * alpha) self.im[:] = im_mask if retina_masks else scale_image(im_gpu.shape, im_mask, self.im.shape)
else:
if len(masks) == 0:
self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
colors = torch.tensor(colors, device=im_gpu.device, dtype=torch.float32) / 255.0
colors = colors[:, None, None] # shape(n,1,1,3)
masks = masks.unsqueeze(3) # shape(n,h,w,1)
masks_color = masks * (colors * alpha) # shape(n,h,w,3)
inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
mcs = (masks_color * inv_alph_masks).sum(0) * 2 # mask color summand shape(n,h,w,3)
im_gpu = im_gpu.flip(dims=[0]) # flip channel
im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
im_gpu = im_gpu * inv_alph_masks[-1] + mcs
im_mask = (im_gpu * 255).byte().cpu().numpy()
self.im[:] = scale_image(im_gpu.shape, im_mask, self.im.shape)
if self.pil: if self.pil:
# convert im back to PIL and update draw # convert im back to PIL and update draw
self.fromarray(self.im) self.fromarray(self.im)
......
...@@ -67,6 +67,29 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False): ...@@ -67,6 +67,29 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
return masks.gt_(0.5) return masks.gt_(0.5)
def process_mask_native(protos, masks_in, bboxes, dst_shape):
"""
Crop after upsample.
proto_out: [mask_dim, mask_h, mask_w]
out_masks: [n, mask_dim], n is number of masks after nms
bboxes: [n, 4], n is number of masks after nms
shape:input_image_size, (h, w)
return: h, w, n
"""
c, mh, mw = protos.shape # CHW
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
gain = min(mh / dst_shape[0], mw / dst_shape[1]) # gain = old / new
pad = (mw - dst_shape[1] * gain) / 2, (mh - dst_shape[0] * gain) / 2 # wh padding
top, left = int(pad[1]), int(pad[0]) # y, x
bottom, right = int(mh - pad[1]), int(mw - pad[0])
masks = masks[:, top:bottom, left:right]
masks = F.interpolate(masks[None], dst_shape, mode='bilinear', align_corners=False)[0] # CHW
masks = crop_mask(masks, bboxes) # CHW
return masks.gt_(0.5)
def scale_image(im1_shape, masks, im0_shape, ratio_pad=None): def scale_image(im1_shape, masks, im0_shape, ratio_pad=None):
""" """
img1_shape: model input shape, [h, w] img1_shape: model input shape, [h, w]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论