提交 08e97a2f authored 作者: Glenn Jocher's avatar Glenn Jocher

Update hyperparameters to add lrf, anchors

上级 9776e709
# Hyperparameters for VOC fine-tuning
# python train.py --batch 64 --cfg '' --weights yolov5m.pt --data voc.yaml --img 512 --epochs 50
# Hyperparameters for VOC finetuning
# python train.py --batch 64 --weights yolov5m.pt --data voc.yaml --img 512 --epochs 50
# See tutorials for hyperparameter evolution https://github.com/ultralytics/yolov5#tutorials
lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
momentum: 0.94 # SGD momentum/Adam beta1
weight_decay: 0.0005 # optimizer weight decay 5e-4
giou: 0.05 # GIoU loss gain
cls: 0.4 # cls loss gain
cls_pw: 1.0 # cls BCELoss positive_weight
obj: 0.5 # obj loss gain (scale with pixels)
obj_pw: 1.0 # obj BCELoss positive_weight
iou_t: 0.20 # IoU training threshold
anchor_t: 4.0 # anchor-multiple threshold
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
hsv_v: 0.4 # image HSV-Value augmentation (fraction)
degrees: 1.0 # image rotation (+/- deg)
translate: 0.1 # image translation (+/- fraction)
scale: 0.6 # image scale (+/- gain)
shear: 1.0 # image shear (+/- deg)
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
flipud: 0.01 # image flip up-down (probability)
fliplr: 0.5 # image flip left-right (probability)
mixup: 0.2 # image mixup (probability)
# Hyperparameter Evolution Results
# Generations: 51
# P R mAP.5 mAP.5:.95 box obj cls
# Metrics: 0.625 0.926 0.89 0.677 0.0111 0.00849 0.00124
lr0: 0.00447
lrf: 0.114
momentum: 0.873
weight_decay: 0.00047
giou: 0.0306
cls: 0.211
cls_pw: 0.546
obj: 0.421
obj_pw: 0.972
iou_t: 0.2
anchor_t: 2.26
# anchors: 5.07
fl_gamma: 0.0
hsv_h: 0.0154
hsv_s: 0.9
hsv_v: 0.619
degrees: 0.404
translate: 0.206
scale: 0.86
shear: 0.795
perspective: 0.0
flipud: 0.00756
fliplr: 0.5
mixup: 0.153
......@@ -4,15 +4,17 @@
lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
lrf: 0.2 # final OneCycleLR learning rate (lr0 * lrf)
momentum: 0.937 # SGD momentum/Adam beta1
weight_decay: 0.0005 # optimizer weight decay 5e-4
giou: 0.05 # GIoU loss gain
giou: 0.05 # box loss gain
cls: 0.5 # cls loss gain
cls_pw: 1.0 # cls BCELoss positive_weight
obj: 1.0 # obj loss gain (scale with pixels)
obj_pw: 1.0 # obj BCELoss positive_weight
iou_t: 0.20 # IoU training threshold
anchor_t: 4.0 # anchor-multiple threshold
# anchors: 0 # anchors per output grid (0 to ignore)
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
......
......@@ -53,7 +53,7 @@ def train(hyp, opt, device, tb_writer=None):
cuda = device.type != 'cpu'
init_seeds(2 + rank)
with open(opt.data) as f:
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
data_dict = yaml.load(f, Loader=yaml.FullLoader) # data dict
with torch_distributed_zero_first(rank):
check_dataset(data_dict) # check
train_path = data_dict['train']
......@@ -67,6 +67,8 @@ def train(hyp, opt, device, tb_writer=None):
with torch_distributed_zero_first(rank):
attempt_download(weights) # download if not found locally
ckpt = torch.load(weights, map_location=device) # load checkpoint
# if hyp['anchors']:
# ckpt['model'].yaml['anchors'] = round(hyp['anchors']) # force autoanchor
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create
exclude = ['anchor'] if opt.cfg else [] # exclude keys
state_dict = ckpt['model'].float().state_dict() # to FP32
......@@ -111,7 +113,7 @@ def train(hyp, opt, device, tb_writer=None):
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
# https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.8 + 0.2 # cosine
lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - hyp['lrf']) + hyp['lrf'] # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
# plot_lr_scheduler(optimizer, scheduler, epochs)
......@@ -459,6 +461,7 @@ if __name__ == '__main__':
else:
# Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
'momentum': (0.1, 0.6, 0.98), # SGD momentum/Adam beta1
'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
'giou': (1, 0.02, 0.2), # GIoU loss gain
......@@ -468,6 +471,7 @@ if __name__ == '__main__':
'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
'iou_t': (0, 0.1, 0.7), # IoU training threshold
'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
# 'anchors': (1, 2.0, 10.0), # anchors per output grid (0 to ignore)
'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
......@@ -476,9 +480,9 @@ if __name__ == '__main__':
'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
'scale': (1, 0.0, 0.9), # image scale (+/- gain)
'shear': (1, 0.0, 10.0), # image shear (+/- deg)
'perspective': (1, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
'flipud': (0, 0.0, 1.0), # image flip up-down (probability)
'fliplr': (1, 0.0, 1.0), # image flip left-right (probability)
'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
'mixup': (1, 0.0, 1.0)} # image mixup (probability)
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论