import base64
import os
import random
from pydoc import text
from networkx import capacity_scaling
import requests
import torch
from PIL import Image
from io import BytesIO
import pyarrow as pa
import faiss
import open_clip
import pandas as pd
from open_clip import create_model_from_pretrained, get_tokenizer
from sentence_transformers import SentenceTransformer
from search.search_api import SearchAPI


class ArrowMetadataProvider:
    """The Arrow metadata provider provides metadata from contiguous ids using Arrow"""

    def __init__(self, arrow_file):
        self.table = pa.ipc.RecordBatchFileReader(
            pa.memory_map(arrow_file, "r")
        ).read_all()

    def get(self, ids, cols=None):
        """Get metadata for the specified ids"""
        if cols is None:
            cols = self.table.schema.names
        else:
            cols = list(set(self.table.schema.names) & set(cols))
        t = pa.concat_tables([self.table[i : i + 1] for i in ids])
        return t.select(cols).to_pandas().to_dict("records")


class CsvMetadataProvider:
    """The Csv metadata provider provides metadata from contiguous ids using Arrow"""

    def __init__(self, csv_file):
        # Load metadata from CSV
        self.metadata_df = pd.read_csv(csv_file)

    def get(self, ids, cols=None):
        if cols is None:
            cols = ["image_path", "caption"]
        # Retrieve metadata based on ids
        return self.metadata_df.loc[ids, cols].to_dict(orient="records")


class FaissSearch(SearchAPI):
    def __init__(
        self,
        index_path="/drl_nas2/ckddls1321/data/laion5b-index/populated.index",
        metadata_path="/drl_nas2/ckddls1321/data/laion5b-index/metadata.arrow",
        model_name="hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224",
        text_model="Alibaba-NLP/gte-modernbert-base",
        search_caption=False,
        top_k=5,
    ):
        """
        Initializes the FaissSearch with FAISS index, metadata, and model.

        Args:
            index_path (str): Path to the FAISS index file.
            metadata_path (str): Path to the Arrow metadata file.
            model_name (str): Name of the model to use for encoding.
            pretrained (str): Pretrained weights to load for the model.
        """
        super().__init__()  # Initialize the parent class
        self.top_k = top_k
        self.search_caption = search_caption
        if os.path.exists(index_path):
            self.index = faiss.read_index(index_path)
            print(f"Loaded FAISS index from: {index_path}")
        else:
            raise FileNotFoundError(f"FAISS index file not found at {index_path}")

        # Load metadata
        if os.path.exists(metadata_path):
            if ".arrow" in metadata_path:
                self.metadata_provider = ArrowMetadataProvider(metadata_path)
                self.cols = ["url", "caption"]
            if ".csv" in metadata_path:
                self.metadata_provider = CsvMetadataProvider(metadata_path)
                self.cols = ["image_path", "caption"]
            print(f"Loaded metadata from: {metadata_path}")
        else:
            raise FileNotFoundError(f"Metadata file not found at {metadata_path}")

        # Load model
        self.model, self.preprocess = create_model_from_pretrained(model_name)
        self.tokenizer = get_tokenizer(model_name)
        self.model.eval()  # Set to evaluation mode
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)

        if self.search_caption:
            self.prompt_name = None
            if "Qwen3" in text_model:
                #self.prompt_name = "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery:"
                self.prompt_name = "query"
            self.text_model = SentenceTransformer(
                text_model, device=self.device
            )

    def compute_query(self, query):
        """
        Processes an image (PIL.Image or base64 string) and computes the embedding.

        Args:
            query (str or PIL.Image.Image): Image input as PIL.Image or base64-encoded string.

        Returns:
            torch.Tensor: Normalized image embedding.
        """
        if isinstance(query, dict) and "image_path" in query:
            if isinstance(query["image_path"], list):
                random_path = random.choice(query["image_path"])
                image = Image.open(random_path).convert("RGB")
            else:
                image = Image.open(query["image_path"]).convert("RGB")
        if isinstance(query, Image.Image):
            image = query
        if isinstance(query, str):
            if "," in query:
                image_data = BytesIO(base64.b64decode(query.split(",")[1]))
                image = Image.open(image_data).convert("RGB")
            else:
                image = Image.open(query).convert("RGB")  # file path

        image = self.preprocess(image).unsqueeze(0).to(self.device)

        with torch.no_grad():
            image_features = self.model.encode_image(image)
            image_features /= image_features.norm(dim=-1, keepdim=True)  # Normalize

        if self.search_caption:
            caption = query.get("caption", "")
            prompt_name = ""
            text_features = self.text_model.encode(
                caption,
                normalize_embeddings=True,
                convert_to_tensor=True,
                convert_to_numpy=False,
                prompt_name=self.prompt_name
            )
            if text_features.ndim == 1:
                text_features = text_features.unsqueeze(0)
            image_features = torch.concat([image_features, text_features], dim=-1)

        return image_features

    def retrieve_query(self, query, top_k=5):
        """
        Retrieves search results based on the query embedding from the FAISS index.

        Args:
            query (torch.Tensor): The query embedding (normalized).

        Returns:
            list: Metadata of the top-k retrieved items.
        """
        top_k = max(top_k, self.top_k)
        query_features = query.cpu().numpy()
        distances, indices = self.index.search(query_features, top_k)
        metadata = self.metadata_provider.get(indices[0], cols=self.cols)
        return metadata, distances[0]


