提交 69ff781c authored 作者: Glenn Jocher's avatar Glenn Jocher

opt.img_weights bug fix (#885)

上级 987c2268
...@@ -216,18 +216,15 @@ def train(hyp, opt, device, tb_writer=None): ...@@ -216,18 +216,15 @@ def train(hyp, opt, device, tb_writer=None):
model.train() model.train()
# Update image weights (optional) # Update image weights (optional)
if dataset.image_weights: if opt.img_weights:
# Generate indices # Generate indices
if rank in [-1, 0]: if rank in [-1, 0]:
w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w) iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
dataset.indices = random.choices(range(dataset.n), weights=image_weights, dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
k=dataset.n) # rand weighted idx
# Broadcast if DDP # Broadcast if DDP
if rank != -1: if rank != -1:
indices = torch.zeros([dataset.n], dtype=torch.int) indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
if rank == 0:
indices[:] = torch.tensor(dataset.indices, dtype=torch.int)
dist.broadcast(indices, 0) dist.broadcast(indices, 0)
if rank != 0: if rank != 0:
dataset.indices = indices.cpu().numpy() dataset.indices = indices.cpu().numpy()
...@@ -388,7 +385,8 @@ if __name__ == '__main__': ...@@ -388,7 +385,8 @@ if __name__ == '__main__':
parser.add_argument('--hyp', type=str, default='', help='hyperparameters path, i.e. data/hyp.scratch.yaml') parser.add_argument('--hyp', type=str, default='', help='hyperparameters path, i.e. data/hyp.scratch.yaml')
parser.add_argument('--epochs', type=int, default=300) parser.add_argument('--epochs', type=int, default=300)
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs') parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes') parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
parser.add_argument('--img-weights', action='store_true', help='use weighted image selection for training')
parser.add_argument('--rect', action='store_true', help='rectangular training') parser.add_argument('--rect', action='store_true', help='rectangular training')
parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training') parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论