Unverified 提交 e1dc8943 authored 作者: bilzard's avatar bilzard 提交者: GitHub

Enable AdamW optimizer (#6152)

上级 d95978a5
...@@ -22,7 +22,7 @@ import torch.nn as nn ...@@ -22,7 +22,7 @@ import torch.nn as nn
import yaml import yaml
from torch.cuda import amp from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD, Adam, lr_scheduler from torch.optim import SGD, Adam, AdamW, lr_scheduler
from tqdm import tqdm from tqdm import tqdm
FILE = Path(__file__).resolve() FILE = Path(__file__).resolve()
...@@ -155,8 +155,10 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ...@@ -155,8 +155,10 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay) elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
g1.append(v.weight) g1.append(v.weight)
if opt.adam: if opt.optimizer == 'Adam':
optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
elif opt.optimizer == 'AdamW':
optimizer = AdamW(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
else: else:
optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True) optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
...@@ -460,7 +462,7 @@ def parse_opt(known=False): ...@@ -460,7 +462,7 @@ def parse_opt(known=False):
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class') parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer') parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'], default='SGD', help='optimizer')
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode') parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)') parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name') parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论