Unverified 提交 9728e2b8 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

--image_weights bug fix (#1524)

上级 e9a0ae6f
...@@ -181,8 +181,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ...@@ -181,8 +181,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# Trainloader # Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
rank=rank, world_size=opt.world_size, workers=opt.workers) world_size=opt.world_size, workers=opt.workers,
image_weights=opt.image_weights)
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
nb = len(dataloader) # number of batches nb = len(dataloader) # number of batches
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1) assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
......
...@@ -55,7 +55,7 @@ def exif_size(img): ...@@ -55,7 +55,7 @@ def exif_size(img):
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False, def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
rank=-1, world_size=1, workers=8): rank=-1, world_size=1, workers=8, image_weights=False):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with torch_distributed_zero_first(rank): with torch_distributed_zero_first(rank):
dataset = LoadImagesAndLabels(path, imgsz, batch_size, dataset = LoadImagesAndLabels(path, imgsz, batch_size,
...@@ -66,7 +66,8 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa ...@@ -66,7 +66,8 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
single_cls=opt.single_cls, single_cls=opt.single_cls,
stride=int(stride), stride=int(stride),
pad=pad, pad=pad,
rank=rank) rank=rank,
image_weights=image_weights)
batch_size = min(batch_size, len(dataset)) batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
...@@ -392,6 +393,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -392,6 +393,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
nb = bi[-1] + 1 # number of batches nb = bi[-1] + 1 # number of batches
self.batch = bi # batch index of image self.batch = bi # batch index of image
self.n = n self.n = n
self.indices = range(n)
# Rectangular Training # Rectangular Training
if self.rect: if self.rect:
...@@ -485,8 +487,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -485,8 +487,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
# return self # return self
def __getitem__(self, index): def __getitem__(self, index):
if self.image_weights: index = self.indices[index] # linear, shuffled, or image_weights
index = self.indices[index]
hyp = self.hyp hyp = self.hyp
mosaic = self.mosaic and random.random() < hyp['mosaic'] mosaic = self.mosaic and random.random() < hyp['mosaic']
...@@ -497,7 +498,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing ...@@ -497,7 +498,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
# MixUp https://arxiv.org/pdf/1710.09412.pdf # MixUp https://arxiv.org/pdf/1710.09412.pdf
if random.random() < hyp['mixup']: if random.random() < hyp['mixup']:
img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1)) img2, labels2 = load_mosaic(self, random.randint(0, self.n - 1))
r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0 r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0
img = (img * r + img2 * (1 - r)).astype(np.uint8) img = (img * r + img2 * (1 - r)).astype(np.uint8)
labels = np.concatenate((labels, labels2), 0) labels = np.concatenate((labels, labels2), 0)
...@@ -619,7 +620,7 @@ def load_mosaic(self, index): ...@@ -619,7 +620,7 @@ def load_mosaic(self, index):
labels4 = [] labels4 = []
s = self.img_size s = self.img_size
yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
indices = [index] + [random.randint(0, len(self.labels) - 1) for _ in range(3)] # 3 additional image indices indices = [index] + [self.indices[random.randint(0, self.n - 1)] for _ in range(3)] # 3 additional image indices
for i, index in enumerate(indices): for i, index in enumerate(indices):
# Load image # Load image
img, _, (h, w) = load_image(self, index) img, _, (h, w) = load_image(self, index)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论