Unverified 提交 8bc0027a authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update loss criteria constructor (#1711)

上级 79972410
import argparse
import logging
import math
import os
import random
import time
......@@ -7,7 +8,6 @@ from pathlib import Path
from threading import Thread
from warnings import warn
import math
import numpy as np
import torch.distributed as dist
import torch.nn as nn
......@@ -217,7 +217,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
model.nc = nc # attach number of classes to model
model.hyp = hyp # attach hyperparameters to model
model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
model.names = names
# Start training
......@@ -238,7 +238,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
if opt.image_weights:
# Generate indices
if rank in [-1, 0]:
cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
# Broadcast if DDP
......@@ -330,7 +330,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
if rank in [-1, 0]:
# mAP
if ema:
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride'])
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs
if not opt.notest or final_epoch: # Calculate mAP
results, maps, times = test.test(opt.data,
......
......@@ -1199,7 +1199,7 @@
"\n",
"m1 = lambda x: x * torch.sigmoid(x)\n",
"m2 = torch.nn.SiLU()\n",
"profile(x=torch.randn(16, 3, 640, 640), [m1, m2], n=100)"
"profile(x=torch.randn(16, 3, 640, 640), ops=[m1, m2], n=100)"
],
"execution_count": null,
"outputs": []
......
......@@ -57,8 +57,8 @@ class FocalLoss(nn.Module):
return loss.sum()
else: # 'none'
return loss
class QFocalLoss(nn.Module):
# Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
......@@ -71,7 +71,7 @@ class QFocalLoss(nn.Module):
def forward(self, pred, true):
loss = self.loss_fcn(pred, true)
pred_prob = torch.sigmoid(pred) # prob from logits
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
modulating_factor = torch.abs(true - pred_prob) ** self.gamma
......@@ -92,8 +92,8 @@ def compute_loss(p, targets, model): # predictions, targets, model
h = model.hyp # hyperparameters
# Define criteria
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['cls_pw']])).to(device)
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['obj_pw']])).to(device)
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device)) # weight=model.class_weights)
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
cp, cn = smooth_BCE(eps=0.0)
......@@ -119,7 +119,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
# Regression
pxy = ps[:, :2].sigmoid() * 2. - 0.5
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box
pbox = torch.cat((pxy, pwh), 1) # predicted box
iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)
lbox += (1.0 - iou).mean() # iou loss
......
......@@ -81,8 +81,8 @@ def profile(x, ops, n=100, device=None):
# m1 = lambda x: x * torch.sigmoid(x)
# m2 = nn.SiLU()
# profile(x, [m1, m2], n=100) # profile speed over 100 iterations
device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
x = x.to(device)
x.requires_grad = True
print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
......@@ -99,8 +99,11 @@ def profile(x, ops, n=100, device=None):
t[0] = time_synchronized()
y = m(x)
t[1] = time_synchronized()
_ = y.sum().backward()
t[2] = time_synchronized()
try:
_ = y.sum().backward()
t[2] = time_synchronized()
except: # no backward method
t[2] = float('nan')
dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论