Unverified 提交 e3ff7806 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Allow PyTorch Hub results to display in notebooks (#9825)

* Allow PyTorch Hub results to display in notebooks * fix CI * fix CI * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI * fix CI * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI * fix CI * fix CI * fix CI * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI * fix CI * fix CI Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
上级 e42c89d4
......@@ -91,7 +91,7 @@ def run(
# Dataloader
bs = 1 # batch_size
if webcam:
view_img = check_imshow()
view_img = check_imshow(warn=True)
dataset = LoadStreams(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]), vid_stride=vid_stride)
bs = len(dataset)
elif screenshot:
......
......@@ -99,7 +99,7 @@ def run(
# Dataloader
bs = 1 # batch_size
if webcam:
view_img = check_imshow()
view_img = check_imshow(warn=True)
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = len(dataset)
elif screenshot:
......
......@@ -18,16 +18,20 @@ import pandas as pd
import requests
import torch
import torch.nn as nn
from IPython.display import display
from PIL import Image
from torch.cuda import amp
from utils import TryExcept
from utils.dataloaders import exif_transpose, letterbox
from utils.general import (LOGGER, ROOT, Profile, check_requirements, check_suffix, check_version, colorstr,
increment_path, make_divisible, non_max_suppression, scale_boxes, xywh2xyxy, xyxy2xywh,
yaml_load)
from utils.general import (LOGGER, ROOT, Profile, check_imshow, check_requirements, check_suffix, check_version,
colorstr, increment_path, make_divisible, non_max_suppression, scale_boxes, xywh2xyxy,
xyxy2xywh, yaml_load)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import copy_attr, smart_inference_mode
CHECK_IMSHOW = check_imshow()
def autopad(k, p=None, d=1): # kernel, padding, dilation
# Pad to 'same' shape outputs
......@@ -756,7 +760,7 @@ class Detections:
im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
if show:
im.show(self.files[i]) # show
im.show(self.files[i]) if CHECK_IMSHOW else display(im)
if save:
f = self.files[i]
im.save(save_dir / f) # save
......@@ -772,6 +776,7 @@ class Detections:
LOGGER.info(f'Saved results to {save_dir}\n')
return crops
@TryExcept('Showing images is not supported in this environment')
def show(self, labels=True):
self._run(show=True, labels=labels) # show results
......
......@@ -102,7 +102,7 @@ def run(
# Dataloader
bs = 1 # batch_size
if webcam:
view_img = check_imshow()
view_img = check_imshow(warn=True)
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = len(dataset)
elif screenshot:
......
......@@ -23,7 +23,7 @@ class TryExcept(contextlib.ContextDecorator):
def __exit__(self, exc_type, value, traceback):
if value:
print(emojis(f'{self.msg}{value}'))
print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}"))
return True
......
......@@ -26,7 +26,7 @@ def check_anchor_order(m):
m.anchors[:] = m.anchors.flip(0)
@TryExcept(f'{PREFIX}ERROR: ')
@TryExcept(f'{PREFIX}ERROR')
def check_anchors(dataset, model, thr=4.0, imgsz=640):
# Check anchor fit to data, recompute if necessary
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
......
......@@ -27,6 +27,7 @@ from typing import Optional
from zipfile import ZipFile
import cv2
import IPython
import numpy as np
import pandas as pd
import pkg_resources as pkg
......@@ -73,6 +74,12 @@ def is_colab():
return 'COLAB_GPU' in os.environ
def is_notebook():
# Is environment a Jupyter notebook? Verified on Colab, Jupyterlab, Kaggle, Paperspace
ipython_type = str(type(IPython.get_ipython()))
return 'colab' in ipython_type or 'zmqshell' in ipython_type
def is_kaggle():
# Is environment a Kaggle Notebook?
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
......@@ -383,18 +390,20 @@ def check_img_size(imgsz, s=32, floor=0):
return new_size
def check_imshow():
def check_imshow(warn=False):
# Check if environment supports image displays
try:
assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
assert not is_notebook()
assert not is_docker()
assert 'NoneType' not in str(type(IPython.get_ipython())) # SSH terminals, GitHub CI
cv2.imshow('test', np.zeros((1, 1, 3)))
cv2.waitKey(1)
cv2.destroyAllWindows()
cv2.waitKey(1)
return True
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
if warn:
LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
return False
......
......@@ -186,7 +186,7 @@ class ConfusionMatrix:
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
return tp[:-1], fp[:-1] # remove background class
@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure: ')
@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
def plot(self, normalize=True, save_dir='', names=()):
import seaborn as sn
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论