Unverified 提交 d07ddc69 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

New TryExcept decorator (#9154)

* New TryExcept decorator * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ciCo-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
上级 f0e5a608
...@@ -3,6 +3,33 @@ ...@@ -3,6 +3,33 @@
utils/initialization utils/initialization
""" """
import contextlib
import threading
class TryExcept(contextlib.ContextDecorator):
# YOLOv5 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager
def __init__(self, msg='default message here'):
self.msg = msg
def __enter__(self):
pass
def __exit__(self, exc_type, value, traceback):
if value:
print(f'{self.msg}: {value}')
return True
def threaded(func):
# Multi-threads a target function and returns thread. Usage: @threaded decorator
def wrapper(*args, **kwargs):
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread
return wrapper
def notebook_init(verbose=True): def notebook_init(verbose=True):
# Check system software and hardware # Check system software and hardware
......
...@@ -15,7 +15,6 @@ import re ...@@ -15,7 +15,6 @@ import re
import shutil import shutil
import signal import signal
import sys import sys
import threading
import time import time
import urllib import urllib
from datetime import datetime from datetime import datetime
...@@ -34,6 +33,7 @@ import torch ...@@ -34,6 +33,7 @@ import torch
import torchvision import torchvision
import yaml import yaml
from utils import TryExcept
from utils.downloads import gsutil_getsize from utils.downloads import gsutil_getsize
from utils.metrics import box_iou, fitness from utils.metrics import box_iou, fitness
...@@ -195,27 +195,6 @@ class WorkingDirectory(contextlib.ContextDecorator): ...@@ -195,27 +195,6 @@ class WorkingDirectory(contextlib.ContextDecorator):
os.chdir(self.cwd) os.chdir(self.cwd)
def try_except(func):
# try-except function. Usage: @try_except decorator
def handler(*args, **kwargs):
try:
func(*args, **kwargs)
except Exception as e:
print(e)
return handler
def threaded(func):
# Multi-threads a target function and returns thread. Usage: @threaded decorator
def wrapper(*args, **kwargs):
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread
return wrapper
def methods(instance): def methods(instance):
# Get class/instance methods # Get class/instance methods
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")] return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
...@@ -319,7 +298,7 @@ def git_describe(path=ROOT): # path must be a directory ...@@ -319,7 +298,7 @@ def git_describe(path=ROOT): # path must be a directory
return '' return ''
@try_except @TryExcept()
@WorkingDirectory(ROOT) @WorkingDirectory(ROOT)
def check_git_status(repo='ultralytics/yolov5'): def check_git_status(repo='ultralytics/yolov5'):
# YOLOv5 status check, recommend 'git pull' if code is out of date # YOLOv5 status check, recommend 'git pull' if code is out of date
...@@ -364,7 +343,7 @@ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=Fals ...@@ -364,7 +343,7 @@ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=Fals
return result return result
@try_except @TryExcept()
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()): def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
# Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages) # Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages)
prefix = colorstr('red', 'bold', 'requirements:') prefix = colorstr('red', 'bold', 'requirements:')
......
...@@ -11,6 +11,8 @@ import matplotlib.pyplot as plt ...@@ -11,6 +11,8 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
from utils import TryExcept, threaded
def fitness(x): def fitness(x):
# Model fitness as a weighted combination of metrics # Model fitness as a weighted combination of metrics
...@@ -184,36 +186,35 @@ class ConfusionMatrix: ...@@ -184,36 +186,35 @@ class ConfusionMatrix:
# fn = self.matrix.sum(0) - tp # false negatives (missed detections) # fn = self.matrix.sum(0) - tp # false negatives (missed detections)
return tp[:-1], fp[:-1] # remove background class return tp[:-1], fp[:-1] # remove background class
@TryExcept('WARNING: ConfusionMatrix plot failure')
def plot(self, normalize=True, save_dir='', names=()): def plot(self, normalize=True, save_dir='', names=()):
try: import seaborn as sn
import seaborn as sn
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
fig = plt.figure(figsize=(12, 9), tight_layout=True) nc, nn = self.nc, len(names) # number of classes, names
nc, nn = self.nc, len(names) # number of classes, names sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels with warnings.catch_warnings():
with warnings.catch_warnings(): warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered sn.heatmap(array,
sn.heatmap(array, ax=ax,
annot=nc < 30, annot=nc < 30,
annot_kws={ annot_kws={
"size": 8}, "size": 8},
cmap='Blues', cmap='Blues',
fmt='.2f', fmt='.2f',
square=True, square=True,
vmin=0.0, vmin=0.0,
xticklabels=names + ['background FP'] if labels else "auto", xticklabels=names + ['background FP'] if labels else "auto",
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1)) yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
fig.axes[0].set_xlabel('True') ax.set_ylabel('True')
fig.axes[0].set_ylabel('Predicted') ax.set_ylabel('Predicted')
plt.title('Confusion Matrix') ax.set_title('Confusion Matrix')
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250) fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
plt.close() plt.close(fig)
except Exception as e:
print(f'WARNING: ConfusionMatrix plot failure: {e}')
def print(self): def print(self):
for i in range(self.nc + 1): for i in range(self.nc + 1):
...@@ -320,6 +321,7 @@ def wh_iou(wh1, wh2, eps=1e-7): ...@@ -320,6 +321,7 @@ def wh_iou(wh1, wh2, eps=1e-7):
# Plots ---------------------------------------------------------------------------------------------------------------- # Plots ----------------------------------------------------------------------------------------------------------------
@threaded
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()): def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
# Precision-recall curve # Precision-recall curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
...@@ -336,12 +338,13 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()): ...@@ -336,12 +338,13 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
ax.set_ylabel('Precision') ax.set_ylabel('Precision')
ax.set_xlim(0, 1) ax.set_xlim(0, 1)
ax.set_ylim(0, 1) ax.set_ylim(0, 1)
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
plt.title('Precision-Recall Curve') ax.set_title('Precision-Recall Curve')
fig.savefig(save_dir, dpi=250) fig.savefig(save_dir, dpi=250)
plt.close() plt.close(fig)
@threaded
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'): def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
# Metric-confidence curve # Metric-confidence curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
...@@ -358,7 +361,7 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi ...@@ -358,7 +361,7 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
ax.set_ylabel(ylabel) ax.set_ylabel(ylabel)
ax.set_xlim(0, 1) ax.set_xlim(0, 1)
ax.set_ylim(0, 1) ax.set_ylim(0, 1)
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
plt.title(f'{ylabel}-Confidence Curve') ax.set_title(f'{ylabel}-Confidence Curve')
fig.savefig(save_dir, dpi=250) fig.savefig(save_dir, dpi=250)
plt.close() plt.close(fig)
...@@ -19,8 +19,9 @@ import seaborn as sn ...@@ -19,8 +19,9 @@ import seaborn as sn
import torch import torch
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from utils import TryExcept, threaded
from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_coords, increment_path, from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_coords, increment_path,
is_ascii, threaded, try_except, xywh2xyxy, xyxy2xywh) is_ascii, xywh2xyxy, xyxy2xywh)
from utils.metrics import fitness from utils.metrics import fitness
# Settings # Settings
...@@ -339,7 +340,7 @@ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_ ...@@ -339,7 +340,7 @@ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_
plt.savefig(f, dpi=300) plt.savefig(f, dpi=300)
@try_except # known issue https://github.com/ultralytics/yolov5/issues/5395 @TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
def plot_labels(labels, names=(), save_dir=Path('')): def plot_labels(labels, names=(), save_dir=Path('')):
# plot dataset labels # plot dataset labels
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ") LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论