Unverified 提交 840b7232 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Attach transforms to model (#9028)

* Attach transforms to model Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update val.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * Update train.py Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 4bc5520e
...@@ -122,16 +122,16 @@ def train(opt, device): ...@@ -122,16 +122,16 @@ def train(opt, device):
for p in model.parameters(): for p in model.parameters():
p.requires_grad = True # for training p.requires_grad = True # for training
model = model.to(device) model = model.to(device)
names = trainloader.dataset.classes # class names
model.names = names # attach class names
# Info # Info
if RANK in {-1, 0}: if RANK in {-1, 0}:
model.names = trainloader.dataset.classes # attach class names
model.transforms = testloader.dataset.torch_transforms # attach inference transforms
model_info(model) model_info(model)
if opt.verbose: if opt.verbose:
LOGGER.info(model) LOGGER.info(model)
images, labels = next(iter(trainloader)) images, labels = next(iter(trainloader))
file = imshow_cls(images[:25], labels[:25], names=names, f=save_dir / 'train_images.jpg') file = imshow_cls(images[:25], labels[:25], names=model.names, f=save_dir / 'train_images.jpg')
logger.log_images(file, name='Train Examples') logger.log_images(file, name='Train Examples')
logger.log_graph(model, imgsz) # log model logger.log_graph(model, imgsz) # log model
...@@ -254,8 +254,8 @@ def train(opt, device): ...@@ -254,8 +254,8 @@ def train(opt, device):
# Plot examples # Plot examples
images, labels = (x[:25] for x in next(iter(testloader))) # first 25 images and labels images, labels = (x[:25] for x in next(iter(testloader))) # first 25 images and labels
pred = torch.max(ema.ema((images.half() if cuda else images.float()).to(device)), 1)[1] pred = torch.max(ema.ema(images.to(device)), 1)[1]
file = imshow_cls(images, labels, pred, names, verbose=False, f=save_dir / 'test_images.jpg') file = imshow_cls(images, labels, pred, model.names, verbose=False, f=save_dir / 'test_images.jpg')
# Log results # Log results
meta = {"epochs": epochs, "top1_acc": best_fitness, "date": datetime.now().isoformat()} meta = {"epochs": epochs, "top1_acc": best_fitness, "date": datetime.now().isoformat()}
......
...@@ -39,7 +39,7 @@ def run( ...@@ -39,7 +39,7 @@ def run(
project=ROOT / 'runs/val-cls', # save to project/name project=ROOT / 'runs/val-cls', # save to project/name
name='exp', # save to project/name name='exp', # save to project/name
exist_ok=False, # existing project/name ok, do not increment exist_ok=False, # existing project/name ok, do not increment
half=True, # use FP16 half-precision inference half=False, # use FP16 half-precision inference
dnn=False, # use OpenCV DNN for ONNX inference dnn=False, # use OpenCV DNN for ONNX inference
model=None, model=None,
dataloader=None, dataloader=None,
...@@ -124,7 +124,6 @@ def run( ...@@ -124,7 +124,6 @@ def run(
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t) LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}") LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
model.float() # for training
return top1, top5, loss return top1, top5, loss
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论