Unverified 提交 c215878f authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

YOLOv5 Apple Metal Performance Shader (MPS) support (#7878)

* Apple Metal Performance Shader (MPS) device support Following https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/ Should work with Apple M1 devices with PyTorch nightly installed with command `--device mps`. Usage examples: ```bash python train.py --device mps python detect.py --device mps python val.py --device mps ``` * Update device strategy to fix MPS issue
上级 27911dc8
...@@ -486,7 +486,7 @@ def run( ...@@ -486,7 +486,7 @@ def run(
if half: if half:
assert device.type != 'cpu' or coreml or xml, '--half only compatible with GPU export, i.e. use --device 0' assert device.type != 'cpu' or coreml or xml, '--half only compatible with GPU export, i.e. use --device 0'
assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both' assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'
model = attempt_load(weights, map_location=device, inplace=True, fuse=True) # load FP32 model model = attempt_load(weights, device=device, inplace=True, fuse=True) # load FP32 model
nc, names = model.nc, model.names # number of classes, class names nc, names = model.nc, model.names # number of classes, class names
# Checks # Checks
......
...@@ -331,7 +331,7 @@ class DetectMultiBackend(nn.Module): ...@@ -331,7 +331,7 @@ class DetectMultiBackend(nn.Module):
names = yaml.safe_load(f)['names'] names = yaml.safe_load(f)['names']
if pt: # PyTorch if pt: # PyTorch
model = attempt_load(weights if isinstance(weights, list) else w, map_location=device) model = attempt_load(weights if isinstance(weights, list) else w, device=device)
stride = max(int(model.stride.max()), 32) # model stride stride = max(int(model.stride.max()), 32) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names names = model.module.names if hasattr(model, 'module') else model.names # get class names
model.half() if fp16 else model.float() model.half() if fp16 else model.float()
......
...@@ -71,14 +71,14 @@ class Ensemble(nn.ModuleList): ...@@ -71,14 +71,14 @@ class Ensemble(nn.ModuleList):
return y, None # inference, train output return y, None # inference, train output
def attempt_load(weights, map_location=None, inplace=True, fuse=True): def attempt_load(weights, device=None, inplace=True, fuse=True):
from models.yolo import Detect, Model from models.yolo import Detect, Model
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
model = Ensemble() model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]: for w in weights if isinstance(weights, list) else [weights]:
ckpt = torch.load(attempt_download(w), map_location=map_location) # load ckpt = torch.load(attempt_download(w))
ckpt = (ckpt.get('ema') or ckpt['model']).float() # FP32 model ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
model.append(ckpt.fuse().eval() if fuse else ckpt.eval()) # fused or un-fused model in eval mode model.append(ckpt.fuse().eval() if fuse else ckpt.eval()) # fused or un-fused model in eval mode
# Compatibility updates # Compatibility updates
......
...@@ -536,7 +536,7 @@ def run( ...@@ -536,7 +536,7 @@ def run(
): ):
# PyTorch model # PyTorch model
im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
model = attempt_load(weights, map_location=torch.device('cpu'), inplace=True, fuse=False) model = attempt_load(weights, device=torch.device('cpu'), inplace=True, fuse=False)
_ = model(im) # inference _ = model(im) # inference
model.info() model.info()
......
...@@ -54,7 +54,8 @@ def select_device(device='', batch_size=0, newline=True): ...@@ -54,7 +54,8 @@ def select_device(device='', batch_size=0, newline=True):
s = f'YOLOv5 🚀 {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} ' s = f'YOLOv5 🚀 {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} '
device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0' device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
cpu = device == 'cpu' cpu = device == 'cpu'
if cpu: mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
if cpu or mps:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
elif device: # non-cpu device requested elif device: # non-cpu device requested
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available() os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
...@@ -71,13 +72,15 @@ def select_device(device='', batch_size=0, newline=True): ...@@ -71,13 +72,15 @@ def select_device(device='', batch_size=0, newline=True):
for i, d in enumerate(devices): for i, d in enumerate(devices):
p = torch.cuda.get_device_properties(i) p = torch.cuda.get_device_properties(i)
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
elif mps:
s += 'MPS\n'
else: else:
s += 'CPU\n' s += 'CPU\n'
if not newline: if not newline:
s = s.rstrip() s = s.rstrip()
LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
return torch.device('cuda:0' if cuda else 'cpu') return torch.device('cuda:0' if cuda else 'mps' if mps else 'cpu')
def time_sync(): def time_sync():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论