Unverified 提交 43b2817f authored 作者: Kalen Michael's avatar Kalen Michael 提交者: GitHub

Feature/fix export on url (#4823)

* added callbacks * added back callback to main * added save_dir to callback output * merged in upstream * removed ghost code * added url check * Add url2file() * Update file-only Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 0dc725e3
...@@ -41,7 +41,7 @@ from models.experimental import attempt_load ...@@ -41,7 +41,7 @@ from models.experimental import attempt_load
from models.yolo import Detect from models.yolo import Detect
from utils.activations import SiLU from utils.activations import SiLU
from utils.datasets import LoadImages from utils.datasets import LoadImages
from utils.general import colorstr, check_dataset, check_img_size, check_requirements, file_size, set_logging from utils.general import colorstr, check_dataset, check_img_size, check_requirements, file_size, set_logging, url2file
from utils.torch_utils import select_device from utils.torch_utils import select_device
...@@ -244,7 +244,7 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' ...@@ -244,7 +244,7 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
include = [x.lower() for x in include] include = [x.lower() for x in include]
tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'tfjs')) # TensorFlow exports tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'tfjs')) # TensorFlow exports
imgsz *= 2 if len(imgsz) == 1 else 1 # expand imgsz *= 2 if len(imgsz) == 1 else 1 # expand
file = Path(weights) file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights)
# Load PyTorch model # Load PyTorch model
device = select_device(device) device = select_device(device)
......
...@@ -360,6 +360,13 @@ def check_dataset(data, autodownload=True): ...@@ -360,6 +360,13 @@ def check_dataset(data, autodownload=True):
return data # dictionary return data # dictionary
def url2file(url):
# Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
file = Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
return file
def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1): def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
# Multi-threaded file download and unzip function, used in data.yaml for autodownload # Multi-threaded file download and unzip function, used in data.yaml for autodownload
def download_one(url, dir): def download_one(url, dir):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论