Unverified 提交 fdcb92a9 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update train.py `import val as validate` (#9037)

* Update train.py `import val as validate` Signed-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ciSigned-off-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
上级 4a8ab3bc
...@@ -36,7 +36,7 @@ if str(ROOT) not in sys.path: ...@@ -36,7 +36,7 @@ if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
import val # for end-of-epoch mAP import val as validate # for end-of-epoch mAP
from models.experimental import attempt_load from models.experimental import attempt_load
from models.yolo import Model from models.yolo import Model
from utils.autoanchor import check_anchors from utils.autoanchor import check_anchors
...@@ -347,7 +347,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio ...@@ -347,7 +347,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights']) ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
final_epoch = (epoch + 1 == epochs) or stopper.possible_stop final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
if not noval or final_epoch: # Calculate mAP if not noval or final_epoch: # Calculate mAP
results, maps, _ = val.run(data_dict, results, maps, _ = validate.run(data_dict,
batch_size=batch_size // WORLD_SIZE * 2, batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz, imgsz=imgsz,
half=amp, half=amp,
...@@ -407,12 +407,12 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio ...@@ -407,12 +407,12 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
strip_optimizer(f) # strip optimizers strip_optimizer(f) # strip optimizers
if f is best: if f is best:
LOGGER.info(f'\nValidating {f}...') LOGGER.info(f'\nValidating {f}...')
results, _, _ = val.run( results, _, _ = validate.run(
data_dict, data_dict,
batch_size=batch_size // WORLD_SIZE * 2, batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz, imgsz=imgsz,
model=attempt_load(f, device).half(), model=attempt_load(f, device).half(),
iou_thres=0.65 if is_coco else 0.60, # best pycocotools results at 0.65 iou_thres=0.65 if is_coco else 0.60, # best pycocotools at iou 0.65
single_cls=single_cls, single_cls=single_cls,
dataloader=val_loader, dataloader=val_loader,
save_dir=save_dir, save_dir=save_dir,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论