Unverified 提交 b5de52c4 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

torch.cuda.amp bug fix (#2750)

PR https://github.com/ultralytics/yolov5/pull/2725 introduced a very specific bug that only affects multi-GPU trainings. Apparently the cause was using the torch.cuda.amp decorator in the autoShape forward method. I've implemented amp more traditionally in this PR, and the bug is resolved.
上级 fca5e2a4
...@@ -10,6 +10,7 @@ import requests ...@@ -10,6 +10,7 @@ import requests
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from PIL import Image
from torch.cuda import amp
from utils.datasets import letterbox from utils.datasets import letterbox
from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh
...@@ -237,7 +238,6 @@ class autoShape(nn.Module): ...@@ -237,7 +238,6 @@ class autoShape(nn.Module):
return self return self
@torch.no_grad() @torch.no_grad()
@torch.cuda.amp.autocast(torch.cuda.is_available())
def forward(self, imgs, size=640, augment=False, profile=False): def forward(self, imgs, size=640, augment=False, profile=False):
# Inference from various sources. For height=640, width=1280, RGB images example inputs are: # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
# filename: imgs = 'data/samples/zidane.jpg' # filename: imgs = 'data/samples/zidane.jpg'
...@@ -251,7 +251,8 @@ class autoShape(nn.Module): ...@@ -251,7 +251,8 @@ class autoShape(nn.Module):
t = [time_synchronized()] t = [time_synchronized()]
p = next(self.model.parameters()) # for device and type p = next(self.model.parameters()) # for device and type
if isinstance(imgs, torch.Tensor): # torch if isinstance(imgs, torch.Tensor): # torch
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference with amp.autocast(enabled=p.device.type != 'cpu'):
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
# Pre-process # Pre-process
n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
...@@ -278,17 +279,18 @@ class autoShape(nn.Module): ...@@ -278,17 +279,18 @@ class autoShape(nn.Module):
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32 x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
t.append(time_synchronized()) t.append(time_synchronized())
# Inference with amp.autocast(enabled=p.device.type != 'cpu'):
y = self.model(x, augment, profile)[0] # forward # Inference
t.append(time_synchronized()) y = self.model(x, augment, profile)[0] # forward
t.append(time_synchronized())
# Post-process # Post-process
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
for i in range(n): for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i]) scale_coords(shape1, y[i][:, :4], shape0[i])
t.append(time_synchronized()) t.append(time_synchronized())
return Detections(imgs, y, files, t, self.names, x.shape) return Detections(imgs, y, files, t, self.names, x.shape)
class Detections: class Detections:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论