Unverified 提交 9b11f0c5 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

PyTorch Hub models default to CUDA:0 if available (#2472)

* PyTorch Hub models default to CUDA:0 if available * device as string bug fix
上级 2d41e70e
...@@ -12,6 +12,7 @@ import torch ...@@ -12,6 +12,7 @@ import torch
from models.yolo import Model from models.yolo import Model
from utils.general import set_logging from utils.general import set_logging
from utils.google_utils import attempt_download from utils.google_utils import attempt_download
from utils.torch_utils import select_device
dependencies = ['torch', 'yaml'] dependencies = ['torch', 'yaml']
set_logging() set_logging()
...@@ -43,7 +44,8 @@ def create(name, pretrained, channels, classes, autoshape): ...@@ -43,7 +44,8 @@ def create(name, pretrained, channels, classes, autoshape):
model.names = ckpt['model'].names # set class names attribute model.names = ckpt['model'].names # set class names attribute
if autoshape: if autoshape:
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
return model device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available
return model.to(device)
except Exception as e: except Exception as e:
help_url = 'https://github.com/ultralytics/yolov5/issues/36' help_url = 'https://github.com/ultralytics/yolov5/issues/36'
......
...@@ -385,7 +385,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -385,7 +385,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
# Display cache # Display cache
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
if exists: if exists:
d = f"Scanning '{cache_path}' for images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted" d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}' assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
...@@ -485,7 +485,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -485,7 +485,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
nc += 1 nc += 1
print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}') print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' for images and labels... " \ pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \
f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted" f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
if nf == 0: if nf == 0:
......
...@@ -79,7 +79,7 @@ def check_git_status(): ...@@ -79,7 +79,7 @@ def check_git_status():
f"Use 'git pull' to update or 'git clone {url}' to download latest." f"Use 'git pull' to update or 'git clone {url}' to download latest."
else: else:
s = f'up to date with {url} ✅' s = f'up to date with {url} ✅'
print(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) print(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
except Exception as e: except Exception as e:
print(e) print(e)
......
# PyTorch utils # PyTorch utils
import logging import logging
import math import math
import os import os
import platform
import subprocess import subprocess
import time import time
from contextlib import contextmanager from contextlib import contextmanager
...@@ -53,7 +53,7 @@ def git_describe(): ...@@ -53,7 +53,7 @@ def git_describe():
def select_device(device='', batch_size=None): def select_device(device='', batch_size=None):
# device = 'cpu' or '0' or '0,1,2,3' # device = 'cpu' or '0' or '0,1,2,3'
s = f'YOLOv5 {git_describe()} torch {torch.__version__} ' # string s = f'YOLOv5 🚀 {git_describe()} torch {torch.__version__} ' # string
cpu = device.lower() == 'cpu' cpu = device.lower() == 'cpu'
if cpu: if cpu:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
...@@ -73,7 +73,7 @@ def select_device(device='', batch_size=None): ...@@ -73,7 +73,7 @@ def select_device(device='', batch_size=None):
else: else:
s += 'CPU\n' s += 'CPU\n'
logger.info(s) # skip a line logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
return torch.device('cuda:0' if cuda else 'cpu') return torch.device('cuda:0' if cuda else 'cpu')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论