Unverified 提交 e3ff7806 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Allow PyTorch Hub results to display in notebooks (#9825)

* Allow PyTorch Hub results to display in notebooks * fix CI * fix CI * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI * fix CI * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI * fix CI * fix CI * fix CI * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI * fix CI * fix CI Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
上级 e42c89d4
...@@ -91,7 +91,7 @@ def run( ...@@ -91,7 +91,7 @@ def run(
# Dataloader # Dataloader
bs = 1 # batch_size bs = 1 # batch_size
if webcam: if webcam:
view_img = check_imshow() view_img = check_imshow(warn=True)
dataset = LoadStreams(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]), vid_stride=vid_stride) dataset = LoadStreams(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]), vid_stride=vid_stride)
bs = len(dataset) bs = len(dataset)
elif screenshot: elif screenshot:
......
...@@ -99,7 +99,7 @@ def run( ...@@ -99,7 +99,7 @@ def run(
# Dataloader # Dataloader
bs = 1 # batch_size bs = 1 # batch_size
if webcam: if webcam:
view_img = check_imshow() view_img = check_imshow(warn=True)
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = len(dataset) bs = len(dataset)
elif screenshot: elif screenshot:
......
...@@ -18,16 +18,20 @@ import pandas as pd ...@@ -18,16 +18,20 @@ import pandas as pd
import requests import requests
import torch import torch
import torch.nn as nn import torch.nn as nn
from IPython.display import display
from PIL import Image from PIL import Image
from torch.cuda import amp from torch.cuda import amp
from utils import TryExcept
from utils.dataloaders import exif_transpose, letterbox from utils.dataloaders import exif_transpose, letterbox
from utils.general import (LOGGER, ROOT, Profile, check_requirements, check_suffix, check_version, colorstr, from utils.general import (LOGGER, ROOT, Profile, check_imshow, check_requirements, check_suffix, check_version,
increment_path, make_divisible, non_max_suppression, scale_boxes, xywh2xyxy, xyxy2xywh, colorstr, increment_path, make_divisible, non_max_suppression, scale_boxes, xywh2xyxy,
yaml_load) xyxy2xywh, yaml_load)
from utils.plots import Annotator, colors, save_one_box from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import copy_attr, smart_inference_mode from utils.torch_utils import copy_attr, smart_inference_mode
CHECK_IMSHOW = check_imshow()
def autopad(k, p=None, d=1): # kernel, padding, dilation def autopad(k, p=None, d=1): # kernel, padding, dilation
# Pad to 'same' shape outputs # Pad to 'same' shape outputs
...@@ -756,7 +760,7 @@ class Detections: ...@@ -756,7 +760,7 @@ class Detections:
im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
if show: if show:
im.show(self.files[i]) # show im.show(self.files[i]) if CHECK_IMSHOW else display(im)
if save: if save:
f = self.files[i] f = self.files[i]
im.save(save_dir / f) # save im.save(save_dir / f) # save
...@@ -772,6 +776,7 @@ class Detections: ...@@ -772,6 +776,7 @@ class Detections:
LOGGER.info(f'Saved results to {save_dir}\n') LOGGER.info(f'Saved results to {save_dir}\n')
return crops return crops
@TryExcept('Showing images is not supported in this environment')
def show(self, labels=True): def show(self, labels=True):
self._run(show=True, labels=labels) # show results self._run(show=True, labels=labels) # show results
......
...@@ -102,7 +102,7 @@ def run( ...@@ -102,7 +102,7 @@ def run(
# Dataloader # Dataloader
bs = 1 # batch_size bs = 1 # batch_size
if webcam: if webcam:
view_img = check_imshow() view_img = check_imshow(warn=True)
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = len(dataset) bs = len(dataset)
elif screenshot: elif screenshot:
......
...@@ -23,7 +23,7 @@ class TryExcept(contextlib.ContextDecorator): ...@@ -23,7 +23,7 @@ class TryExcept(contextlib.ContextDecorator):
def __exit__(self, exc_type, value, traceback): def __exit__(self, exc_type, value, traceback):
if value: if value:
print(emojis(f'{self.msg}{value}')) print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}"))
return True return True
......
...@@ -26,7 +26,7 @@ def check_anchor_order(m): ...@@ -26,7 +26,7 @@ def check_anchor_order(m):
m.anchors[:] = m.anchors.flip(0) m.anchors[:] = m.anchors.flip(0)
@TryExcept(f'{PREFIX}ERROR: ') @TryExcept(f'{PREFIX}ERROR')
def check_anchors(dataset, model, thr=4.0, imgsz=640): def check_anchors(dataset, model, thr=4.0, imgsz=640):
# Check anchor fit to data, recompute if necessary # Check anchor fit to data, recompute if necessary
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect() m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
......
...@@ -27,6 +27,7 @@ from typing import Optional ...@@ -27,6 +27,7 @@ from typing import Optional
from zipfile import ZipFile from zipfile import ZipFile
import cv2 import cv2
import IPython
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import pkg_resources as pkg import pkg_resources as pkg
...@@ -73,6 +74,12 @@ def is_colab(): ...@@ -73,6 +74,12 @@ def is_colab():
return 'COLAB_GPU' in os.environ return 'COLAB_GPU' in os.environ
def is_notebook():
# Is environment a Jupyter notebook? Verified on Colab, Jupyterlab, Kaggle, Paperspace
ipython_type = str(type(IPython.get_ipython()))
return 'colab' in ipython_type or 'zmqshell' in ipython_type
def is_kaggle(): def is_kaggle():
# Is environment a Kaggle Notebook? # Is environment a Kaggle Notebook?
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com' return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
...@@ -383,18 +390,20 @@ def check_img_size(imgsz, s=32, floor=0): ...@@ -383,18 +390,20 @@ def check_img_size(imgsz, s=32, floor=0):
return new_size return new_size
def check_imshow(): def check_imshow(warn=False):
# Check if environment supports image displays # Check if environment supports image displays
try: try:
assert not is_docker(), 'cv2.imshow() is disabled in Docker environments' assert not is_notebook()
assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments' assert not is_docker()
assert 'NoneType' not in str(type(IPython.get_ipython())) # SSH terminals, GitHub CI
cv2.imshow('test', np.zeros((1, 1, 3))) cv2.imshow('test', np.zeros((1, 1, 3)))
cv2.waitKey(1) cv2.waitKey(1)
cv2.destroyAllWindows() cv2.destroyAllWindows()
cv2.waitKey(1) cv2.waitKey(1)
return True return True
except Exception as e: except Exception as e:
LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}') if warn:
LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
return False return False
......
...@@ -186,7 +186,7 @@ class ConfusionMatrix: ...@@ -186,7 +186,7 @@ class ConfusionMatrix:
# fn = self.matrix.sum(0) - tp # false negatives (missed detections) # fn = self.matrix.sum(0) - tp # false negatives (missed detections)
return tp[:-1], fp[:-1] # remove background class return tp[:-1], fp[:-1] # remove background class
@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure: ') @TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
def plot(self, normalize=True, save_dir='', names=()): def plot(self, normalize=True, save_dir='', names=()):
import seaborn as sn import seaborn as sn
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论