Unverified 提交 ca9babb8 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Add ComputeLoss() class (#1950)

上级 f4a78e1b
...@@ -13,7 +13,6 @@ from models.experimental import attempt_load ...@@ -13,7 +13,6 @@ from models.experimental import attempt_load
from utils.datasets import create_dataloader from utils.datasets import create_dataloader
from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, check_requirements, \ from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, check_requirements, \
box_iou, non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path, colorstr box_iou, non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path, colorstr
from utils.loss import compute_loss
from utils.metrics import ap_per_class, ConfusionMatrix from utils.metrics import ap_per_class, ConfusionMatrix
from utils.plots import plot_images, output_to_target, plot_study_txt from utils.plots import plot_images, output_to_target, plot_study_txt
from utils.torch_utils import select_device, time_synchronized from utils.torch_utils import select_device, time_synchronized
...@@ -36,7 +35,8 @@ def test(data, ...@@ -36,7 +35,8 @@ def test(data,
save_hybrid=False, # for hybrid auto-labelling save_hybrid=False, # for hybrid auto-labelling
save_conf=False, # save auto-label confidences save_conf=False, # save auto-label confidences
plots=True, plots=True,
log_imgs=0): # number of logged images log_imgs=0, # number of logged images
compute_loss=None):
# Initialize/load model and set device # Initialize/load model and set device
training = model is not None training = model is not None
...@@ -111,8 +111,8 @@ def test(data, ...@@ -111,8 +111,8 @@ def test(data,
t0 += time_synchronized() - t t0 += time_synchronized() - t
# Compute loss # Compute loss
if training: if compute_loss:
loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # box, obj, cls loss += compute_loss([x.float() for x in train_out], targets)[1][:3] # box, obj, cls
# Run NMS # Run NMS
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
......
...@@ -29,7 +29,7 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima ...@@ -29,7 +29,7 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima
fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \ fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
check_requirements, print_mutation, set_logging, one_cycle, colorstr check_requirements, print_mutation, set_logging, one_cycle, colorstr
from utils.google_utils import attempt_download from utils.google_utils import attempt_download
from utils.loss import compute_loss from utils.loss import ComputeLoss
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first
...@@ -227,6 +227,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ...@@ -227,6 +227,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
scheduler.last_epoch = start_epoch - 1 # do not move scheduler.last_epoch = start_epoch - 1 # do not move
scaler = amp.GradScaler(enabled=cuda) scaler = amp.GradScaler(enabled=cuda)
compute_loss = ComputeLoss(model) # init loss class
logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n' logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
f'Using {dataloader.num_workers} dataloader workers\n' f'Using {dataloader.num_workers} dataloader workers\n'
f'Logging results to {save_dir}\n' f'Logging results to {save_dir}\n'
...@@ -286,7 +287,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ...@@ -286,7 +287,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# Forward # Forward
with amp.autocast(enabled=cuda): with amp.autocast(enabled=cuda):
pred = model(imgs) # forward pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device), model) # loss scaled by batch_size loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if rank != -1: if rank != -1:
loss *= opt.world_size # gradient averaged between devices in DDP mode loss *= opt.world_size # gradient averaged between devices in DDP mode
if opt.quad: if opt.quad:
...@@ -344,7 +345,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ...@@ -344,7 +345,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
dataloader=testloader, dataloader=testloader,
save_dir=save_dir, save_dir=save_dir,
plots=plots and final_epoch, plots=plots and final_epoch,
log_imgs=opt.log_imgs if wandb else 0) log_imgs=opt.log_imgs if wandb else 0,
compute_loss=compute_loss)
# Write # Write
with open(results_file, 'a') as f: with open(results_file, 'a') as f:
......
...@@ -85,34 +85,45 @@ class QFocalLoss(nn.Module): ...@@ -85,34 +85,45 @@ class QFocalLoss(nn.Module):
return loss return loss
def compute_loss(p, targets, model): # predictions, targets, model class ComputeLoss:
device = targets.device # Compute losses
lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device) def __init__(self, model, autobalance=False):
tcls, tbox, indices, anchors = build_targets(p, targets, model) # targets super(ComputeLoss, self).__init__()
device = next(model.parameters()).device # get model device
h = model.hyp # hyperparameters h = model.hyp # hyperparameters
# Define criteria # Define criteria
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device)) # weight=model.class_weights) BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device)) BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
cp, cn = smooth_BCE(eps=0.0) self.cp, self.cn = smooth_BCE(eps=0.0)
# Focal loss # Focal loss
g = h['fl_gamma'] # focal loss gamma g = h['fl_gamma'] # focal loss gamma
if g > 0: if g > 0:
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
self.balance = {3: [3.67, 1.0, 0.43], 4: [3.78, 1.0, 0.39, 0.22], 5: [3.88, 1.0, 0.37, 0.17, 0.10]}[det.nl]
# self.balance = [1.0] * det.nl
self.ssi = (det.stride == 16).nonzero(as_tuple=False).item() # stride 16 index
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance
for k in 'na', 'nc', 'nl', 'anchors':
setattr(self, k, getattr(det, k))
def __call__(self, p, targets): # predictions, targets, model
device = targets.device
lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
tcls, tbox, indices, anchors = self.build_targets(p, targets) # targets
# Losses # Losses
nt = 0 # number of targets
balance = [4.0, 1.0, 0.3, 0.1, 0.03] # P3-P7
for i, pi in enumerate(p): # layer index, layer predictions for i, pi in enumerate(p): # layer index, layer predictions
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
tobj = torch.zeros_like(pi[..., 0], device=device) # target obj tobj = torch.zeros_like(pi[..., 0], device=device) # target obj
n = b.shape[0] # number of targets n = b.shape[0] # number of targets
if n: if n:
nt += n # cumulative targets
ps = pi[b, a, gj, gi] # prediction subset corresponding to targets ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
# Regression # Regression
...@@ -123,33 +134,36 @@ def compute_loss(p, targets, model): # predictions, targets, model ...@@ -123,33 +134,36 @@ def compute_loss(p, targets, model): # predictions, targets, model
lbox += (1.0 - iou).mean() # iou loss lbox += (1.0 - iou).mean() # iou loss
# Objectness # Objectness
tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio
# Classification # Classification
if model.nc > 1: # cls loss (only if multiple classes) if self.nc > 1: # cls loss (only if multiple classes)
t = torch.full_like(ps[:, 5:], cn, device=device) # targets t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets
t[range(n), tcls[i]] = cp t[range(n), tcls[i]] = self.cp
lcls += BCEcls(ps[:, 5:], t) # BCE lcls += self.BCEcls(ps[:, 5:], t) # BCE
# Append targets to text file # Append targets to text file
# with open('targets.txt', 'a') as file: # with open('targets.txt', 'a') as file:
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)] # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss obji = self.BCEobj(pi[..., 4], tobj)
lobj += obji * self.balance[i] # obj loss
if self.autobalance:
self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()
lbox *= h['box'] if self.autobalance:
lobj *= h['obj'] self.balance = [x / self.balance[self.ssi] for x in self.balance]
lcls *= h['cls'] lbox *= self.hyp['box']
lobj *= self.hyp['obj']
lcls *= self.hyp['cls']
bs = tobj.shape[0] # batch size bs = tobj.shape[0] # batch size
loss = lbox + lobj + lcls loss = lbox + lobj + lcls
return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach() return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()
def build_targets(self, p, targets):
def build_targets(p, targets, model):
# Build targets for compute_loss(), input targets(image,class,x,y,w,h) # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module na, nt = self.na, targets.shape[0] # number of anchors, targets
na, nt = det.na, targets.shape[0] # number of anchors, targets
tcls, tbox, indices, anch = [], [], [], [] tcls, tbox, indices, anch = [], [], [], []
gain = torch.ones(7, device=targets.device) # normalized to gridspace gain gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt) ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
...@@ -161,8 +175,8 @@ def build_targets(p, targets, model): ...@@ -161,8 +175,8 @@ def build_targets(p, targets, model):
# [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
], device=targets.device).float() * g # offsets ], device=targets.device).float() * g # offsets
for i in range(det.nl): for i in range(self.nl):
anchors = det.anchors[i] anchors = self.anchors[i]
gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
# Match targets to anchors # Match targets to anchors
...@@ -170,7 +184,7 @@ def build_targets(p, targets, model): ...@@ -170,7 +184,7 @@ def build_targets(p, targets, model):
if nt: if nt:
# Matches # Matches
r = t[:, :, 4:6] / anchors[:, None] # wh ratio r = t[:, :, 4:6] / anchors[:, None] # wh ratio
j = torch.max(r, 1. / r).max(2)[0] < model.hyp['anchor_t'] # compare j = torch.max(r, 1. / r).max(2)[0] < self.hyp['anchor_t'] # compare
# j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2)) # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
t = t[j] # filter t = t[j] # filter
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论