Unverified 提交 1c8464e1 authored 作者: Khiem Doan's avatar Khiem Doan 提交者: GitHub

Use pathlib instead of low-level module (#1329)

* Use pathlib instead of low-level module * Use pathlib instead of low-level module * Update detect.py * Update test.py * reformat Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 19e24824
import argparse import argparse
import os
import time import time
from pathlib import Path from pathlib import Path
...@@ -18,13 +17,14 @@ from utils.torch_utils import select_device, load_classifier, time_synchronized ...@@ -18,13 +17,14 @@ from utils.torch_utils import select_device, load_classifier, time_synchronized
def detect(save_img=False): def detect(save_img=False):
save_dir, source, weights, view_img, save_txt, imgsz = \ save_dir, source, weights, view_img, save_txt, imgsz = \
Path(opt.save_dir), opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size Path(opt.save_dir), opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
webcam = source.isnumeric() or source.startswith(('rtsp://', 'rtmp://', 'http://')) or source.endswith('.txt') webcam = source.isnumeric() or source.endswith('.txt') or \
source.lower().startswith(('rtsp://', 'rtmp://', 'http://'))
# Directories # Directories
if save_dir == Path('runs/detect'): # if default if save_dir == Path('runs/detect'): # if default
os.makedirs('runs/detect', exist_ok=True) # make base save_dir.mkdir(parents=True, exist_ok=True) # make base
save_dir = Path(increment_dir(save_dir / 'exp', opt.name)) # increment run save_dir = Path(increment_dir(save_dir / 'exp', opt.name)) # increment run
os.makedirs(save_dir / 'labels' if save_txt else save_dir, exist_ok=True) # make new dir (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make new dir
# Initialize # Initialize
set_logging() set_logging()
......
...@@ -6,7 +6,7 @@ Usage: ...@@ -6,7 +6,7 @@ Usage:
""" """
dependencies = ['torch', 'yaml'] dependencies = ['torch', 'yaml']
import os from pathlib import Path
import torch import torch
...@@ -29,7 +29,7 @@ def create(name, pretrained, channels, classes): ...@@ -29,7 +29,7 @@ def create(name, pretrained, channels, classes):
Returns: Returns:
pytorch model pytorch model
""" """
config = os.path.join(os.path.dirname(__file__), 'models', f'{name}.yaml') # model.yaml path config = Path(__file__).parent / 'models' / f'{name}.yaml' # model.yaml path
try: try:
model = Model(config, channels, classes) model = Model(config, channels, classes)
if pretrained: if pretrained:
......
...@@ -47,9 +47,9 @@ def test(data, ...@@ -47,9 +47,9 @@ def test(data,
# Directories # Directories
if save_dir == Path('runs/test'): # if default if save_dir == Path('runs/test'): # if default
os.makedirs('runs/test', exist_ok=True) # make base save_dir.mkdir(parents=True, exist_ok=True) # make base
save_dir = Path(increment_dir(save_dir / 'exp', opt.name)) # increment run save_dir = Path(increment_dir(save_dir / 'exp', opt.name)) # increment run
os.makedirs(save_dir / 'labels' if save_txt else save_dir, exist_ok=True) # make new dir (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make new dir
# Load model # Load model
model = attempt_load(weights, map_location=device) # load FP32 model model = attempt_load(weights, map_location=device) # load FP32 model
......
...@@ -38,10 +38,10 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ...@@ -38,10 +38,10 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
logger.info(f'Hyperparameters {hyp}') logger.info(f'Hyperparameters {hyp}')
log_dir = Path(tb_writer.log_dir) if tb_writer else Path(opt.logdir) / 'evolve' # logging directory log_dir = Path(tb_writer.log_dir) if tb_writer else Path(opt.logdir) / 'evolve' # logging directory
wdir = log_dir / 'weights' # weights directory wdir = log_dir / 'weights' # weights directory
os.makedirs(wdir, exist_ok=True) wdir.mkdir(parents=True, exist_ok=True)
last = wdir / 'last.pt' last = wdir / 'last.pt'
best = wdir / 'best.pt' best = wdir / 'best.pt'
results_file = str(log_dir / 'results.txt') results_file = log_dir / 'results.txt'
epochs, batch_size, total_batch_size, weights, rank = \ epochs, batch_size, total_batch_size, weights, rank = \
opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
...@@ -121,7 +121,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ...@@ -121,7 +121,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# Logging # Logging
if wandb and wandb.run is None: if wandb and wandb.run is None:
id = ckpt.get('wandb_id') if 'ckpt' in locals() else None id = ckpt.get('wandb_id') if 'ckpt' in locals() else None
wandb_run = wandb.init(config=opt, resume="allow", project="YOLOv5", name=os.path.basename(log_dir), id=id) wandb_run = wandb.init(config=opt, resume="allow", project="YOLOv5", name=log_dir.stem, id=id)
# Resume # Resume
start_epoch, best_fitness = 0, 0.0 start_epoch, best_fitness = 0, 0.0
...@@ -371,7 +371,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ...@@ -371,7 +371,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
n = opt.name if opt.name.isnumeric() else '' n = opt.name if opt.name.isnumeric() else ''
fresults, flast, fbest = log_dir / f'results{n}.txt', wdir / f'last{n}.pt', wdir / f'best{n}.pt' fresults, flast, fbest = log_dir / f'results{n}.txt', wdir / f'last{n}.pt', wdir / f'best{n}.pt'
for f1, f2 in zip([wdir / 'last.pt', wdir / 'best.pt', results_file], [flast, fbest, fresults]): for f1, f2 in zip([wdir / 'last.pt', wdir / 'best.pt', results_file], [flast, fbest, fresults]):
if os.path.exists(f1): if f1.exists():
os.rename(f1, f2) # rename os.rename(f1, f2) # rename
if str(f2).endswith('.pt'): # is *.pt if str(f2).endswith('.pt'): # is *.pt
strip_optimizer(f2) # strip optimizer strip_optimizer(f2) # strip optimizer
...@@ -520,7 +520,7 @@ if __name__ == '__main__': ...@@ -520,7 +520,7 @@ if __name__ == '__main__':
os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
for _ in range(300): # generations to evolve for _ in range(300): # generations to evolve
if os.path.exists('evolve.txt'): # if evolve.txt exists: select best hyps and mutate if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate
# Select parent(s) # Select parent(s)
parent = 'single' # parent selection method: 'single' or 'weighted' parent = 'single' # parent selection method: 'single' or 'weighted'
x = np.loadtxt('evolve.txt', ndmin=2) x = np.loadtxt('evolve.txt', ndmin=2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论