Unverified 提交 525f4f86 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Add --optimize argument (#3093)

上级 57b0d3a6
...@@ -30,6 +30,7 @@ if __name__ == '__main__': ...@@ -30,6 +30,7 @@ if __name__ == '__main__':
parser.add_argument('--half', action='store_true', help='FP16 half-precision export') parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True') parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
parser.add_argument('--train', action='store_true', help='model.train() mode') parser.add_argument('--train', action='store_true', help='model.train() mode')
parser.add_argument('--optimize', action='store_true', help='optimize TorchScript for mobile') # TorchScript-only
parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') # ONNX-only parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') # ONNX-only
parser.add_argument('--simplify', action='store_true', help='simplify ONNX model') # ONNX-only parser.add_argument('--simplify', action='store_true', help='simplify ONNX model') # ONNX-only
opt = parser.parse_args() opt = parser.parse_args()
...@@ -78,7 +79,7 @@ if __name__ == '__main__': ...@@ -78,7 +79,7 @@ if __name__ == '__main__':
print(f'\n{prefix} starting export with torch {torch.__version__}...') print(f'\n{prefix} starting export with torch {torch.__version__}...')
f = opt.weights.replace('.pt', '.torchscript.pt') # filename f = opt.weights.replace('.pt', '.torchscript.pt') # filename
ts = torch.jit.trace(model, img, strict=False) ts = torch.jit.trace(model, img, strict=False)
optimize_for_mobile(ts).save(f) # https://pytorch.org/tutorials/recipes/script_optimized.html (optimize_for_mobile(ts) if opt.optimize else ts).save(f)
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e: except Exception as e:
print(f'{prefix} export failure: {e}') print(f'{prefix} export failure: {e}')
...@@ -123,7 +124,6 @@ if __name__ == '__main__': ...@@ -123,7 +124,6 @@ if __name__ == '__main__':
import coremltools as ct import coremltools as ct
print(f'{prefix} starting export with coremltools {ct.__version__}...') print(f'{prefix} starting export with coremltools {ct.__version__}...')
# convert model from torchscript and apply pixel scaling as per detect.py
model = ct.convert(ts, inputs=[ct.ImageType(name='image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])]) model = ct.convert(ts, inputs=[ct.ImageType(name='image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
f = opt.weights.replace('.pt', '.mlmodel') # filename f = opt.weights.replace('.pt', '.mlmodel') # filename
model.save(f) model.save(f)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论