blob: e0b36d3d3ee94b00cccd3968d14c63fe19c3c27c [file] [log] [blame]
# ==============================================================================
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Upgrade script to move from pre-release schema to new schema.
Usage examples:
bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.json out.json
bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.bin out.bin
bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.bin out.json
bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.json out.bin
bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.tflite out.tflite
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import contextlib
import json
import os
import shutil
import subprocess
import sys
import tempfile
import tensorflow as tf
from tensorflow.python.platform import resource_loader
parser = argparse.ArgumentParser(
description="Script to move TFLite models from pre-release schema to "
"new schema.")
parser.add_argument(
"input",
type=str,
help="Input TensorFlow lite file in `.json`, `.bin` or `.tflite` format.")
parser.add_argument(
"output",
type=str,
help="Output json or bin TensorFlow lite model compliant with "
"the new schema. Extension must be `.json`, `.bin` or `.tflite`.")
# RAII Temporary Directory, because flatc doesn't allow direct use of tempfiles.
@contextlib.contextmanager
def TemporaryDirectoryResource():
temporary = tempfile.mkdtemp()
try:
yield temporary
finally:
shutil.rmtree(temporary)
class Converter(object):
"""Converts TensorFlow flatbuffer models from old to new version of schema.
This can convert between any version to the latest version. It uses
an incremental upgrade strategy to go from version to version.
Usage:
converter = Converter()
converter.Convert("a.tflite", "a.json")
converter.Convert("b.json", "b.tflite")
"""
def __init__(self):
# TODO(aselle): make this work in the open source version with better
# path.
paths_to_try = [
"../../../../flatbuffers/flatc", # not bazel
"../../../../external/flatbuffers/flatc" # bazel
]
for p in paths_to_try:
self._flatc_path = resource_loader.get_path_to_datafile(p)
if os.path.exists(self._flatc_path): break
def FindSchema(base_name):
return resource_loader.get_path_to_datafile("%s" % base_name)
# Supported schemas for upgrade.
self._schemas = [
(0, FindSchema("schema_v0.fbs"), True, self._Upgrade0To1),
(1, FindSchema("schema_v1.fbs"), True, self._Upgrade1To2),
(2, FindSchema("schema_v2.fbs"), True, self._Upgrade2To3),
(3, FindSchema("schema_v3.fbs"), False, None) # Non-callable by design.
]
# Ensure schemas are sorted, and extract latest version and upgrade
# dispatch function table.
self._schemas.sort()
self._new_version, self._new_schema = self._schemas[-1][:2]
self._upgrade_dispatch = dict(
(version, dispatch)
for version, unused1, unused2, dispatch in self._schemas)
def _Read(self, input_file, schema, raw_binary=False):
"""Read a tflite model assuming the given flatbuffer schema.
If `input_file` is in bin, then we must use flatc to convert the schema
from binary to json.
Args:
input_file: a binary (flatbuffer) or json file to read from. Extension
must be `.tflite`, `.bin`, or `.json` for FlatBuffer Binary or
FlatBuffer JSON.
schema: which schema to use for reading
raw_binary: whether to assume raw_binary (versions previous to v3)
that lacked file_identifier require this.
Raises:
RuntimeError: When flatc cannot be invoked.
ValueError: When the extension is not json or bin.
Returns:
A dictionary representing the read tflite model.
"""
raw_binary = ["--raw-binary"] if raw_binary else []
with TemporaryDirectoryResource() as tempdir:
basename = os.path.basename(input_file)
basename_no_extension, extension = os.path.splitext(basename)
if extension in [".bin", ".tflite"]:
# Convert to json using flatc
returncode = subprocess.call([
self._flatc_path,
"-t",
"--strict-json",
"--defaults-json",
] + raw_binary + ["-o", tempdir, schema, "--", input_file])
if returncode != 0:
raise RuntimeError("flatc failed to convert from binary to json.")
json_file = os.path.join(tempdir, basename_no_extension + ".json")
if not os.path.exists(json_file):
raise RuntimeError("Could not find %r" % json_file)
elif extension == ".json":
json_file = input_file
else:
raise ValueError("Invalid extension on input file %r" % input_file)
return json.load(open(json_file))
def _Write(self, data, output_file):
"""Output a json or bin version of the flatbuffer model.
Args:
data: Dict representing the TensorFlow Lite model to write.
output_file: filename to write the converted flatbuffer to. (json,
tflite, or bin extension is required).
Raises:
ValueError: When the extension is not json or bin
RuntimeError: When flatc fails to convert json data to binary.
"""
_, extension = os.path.splitext(output_file)
with TemporaryDirectoryResource() as tempdir:
if extension == ".json":
json.dump(data, open(output_file, "w"), sort_keys=True, indent=2)
elif extension in [".tflite", ".bin"]:
input_json = os.path.join(tempdir, "temp.json")
with open(input_json, "w") as fp:
json.dump(data, fp, sort_keys=True, indent=2)
returncode = subprocess.call([
self._flatc_path, "-b", "--defaults-json", "--strict-json", "-o",
tempdir, self._new_schema, input_json
])
if returncode != 0:
raise RuntimeError("flatc failed to convert upgraded json to binary.")
shutil.copy(os.path.join(tempdir, "temp.tflite"), output_file)
else:
raise ValueError("Invalid extension on output file %r" % output_file)
def _Upgrade0To1(self, data):
"""Upgrade data from Version 0 to Version 1.
Changes: Added subgraphs (which contains a subset of formally global
entries).
Args:
data: Dictionary representing the TensorFlow lite data to be upgraded.
This will be modified in-place to be an upgraded version.
"""
subgraph = {}
for key_to_promote in ["tensors", "operators", "inputs", "outputs"]:
subgraph[key_to_promote] = data[key_to_promote]
del data[key_to_promote]
data["subgraphs"] = [subgraph]
def _Upgrade1To2(self, data):
"""Upgrade data from Version 1 to Version 2.
Changes: Rename operators to Conform to NN API.
Args:
data: Dictionary representing the TensorFlow lite data to be upgraded.
This will be modified in-place to be an upgraded version.
Raises:
ValueError: Throws when model builtins are numeric rather than symbols.
"""
def RemapOperator(opcode_name):
"""Go from old schema op name to new schema op name.
Args:
opcode_name: String representing the ops (see :schema.fbs).
Returns:
Converted opcode_name from V1 to V2.
"""
old_name_to_new_name = {
"CONVOLUTION": "CONV_2D",
"DEPTHWISE_CONVOLUTION": "DEPTHWISE_CONV_2D",
"AVERAGE_POOL": "AVERAGE_POOL_2D",
"MAX_POOL": "MAX_POOL_2D",
"L2_POOL": "L2_POOL_2D",
"SIGMOID": "LOGISTIC",
"L2NORM": "L2_NORMALIZATION",
"LOCAL_RESPONSE_NORM": "LOCAL_RESPONSE_NORMALIZATION",
"Basic_RNN": "RNN",
}
return (old_name_to_new_name[opcode_name]
if opcode_name in old_name_to_new_name else opcode_name)
def RemapOperatorType(operator_type):
"""Remap operator structs from old names to new names.
Args:
operator_type: String representing the builtin operator data type
string.
(see :schema.fbs).
Returns:
Upgraded builtin operator data type as a string.
"""
old_to_new = {
"PoolOptions": "Pool2DOptions",
"DepthwiseConvolutionOptions": "DepthwiseConv2DOptions",
"ConvolutionOptions": "Conv2DOptions",
"LocalResponseNormOptions": "LocalResponseNormalizationOptions",
"BasicRNNOptions": "RNNOptions",
}
return (old_to_new[operator_type]
if operator_type in old_to_new else operator_type)
for subgraph in data["subgraphs"]:
for ops in subgraph["operators"]:
ops["builtin_options_type"] = RemapOperatorType(
ops["builtin_options_type"])
# Upgrade the operator codes
for operator_code in data["operator_codes"]:
# Check if builtin_code is the appropriate string type
# use type("") instead of str or unicode. for py2and3
if not isinstance(operator_code["builtin_code"], type(u"")):
raise ValueError("builtin_code %r is non-string. this usually means "
"your model has consistency problems." %
(operator_code["builtin_code"]))
operator_code["builtin_code"] = (RemapOperator(
operator_code["builtin_code"]))
def _Upgrade2To3(self, data):
"""Upgrade data from Version 2 to Version 3.
Changed actual read-only tensor data to be in a buffers table instead
of inline with the tensor.
Args:
data: Dictionary representing the TensorFlow lite data to be upgraded.
This will be modified in-place to be an upgraded version.
"""
buffers = [{"data": []}] # Start with 1 empty buffer
for subgraph in data["subgraphs"]:
if "tensors" not in subgraph:
continue
for tensor in subgraph["tensors"]:
if "data_buffer" not in tensor:
tensor["buffer"] = 0
else:
if tensor["data_buffer"]:
tensor[u"buffer"] = len(buffers)
buffers.append({"data": tensor["data_buffer"]})
else:
tensor["buffer"] = 0
del tensor["data_buffer"]
data["buffers"] = buffers
def _PerformUpgrade(self, data):
"""Manipulate the `data` (parsed JSON) based on changes in format.
This incrementally will upgrade from version to version within data.
Args:
data: Dictionary representing the TensorFlow data. This will be upgraded
in place.
"""
while data["version"] < self._new_version:
self._upgrade_dispatch[data["version"]](data)
data["version"] += 1
def Convert(self, input_file, output_file):
"""Perform schema conversion from input_file to output_file.
Args:
input_file: Filename of TensorFlow Lite data to convert from. Must
be `.json` or `.bin` extension files for JSON or Binary forms of
the TensorFlow FlatBuffer schema.
output_file: Filename to write to. Extension also must be `.json`
or `.bin`.
Raises:
RuntimeError: Generated when none of the upgrader supported schemas
matche the `input_file` data.
"""
# Read data in each schema (since they are incompatible). Version is
# always present. Use the read data that matches the version of the
# schema.
for version, schema, raw_binary, _ in self._schemas:
try:
data_candidate = self._Read(input_file, schema, raw_binary)
except RuntimeError:
continue # Skip and hope another schema works
if "version" not in data_candidate: # Assume version 1 if not present.
data_candidate["version"] = 1
elif data_candidate["version"] == 0: # Version 0 doesn't exist in wild.
data_candidate["version"] = 1
if data_candidate["version"] == version:
self._PerformUpgrade(data_candidate)
self._Write(data_candidate, output_file)
return
raise RuntimeError("No schema that the converter understands worked with "
"the data file you provided.")
def main(argv):
del argv
Converter().Convert(FLAGS.input, FLAGS.output)
if __name__ == "__main__":
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)