Unverified 提交 9c803f2f authored 作者: Phat Tran's avatar Phat Tran 提交者: GitHub

Add --label-smoothing eps argument to train.py (default 0.0) (#2344)

* Add label smoothing option * Correct data type * add_log * Remove log * Add log * Update loss.py remove comment (too versbose) Co-authored-by: 's avatarphattran <phat.tranhoang@cyberlogitec.com> Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 fd167997
...@@ -224,6 +224,7 @@ def train(hyp, opt, device, tb_writer=None): ...@@ -224,6 +224,7 @@ def train(hyp, opt, device, tb_writer=None):
hyp['box'] *= 3. / nl # scale to layers hyp['box'] *= 3. / nl # scale to layers
hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
hyp['label_smoothing'] = opt.label_smoothing
model.nc = nc # attach number of classes to model model.nc = nc # attach number of classes to model
model.hyp = hyp # attach hyperparameters to model model.hyp = hyp # attach hyperparameters to model
model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou) model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
...@@ -481,6 +482,7 @@ if __name__ == '__main__': ...@@ -481,6 +482,7 @@ if __name__ == '__main__':
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
parser.add_argument('--quad', action='store_true', help='quad dataloader') parser.add_argument('--quad', action='store_true', help='quad dataloader')
parser.add_argument('--linear-lr', action='store_true', help='linear LR') parser.add_argument('--linear-lr', action='store_true', help='linear LR')
parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table') parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table')
parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B') parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch') parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
......
...@@ -97,7 +97,7 @@ class ComputeLoss: ...@@ -97,7 +97,7 @@ class ComputeLoss:
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device)) BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
self.cp, self.cn = smooth_BCE(eps=0.0) self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0)) # positive, negative BCE targets
# Focal loss # Focal loss
g = h['fl_gamma'] # focal loss gamma g = h['fl_gamma'] # focal loss gamma
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论