提交 5c854fab authored 作者: glennjocher's avatar glennjocher

requires grad after reset params

上级 e08d568d
......@@ -114,13 +114,13 @@ def train(opt, device):
LOGGER.warning("WARNING: pass YOLOv5 classifier model with '-cls' suffix, i.e. '--model yolov5s-cls.pt'")
model = ClassificationModel(model=model, nc=nc, cutoff=opt.cutoff or 10) # convert to classification model
reshape_classifier_output(model, nc) # update class count
for p in model.parameters():
p.requires_grad = True # for training
for m in model.modules():
if not pretrained and hasattr(m, 'reset_parameters'):
m.reset_parameters()
if isinstance(m, torch.nn.Dropout) and opt.dropout is not None:
m.p = opt.dropout # set dropout
for p in model.parameters():
p.requires_grad = True # for training
model = model.to(device)
names = trainloader.dataset.classes # class names
model.names = names # attach class names
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论