33 lines
956 B
Python
33 lines
956 B
Python
from tflite_runtime.interpreter import Interpreter
|
|
import numpy as np
|
|
|
|
# Load the TFLite model
|
|
interpreter = Interpreter(model_path="Iris.tflite")
|
|
interpreter.allocate_tensors()
|
|
|
|
# Get input and output details
|
|
input_details = interpreter.get_input_details()
|
|
output_details = interpreter.get_output_details()
|
|
|
|
# Print details
|
|
print("Input details:", input_details)
|
|
print("Output details:", output_details)
|
|
|
|
# Dummy input data
|
|
input_shape = input_details[0]['shape'] # e.g., [1, 224, 224, 3]
|
|
input_data = np.array([[5.1, 3.5, 1.4, 0.2]]).astype(input_details[0]['dtype'])
|
|
# Set input tensor
|
|
interpreter.set_tensor(input_details[0]['index'], input_data)
|
|
|
|
# Run inference
|
|
interpreter.invoke()
|
|
|
|
# Get output tensor
|
|
output_data = interpreter.get_tensor(output_details[0]['index'])
|
|
print("Predictions:", output_data)
|
|
|
|
classes = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
|
|
final = np.argmax(output_data)
|
|
output_class = classes[final]
|
|
print(output_class)
|