import argparse
import json
import logging
import os

import numpy as np
import pandas as pd
import torch
from PIL import Image

# from fuzzywuzzy import fuzz
from rapidfuzz import fuzz
from sentence_transformers import SentenceTransformer, util
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm

# Configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)


def load_model(model_name="sentence-transformers/distiluse-base-multilingual-cased-v2"):
    model = SentenceTransformer(model_name)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    logging.info(f"Model loaded to {device}")
    return model, device


def batch_encode_texts(model, texts, device, batch_size=32):
    """
    Encodes a batch of texts into embeddings using GPU.

    Args:
        texts (list): List of text descriptions or queries.
        batch_size (int): Number of texts to process per batch.

    Returns:
        torch.Tensor: Tensor containing encoded embeddings.
    """
    embeddings = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Encoding batches"):
        batch = texts[i : i + batch_size]
        batch_embeddings = model.encode(batch, convert_to_tensor=True).to(device)
        embeddings.append(batch_embeddings)
    return torch.cat(embeddings, dim=0)


def compute_similarity(desc_embeddings, query_embeddings):
    """
    Computes similarity scores between corresponding image descriptions and queries.

    Args:
        desc_embeddings (torch.Tensor): Precomputed embeddings for image descriptions.
        query_embeddings (torch.Tensor): Precomputed embeddings for queries.

    Returns:
        torch.Tensor: Tensor of cosine similarity scores (1-to-1 match).
    """
    return util.pytorch_cos_sim(desc_embeddings, query_embeddings).diag()


def read_image_objects(input_file):
    data = []
    with open(input_file, "r", encoding="utf-8") as f:
        for line in f:
            json_object = json.loads(line)
            data.append(json_object)
    return data


def process_images_in_batches(
    image_objects,
    model,
    device,
    batch_size=32,
):
    total_batches = (len(image_objects) + batch_size - 1) // batch_size
    logging.info(f"Processing {len(image_objects)} images in batches of {batch_size}")

    queries = []
    descriptions = []
    for img_object in image_objects:
        category = img_object["category"]
        subcategory = img_object["subcategory"]
        topic = img_object["topic"]
        meta_data = (
            category + " \n" + subcategory + " \n" + topic + " \n" + img_object["query"]
        )
        queries.append(meta_data)
        descriptions.append(img_object["response_img_desc"]["description"])

    logging.info(f"Encoding image descriptions...")
    desc_embeddings = batch_encode_texts(
        model, descriptions, device, batch_size=batch_size
    )
    logging.info(f"Encoding queries...")
    query_embeddings = batch_encode_texts(model, queries, device, batch_size=batch_size)
    logging.info(f"Computing similarity scores...")
    similarity_scores = (
        compute_similarity(desc_embeddings, query_embeddings).cpu().numpy()
    )

    for i, img_object in enumerate(image_objects):
        img_object["image_desc_query_sim"] = round(similarity_scores[i], 4)


class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, (np.integer)):
            return int(obj)
        elif isinstance(obj, (np.floating)):
            return float(obj)
        elif isinstance(obj, (np.ndarray)):
            return obj.tolist()
        return super().default(obj)


def save_to_jsonl(jsonl_save_path, image_objects):
    with open(jsonl_save_path, "w") as jsonl_file:
        for item in image_objects:
            jsonl_file.write(
                json.dumps(item, cls=NumpyEncoder, ensure_ascii=False) + "\n"
            )
    logging.info(f"Saved results to {jsonl_save_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Location Matching Script")
    parser.add_argument(
        "--input_file",
        type=str,
        required=True,
        help="Path to the input file containing image paths",
    )
    parser.add_argument(
        "--output_file",
        type=str,
        required=True,
        help="Path to the output JSONL file to save results",
    )
    parser.add_argument(
        "--model_name",
        type=str,
        default="sentence-transformers/distiluse-base-multilingual-cased-v2",
        help="Model name to use for location matching",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=32,
        help="Batch size",
    )

    args = parser.parse_args()

    image_objects = read_image_objects(args.input_file)
    model, device = load_model(model_name=args.model_name)

    process_images_in_batches(image_objects, model, device, args.batch_size)

    save_to_jsonl(args.output_file, image_objects)
    logging.info("Location matching finished")
