Unverified 提交 54f49fa5 authored 作者: paradigm's avatar paradigm 提交者: GitHub

Add TFLite Metadata to TFLite and Edge TPU models (#9903)

* added embedded meta data to tflite models * added try block for inference * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactored tfite meta data into separate function * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Creat tmp file in /tmp * Update export.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update export.py * Update export.py * Update export.py * Update export.py * Update common.py * Update export.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update common.py Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 fba61e55
...@@ -45,6 +45,7 @@ TensorFlow.js: ...@@ -45,6 +45,7 @@ TensorFlow.js:
""" """
import argparse import argparse
import contextlib
import json import json
import os import os
import platform import platform
...@@ -453,6 +454,39 @@ def export_tfjs(file, prefix=colorstr('TensorFlow.js:')): ...@@ -453,6 +454,39 @@ def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
return f, None return f, None
def add_tflite_metadata(file, metadata, num_outputs):
# Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata
with contextlib.suppress(ImportError):
# check_requirements('tflite_support')
from tflite_support import flatbuffers
from tflite_support import metadata as _metadata
from tflite_support import metadata_schema_py_generated as _metadata_fb
tmp_file = Path('/tmp/meta.txt')
with open(tmp_file, 'w') as meta_f:
meta_f.write(str(metadata))
model_meta = _metadata_fb.ModelMetadataT()
label_file = _metadata_fb.AssociatedFileT()
label_file.name = tmp_file.name
model_meta.associatedFiles = [label_file]
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()]
subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * num_outputs
model_meta.subgraphMetadata = [subgraph]
b = flatbuffers.Builder(0)
b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_buf = b.Output()
populator = _metadata.MetadataPopulator.with_model_file(file)
populator.load_metadata_buffer(metadata_buf)
populator.load_associated_files([str(tmp_file)])
populator.populate()
tmp_file.unlink()
@smart_inference_mode() @smart_inference_mode()
def run( def run(
data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
...@@ -550,8 +584,9 @@ def run( ...@@ -550,8 +584,9 @@ def run(
f[6], _ = export_pb(s_model, file) f[6], _ = export_pb(s_model, file)
if tflite or edgetpu: if tflite or edgetpu:
f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms) f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
if edgetpu: if edgetpu:
f[8], _ = export_edgetpu(file) f[8], _ = export_edgetpu(file)
add_tflite_metadata(f[8] or f[7], metadata, num_outputs=len(s_model.outputs))
if tfjs: if tfjs:
f[9], _ = export_tfjs(file) f[9], _ = export_tfjs(file)
if paddle: # PaddlePaddle if paddle: # PaddlePaddle
......
...@@ -3,10 +3,13 @@ ...@@ -3,10 +3,13 @@
Common modules Common modules
""" """
import ast
import contextlib
import json import json
import math import math
import platform import platform
import warnings import warnings
import zipfile
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
from copy import copy from copy import copy
from pathlib import Path from pathlib import Path
...@@ -462,6 +465,12 @@ class DetectMultiBackend(nn.Module): ...@@ -462,6 +465,12 @@ class DetectMultiBackend(nn.Module):
interpreter.allocate_tensors() # allocate interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs output_details = interpreter.get_output_details() # outputs
# load metadata
with contextlib.suppress(zipfile.BadZipFile):
with zipfile.ZipFile(w, "r") as model:
meta_file = model.namelist()[0]
meta = ast.literal_eval(model.read(meta_file).decode("utf-8"))
stride, names = int(meta['stride']), meta['names']
elif tfjs: # TF.js elif tfjs: # TF.js
raise NotImplementedError('ERROR: YOLOv5 TF.js inference is not supported') raise NotImplementedError('ERROR: YOLOv5 TF.js inference is not supported')
elif paddle: # PaddlePaddle elif paddle: # PaddlePaddle
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论