Unverified 提交 4db6757e authored 作者: imyhxy's avatar imyhxy 提交者: GitHub

Fixed access 'names' from a DistributedDataParallel module (#11023)

上级 7a972e86
......@@ -44,7 +44,7 @@ from utils.general import (DATASETS_DIR, LOGGER, TQDM_BAR_FORMAT, WorkingDirecto
check_requirements, colorstr, download, increment_path, init_seeds, print_args, yaml_save)
from utils.loggers import GenericLogger
from utils.plots import imshow_cls
from utils.torch_utils import (ModelEMA, model_info, reshape_classifier_output, select_device, smart_DDP,
from utils.torch_utils import (ModelEMA, de_parallel, model_info, reshape_classifier_output, select_device, smart_DDP,
smart_optimizer, smartCrossEntropyLoss, torch_distributed_zero_first)
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
......@@ -260,7 +260,7 @@ def train(opt, device):
# Plot examples
images, labels = (x[:25] for x in next(iter(testloader))) # first 25 images and labels
pred = torch.max(ema.ema(images.to(device)), 1)[1]
file = imshow_cls(images, labels, pred, model.names, verbose=False, f=save_dir / 'test_images.jpg')
file = imshow_cls(images, labels, pred, de_parallel(model).names, verbose=False, f=save_dir / 'test_images.jpg')
# Log results
meta = {'epochs': epochs, 'top1_acc': best_fitness, 'date': datetime.now().isoformat()}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论