Unverified 提交 94557969 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Remove `formats` variable to avoid `pd` conflict (#7993)

* Remove `formats` variable to avoid `pd` conflict * Update export.py
上级 1dcb7749
...@@ -475,9 +475,9 @@ def run( ...@@ -475,9 +475,9 @@ def run(
): ):
t = time.time() t = time.time()
include = [x.lower() for x in include] # to lowercase include = [x.lower() for x in include] # to lowercase
formats = tuple(export_formats()['Argument'][1:]) # --include arguments fmts = tuple(export_formats()['Argument'][1:]) # --include arguments
flags = [x in include for x in formats] flags = [x in include for x in fmts]
assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {formats}' assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}'
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights
...@@ -499,7 +499,7 @@ def run( ...@@ -499,7 +499,7 @@ def run(
im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
# Update model # Update model
if half and not (coreml or xml): if half and not coreml and not xml:
im, model = im.half(), model.half() # to FP16 im, model = im.half(), model.half() # to FP16
model.train() if train else model.eval() # training mode = no Detect() layer grid construction model.train() if train else model.eval() # training mode = no Detect() layer grid construction
for k, m in model.named_modules(): for k, m in model.named_modules():
...@@ -531,7 +531,7 @@ def run( ...@@ -531,7 +531,7 @@ def run(
if any((saved_model, pb, tflite, edgetpu, tfjs)): if any((saved_model, pb, tflite, edgetpu, tfjs)):
if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707 if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow` check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow`
assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.' assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'
model, f[5] = export_saved_model(model.cpu(), model, f[5] = export_saved_model(model.cpu(),
im, im,
file, file,
......
...@@ -56,9 +56,8 @@ def run( ...@@ -56,9 +56,8 @@ def run(
pt_only=False, # test PyTorch only pt_only=False, # test PyTorch only
): ):
y, t = [], time.time() y, t = [], time.time()
formats = export.export_formats()
device = select_device(device) device = select_device(device)
for i, (name, f, suffix, gpu) in formats.iterrows(): # index, (name, file, suffix, gpu-capable) for i, (name, f, suffix, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, gpu-capable)
try: try:
assert i != 9, 'Edge TPU not supported' assert i != 9, 'Edge TPU not supported'
assert i != 10, 'TF.js not supported' assert i != 10, 'TF.js not supported'
...@@ -104,9 +103,8 @@ def test( ...@@ -104,9 +103,8 @@ def test(
pt_only=False, # test PyTorch only pt_only=False, # test PyTorch only
): ):
y, t = [], time.time() y, t = [], time.time()
formats = export.export_formats()
device = select_device(device) device = select_device(device)
for i, (name, f, suffix, gpu) in formats.iterrows(): # index, (name, file, suffix, gpu-capable) for i, (name, f, suffix, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, gpu-capable)
try: try:
w = weights if f == '-' else \ w = weights if f == '-' else \
export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # weights export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # weights
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论