def main():
    # Use the provided arguments
    # index_path = "/drl_nas2/ckddls1321/data/MedTrinity-demo/faiss.index"
    # metadata_path = "/drl_nas2/ckddls1321/data/MedTrinity-demo/metadata.csv"
    # model_name = "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
    # image_url = "https://radiologykey.com/wp-content/uploads/2016/07/B9780750675376500068_gr3.jpg"

    # index_path = "/home/ckddls1321/.cache/indexes/WikiWeb2M/faiss.index"
    # metadata_path = "/home/ckddls1321/.cache/indexes/WikiWeb2M/merged_metadata.csv"
    # model_name = "hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K"
    # image_url = (
    # "https://open-vision-language.github.io/oven/assets/oven-images/material1.png"
    # )

    # index_path = "/drl_nas2/ckddls1321/data/InfoSeek/image_faiss.index"
    # metadata_path = "/drl_nas2/ckddls1321/data/InfoSeek/image_metadata.csv"
    # model_name = "hf-hub:timm/ViT-SO400M-16-SigLIP2-384"
    # image_path = "./bird1.png"

    index_path = "/drl_nas2/ckddls1321/data/InfoSeek/mixed_faiss.index"
    metadata_path = "/drl_nas2/ckddls1321/data/InfoSeek/mixed_metadata.csv"
    model_name = "hf-hub:timm/ViT-SO400M-16-SigLIP2-384"
    image_path = "./bird1.png"

    # Instantiate the search class with the correct arguments
    retriever = FaissSearch(
        index_path=index_path,
        metadata_path=metadata_path,
        model_name=model_name,
        search_caption=False,
    )

    # Test image from URL
    # response = requests.get(image_url)
    # if response.status_code == 200:
    #     image = Image.open(BytesIO(response.content)).convert("RGB")
    # else:
    #     print("Failed to load image from URL")
    #     return

    image_text_pairs, distances = retriever.search(image_path)  # Only search image
    # image_text_pairs, distances = retriever.search(
    #     {
    #         "image_path": image_path,
    #         "caption": "This bird is Brown-crested_Flycatcher, The brown tail feathers have rufous inner webs, the remiges have rufous outer webs, and there are two dull wing bars",
    #     }
    # )
    print(image_text_pairs)

    # Print out the search results
    print("Search results:")
    for image_text_pair in image_text_pairs:
        print(
            f"Image: {image_text_pair['image_path']}, Text: {image_text_pair['caption']}"
        )


if __name__ == "__main__":
    main()
