# SAMPLE USAGE:
# python3 vec_generator_easyocr.py images_dir saving_chars.txt vec.normalized

import os
import sys
import torch
import numpy as np
from PIL import Image
import easyocr

reader = easyocr.Reader(['en'], gpu=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = reader.recognizer.module.FeatureExtraction.to(device)
encoder.eval()

def preprocess_image(image_path):
    with Image.open(image_path) as img:
        img = img.convert("L")
        img = img.resize((32, 32))
        img = np.array(img) / 255.0
        img = torch.FloatTensor(img).unsqueeze(0).unsqueeze(0).to(device)
    return img

def generate_vectors(image_dir, saving_chars_file, output_file):
    char_mapping = {}
    with open(saving_chars_file, "r", encoding="utf-8") as file:
        for line in file:
            parts = line.strip().split(" ")
            if len(parts) == 3:
                unicode_text, _, image_name = parts
                char_mapping[image_name] = unicode_text

    image_files = [f for f in sorted(os.listdir(image_dir)) if f.endswith(".png")]
    number_of_vectors = len(image_files)
    vector_dimension = 256

    with open(output_file, "w", encoding="utf-8") as f:
        f.write(f"{number_of_vectors} {vector_dimension}\n")

    for image_name in image_files:
        image_path = os.path.join(image_dir, image_name)
        char = char_mapping.get(image_name)
        if not char:
            print(f"Warning: Character not found for image {image_name}. Skipping.")
            continue

        try:
            img_tensor = preprocess_image(image_path)
            with torch.no_grad():
                feature_vector = encoder(img_tensor).mean(dim=(2, 3)).squeeze().cpu().numpy()
                vector_str = " ".join(map(str, feature_vector))
                with open(output_file, "a", encoding="utf-8") as f:
                    f.write(f"{char} {vector_str}\n")
        except Exception as e:
            print(f"Error processing image {image_name}: {e}")

    print(f"Vector file '{output_file}' has been created in Word2Vec format.")

if __name__ == "__main__":
    if len(sys.argv) != 4:
        print("Usage: python3 vec_generator.py [images_dir] [saving_chars.txt] [output_file]")
        sys.exit(1)

    image_dir = sys.argv[1]
    saving_chars_file = sys.argv[2]
    output_file = sys.argv[3]

    if not os.path.exists(image_dir):
        print(f"Error: Image directory '{image_dir}' does not exist.")
        sys.exit(1)

    if not os.path.exists(saving_chars_file):
        print(f"Error: Saving chars file '{saving_chars_file}' does not exist.")
        sys.exit(1)

    generate_vectors(image_dir, saving_chars_file, output_file)