3 min read

Export Yolo V7 to Tensorflow Lite

The process of converting the PyTorch model to Tensorflow Lite

PyTorch model -> ONXX model -> Tensorflow Model -> Tensorflow Lite Model

Start by creating a new virtual environment:

python3 -m venv .venv
source .venv/bin/activate

Install basic requirements:

pip install onnx onnxruntime onnxsim onnx-tf

Clone YOLO v7 repository and download official YOLO v7 PyTorch weights:

git clone https://github.com/WongKinYiu/yolov7.git
cd yolov7
wget https://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7.pt

Run requirements.txt install within yolov7 folder to install pytorch and the rest of libraries:

pip install -r requirements.txt

PyTorch model to ONXX

Run the export script from the yolov7 folder:

python export.py --weights yolov7-tiny.pt --grid --end2end --simplify \
        --topk-all 100 --iou-thres 0.65 --conf-thres 0.35 --img-size 640 640 --max-wh 640

My output (missing onnx_graphsurgeon but it seem it didn't stop the export).

Install required tensorflow libraries

pip install tensorflow tensorflow_probability

Convert onxx model to tensorflow

mkdir tfmodel
onnx-tf convert -i yolov7.onnx -o tfmodel/

the output should be stored in the newly created folder tfmodel

Convert model from Tensorflow to Tensorflow Lite

Create a small python script and name it tf_model_to_tf_lite.py, then copy paste the contents:

import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_saved_model('tfmodel/')
tflite_model = converter.convert()

with open('tfmodel/yolov7_model.tflite', 'wb') as f:
	f.write(tflite_model)

Run the script:

python3 tf_model_to_tf_lite.py

The output of the script should be Tensorflow Lite model named tfmodel/yolov7_model.tflite

Run inference

import cv2
import random
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="tfmodel/yolov7_model.tflite")

# Name of the classes according to class indices.
names = [
    "person",
    "bicycle",
    "car",
    "motorcycle",
    "airplane",
    "bus",
    "train",
    "truck",
    "boat",
    "traffic light",
    "fire hydrant",
    "stop sign",
    "parking meter",
    "bench",
    "bird",
    "cat",
    "dog",
    "horse",
    "sheep",
    "cow",
    "elephant",
    "bear",
    "zebra",
    "giraffe",
    "backpack",
    "umbrella",
    "handbag",
    "tie",
    "suitcase",
    "frisbee",
    "skis",
    "snowboard",
    "sports ball",
    "kite",
    "baseball bat",
    "baseball glove",
    "skateboard",
    "surfboard",
    "tennis racket",
    "bottle",
    "wine glass",
    "cup",
    "fork",
    "knife",
    "spoon",
    "bowl",
    "banana",
    "apple",
    "sandwich",
    "orange",
    "broccoli",
    "carrot",
    "hot dog",
    "pizza",
    "donut",
    "cake",
    "chair",
    "couch",
    "potted plant",
    "bed",
    "dining table",
    "toilet",
    "tv",
    "laptop",
    "mouse",
    "remote",
    "keyboard",
    "cell phone",
    "microwave",
    "oven",
    "toaster",
    "sink",
    "refrigerator",
    "book",
    "clock",
    "vase",
    "scissors",
    "teddy bear",
    "hair drier",
    "toothbrush",
]

# Creating random colors for bounding box visualization.
colors = {
    name: [random.randint(0, 255) for _ in range(3)] for i, name in enumerate(names)
}

img = cv2.imread("data/image2.jpg")
image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
image = np.expand_dims(image, 0)
image = np.ascontiguousarray(image)
im = image.astype(np.float32)
im /= 255

 Allocate tensors.
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test the model on random input data.
input_shape = input_details[0]["shape"]
interpreter.set_tensor(input_details[0]["index"], im)

interpreter.invoke()

# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]["index"])

ori_images = [img.copy()]

for i, (batch_id, x0, y0, x1, y1, cls_id, score) in enumerate(output_data):
    image = ori_images[int(batch_id)]
    box = np.array([x0, y0, x1, y1])
    box -= np.array(dwdh * 2)
    box /= ratio
    box = box.round().astype(np.int32).tolist()
    cls_id = int(cls_id)
    score = round(float(score), 3)
    name = names[cls_id]
    color = colors[name]
    name += " " + str(score)
    cv2.rectangle(image, box[:2], box[2:], color, 2)
    cv2.putText(
        image,
        name,
        (box[0], box[1] + 20),
        cv2.FONT_HERSHEY_SIMPLEX,
        0.75,
        [225, 255, 255],
        thickness=2,
    )
# plt.imshow(ori_images[0])
cv2.imshow("test", ori_images[0])
cv2.waitKey(0)

# Destroys all the windows created
cv2.destroyAllWindows()

Resources

Author image

Igor Rendulic

  • Salt Lake City
Explorer, developer, ... Using this website as the bookmarking service for the things that might become useful in the future.
You've successfully subscribed to Igor Technology
Great! Next, complete checkout for full access to Igor Technology
Welcome back! You've successfully signed in.
Success! Your account is fully activated, you now have access to all content.