TensorFlow Lite

TensorFlow Lite for deploying models on mobile and embedded devices.


Convert Model to TFLite

 1import tensorflow as tf
 2
 3# Load Keras model
 4model = tf.keras.models.load_model('my_model.h5')
 5
 6# Convert to TFLite
 7converter = tf.lite.TFLiteConverter.from_keras_model(model)
 8tflite_model = converter.convert()
 9
10# Save
11with open('model.tflite', 'wb') as f:
12    f.write(tflite_model)

Optimization

 1# Post-training quantization (Dynamic range)
 2converter = tf.lite.TFLiteConverter.from_keras_model(model)
 3converter.optimizations = [tf.lite.Optimize.DEFAULT]
 4tflite_quant_model = converter.convert()
 5
 6# Float16 quantization
 7converter.optimizations = [tf.lite.Optimize.DEFAULT]
 8converter.target_spec.supported_types = [tf.float16]
 9tflite_fp16_model = converter.convert()
10
11# Integer quantization
12def representative_dataset():
13    for data in dataset.take(100):
14        yield [tf.dtypes.cast(data, tf.float32)]
15
16converter.optimizations = [tf.lite.Optimize.DEFAULT]
17converter.representative_dataset = representative_dataset
18converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
19converter.inference_input_type = tf.int8
20converter.inference_output_type = tf.int8
21tflite_int8_model = converter.convert()

Run Inference (Python)

 1import numpy as np
 2
 3# Load model
 4interpreter = tf.lite.Interpreter(model_path='model.tflite')
 5interpreter.allocate_tensors()
 6
 7# Get input and output details
 8input_details = interpreter.get_input_details()
 9output_details = interpreter.get_output_details()
10
11# Prepare input
12input_shape = input_details[0]['shape']
13input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
14
15# Run inference
16interpreter.set_tensor(input_details[0]['index'], input_data)
17interpreter.invoke()
18
19# Get output
20output_data = interpreter.get_tensor(output_details[0]['index'])
21print(output_data)

Android Integration

 1// build.gradle
 2dependencies {
 3    implementation 'org.tensorflow:tensorflow-lite:2.13.0'
 4}
 5
 6// Kotlin code
 7import org.tensorflow.lite.Interpreter
 8import java.nio.MappedByteBuffer
 9
10class TFLiteModel(private val modelBuffer: MappedByteBuffer) {
11    private val interpreter = Interpreter(modelBuffer)
12    
13    fun predict(input: FloatArray): FloatArray {
14        val output = Array(1) { FloatArray(10) }
15        interpreter.run(input, output)
16        return output[0]
17    }
18    
19    fun close() {
20        interpreter.close()
21    }
22}

iOS Integration

 1import TensorFlowLite
 2
 3class TFLiteModel {
 4    private var interpreter: Interpreter
 5    
 6    init(modelPath: String) throws {
 7        interpreter = try Interpreter(modelPath: modelPath)
 8        try interpreter.allocateTensors()
 9    }
10    
11    func predict(input: [Float]) throws -> [Float] {
12        let inputData = Data(copyingBufferOf: input)
13        try interpreter.copy(inputData, toInputAt: 0)
14        try interpreter.invoke()
15        
16        let outputTensor = try interpreter.output(at: 0)
17        let results = [Float](unsafeData: outputTensor.data) ?? []
18        return results
19    }
20}

Benchmark

1# Install benchmark tool
2pip install tensorflow
3
4# Benchmark model
5python -m tensorflow.lite.tools.benchmark_model \
6    --graph=model.tflite \
7    --num_threads=4

Related Snippets