Unverified 提交 e77c77f5 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Add check_requirements() (#1853)

* Add check_requirements() * add import * parameterize filename * add to detect, test
上级 135ec5c5
...@@ -9,8 +9,8 @@ from numpy import random ...@@ -9,8 +9,8 @@ from numpy import random
from models.experimental import attempt_load from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, non_max_suppression, apply_classifier, scale_coords, xyxy2xywh, \ from utils.general import check_img_size, check_requirements, non_max_suppression, apply_classifier, scale_coords, \
strip_optimizer, set_logging, increment_path xyxy2xywh, strip_optimizer, set_logging, increment_path
from utils.plots import plot_one_box from utils.plots import plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized from utils.torch_utils import select_device, load_classifier, time_synchronized
...@@ -162,6 +162,7 @@ if __name__ == '__main__': ...@@ -162,6 +162,7 @@ if __name__ == '__main__':
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
opt = parser.parse_args() opt = parser.parse_args()
print(opt) print(opt)
check_requirements()
with torch.no_grad(): with torch.no_grad():
if opt.update: # update all models (to fix SourceChangeWarning) if opt.update: # update all models (to fix SourceChangeWarning)
......
...@@ -11,8 +11,8 @@ from tqdm import tqdm ...@@ -11,8 +11,8 @@ from tqdm import tqdm
from models.experimental import attempt_load from models.experimental import attempt_load
from utils.datasets import create_dataloader from utils.datasets import create_dataloader
from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, box_iou, \ from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, check_requirements, \
non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path box_iou, non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path
from utils.loss import compute_loss from utils.loss import compute_loss
from utils.metrics import ap_per_class, ConfusionMatrix from utils.metrics import ap_per_class, ConfusionMatrix
from utils.plots import plot_images, output_to_target, plot_study_txt from utils.plots import plot_images, output_to_target, plot_study_txt
...@@ -302,6 +302,7 @@ if __name__ == '__main__': ...@@ -302,6 +302,7 @@ if __name__ == '__main__':
opt.save_json |= opt.data.endswith('coco.yaml') opt.save_json |= opt.data.endswith('coco.yaml')
opt.data = check_file(opt.data) # check file opt.data = check_file(opt.data) # check file
print(opt) print(opt)
check_requirements()
if opt.task in ['val', 'test']: # run normally if opt.task in ['val', 'test']: # run normally
test(opt.data, test(opt.data,
......
...@@ -28,7 +28,7 @@ from utils.autoanchor import check_anchors ...@@ -28,7 +28,7 @@ from utils.autoanchor import check_anchors
from utils.datasets import create_dataloader from utils.datasets import create_dataloader
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \ from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \ fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
print_mutation, set_logging, one_cycle check_requirements, print_mutation, set_logging, one_cycle
from utils.google_utils import attempt_download from utils.google_utils import attempt_download
from utils.loss import compute_loss from utils.loss import compute_loss
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
...@@ -472,6 +472,7 @@ if __name__ == '__main__': ...@@ -472,6 +472,7 @@ if __name__ == '__main__':
set_logging(opt.global_rank) set_logging(opt.global_rank)
if opt.global_rank in [-1, 0]: if opt.global_rank in [-1, 0]:
check_git_status() check_git_status()
check_requirements()
# Resume # Resume
if opt.resume: # resume an interrupted run if opt.resume: # resume an interrupted run
......
...@@ -53,6 +53,14 @@ def check_git_status(): ...@@ -53,6 +53,14 @@ def check_git_status():
print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n') print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
def check_requirements(file='requirements.txt'):
# Check installed dependencies meet requirements
import pkg_resources
requirements = pkg_resources.parse_requirements(Path(file).open())
requirements = [x.name + ''.join(*x.specs) if len(x.specs) else x.name for x in requirements]
pkg_resources.require(requirements) # DistributionNotFound or VersionConflict exception if requirements not met
def check_img_size(img_size, s=32): def check_img_size(img_size, s=32):
# Verify img_size is a multiple of stride s # Verify img_size is a multiple of stride s
new_size = make_divisible(img_size, int(s)) # ceil gs-multiple new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论