import base64
import os
import random
from io import BytesIO

import faiss
import open_clip
import pandas as pd
import pyarrow as pa
import requests
import torch
from open_clip import create_model_from_pretrained, get_tokenizer
from PIL import Image
#from sentence_transformers import SentenceTransformer
from torch.nn import functional as F
from transformers import AutoModelForCausalLM, AutoProcessor, AutoModel

class SearchAPI:
    def __init__(self):
        """
        Initializes the base search API, which is used by subclasses to implement specific search behaviors.
        """
        pass

    def compute_query(self, query):
        """
        Placeholder for computing the query. Subclasses must implement this to handle the query (text or image).
        
        Args:
            query (str or base64): Text query or base64-encoded image.
        
        Returns:
            Processed query (str or torch.Tensor)
        """
        raise NotImplementedError("Subclasses must implement the compute_query method.")

    def retrieve_query(self, query):
        """
        Placeholder for retrieving results based on the computed query. Subclasses must implement this.
        
        Args:
            query (str or torch.Tensor): Processed query (text or embedding).
        
        Returns:
            List of search results.
        """
        raise NotImplementedError("Subclasses must implement the retrieve_query method.")

    def search(self, query):
        """
        Perform a search based on the given query (text or image).
        
        Args:
            query (str or base64): Text query or base64-encoded image.
        
        Returns:
            List of search results.
        """
        # Step 1: Compute the query (either text or image embedding)
        computed_query = self.compute_query(query)

        # Step 2: Retrieve search results
        return self.retrieve_query(computed_query)


class ArrowMetadataProvider:
    """The Arrow metadata provider provides metadata from contiguous ids using Arrow"""

    def __init__(self, arrow_file):
        self.table = pa.ipc.RecordBatchFileReader(
            pa.memory_map(arrow_file, "r")
        ).read_all()

    def get(self, ids, cols=None):
        """Get metadata for the specified ids"""
        if cols is None:
            cols = self.table.schema.names
        else:
            cols = list(set(self.table.schema.names) & set(cols))
        t = pa.concat_tables([self.table[i : i + 1] for i in ids])
        return t.select(cols).to_pandas().to_dict("records")


class CsvMetadataProvider:
    """The Csv metadata provider provides metadata from contiguous ids using Arrow"""

    def __init__(self, csv_file):
        # Load metadata from CSV
        self.metadata_df = pd.read_csv(csv_file)

    def get(self, ids, cols=None):
        if cols is None:
            cols = ["image_path", "caption"]
        # Retrieve metadata based on ids
        return self.metadata_df.loc[ids, cols].to_dict(orient="records")


class HFFaissSearch(SearchAPI):
    def __init__(
        self,
        index_path="/drl_nas2/ckddls1321/data/InfoSeek/unime_text_faiss_index.index",
        metadata_path="/drl_nas2/ckddls1321/data/InfoSeek/unime_text_metadata.csv",
        model_name="DeepGlint-AI/UniME-Phi3.5-V-4.2B",
        search_caption=False,
        top_k=5,
    ):
        """
        Initializes the HFFaissSearch with a FAISS index, metadata, and a Hugging Face model.

        Args:
            index_path (str): Path to the FAISS index file.
            metadata_path (str): Path to the Arrow or CSV metadata file.
            base_model_path (str): Path or name of the Hugging Face model.
            search_caption (bool): Whether to enable caption-based search.
            top_k (int): The number of top results to retrieve.
        """
        super().__init__()
        self.top_k = top_k
        self.search_caption = search_caption
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model_name = model_name
        # Define prompt templates
        self.img_caption_prompt = "<|user|>\n<|image_1|> Represent the given image with the following question: <sent><|end|>\n<|assistant|>\n"
        self.img_only_prompt = "<|user|>\n<|image_1|>\nSummary above image in one word: <|end|>\n<|assistant|>\n"  # Image-to-text

        # Load FAISS index
        if os.path.exists(index_path):
            self.index = faiss.read_index(index_path)
            print(f"Loaded FAISS index from: {index_path}")
            print(
                f"Index dimension: {self.index.d}, Number of vectors: {self.index.ntotal}"
            )
        else:
            raise FileNotFoundError(f"FAISS index file not found at {index_path}")

        # Load metadata
        if os.path.exists(metadata_path):
            if ".arrow" in metadata_path:
                self.metadata_provider = ArrowMetadataProvider(metadata_path)
                self.cols = ["url", "caption"]
            elif ".csv" in metadata_path:
                self.metadata_provider = CsvMetadataProvider(metadata_path)
                self.cols = ["image_path", "caption"]
            print(f"Loaded metadata from: {metadata_path}")
        else:
            raise FileNotFoundError(f"Metadata file not found at {metadata_path}")

        # Load Hugging Face model and processor
        self.transform = AutoProcessor.from_pretrained(
            model_name, trust_remote_code=True
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="cuda",
            trust_remote_code=True,
            torch_dtype=torch.float16,
            _attn_implementation="flash_attention_2",
        )
        self.transform.tokenizer.padding_side = "left"
        self.transform.tokenizer.padding = True

    def compute_query(self, query):
        """
        Processes a query to compute an embedding. The query can be for an image only
        or for an image and a corresponding caption.

        Args:
            query (dict): A dictionary containing the query information.
            - For image and question(or caption): {"image_path": str, "caption": str}
            - For image only: {"image_path": str}

        Returns:
            torch.Tensor: The normalized embedding vector for the query.
        """
        if not isinstance(query, dict) or "image_path" not in query:
            raise ValueError("Query must be a dictionary with an 'image_path' key.")

        image_path = query["image_path"]
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image file not found at: {image_path}")

        image = Image.open(image_path).convert("RGB")

        # Case 1: The query includes a caption
        if "caption" in query and query["caption"]:
            caption = query["caption"]
            prompt = self.img_caption_prompt.replace("<sent>", caption)
        # Case 2: The query is for the image only
        else:
            prompt = self.img_only_prompt

        inputs = self.transform(
            text=prompt, images=[image], return_tensors="pt", padding=True
        ).to(self.device)

        with torch.no_grad():
            hidden_states = self.model(
                **inputs, output_hidden_states=True, return_dict=True
            ).hidden_states
            embedding = F.normalize(hidden_states[-1][:, -1, :], dim=-1)

        return embedding

    def search(self, query):
        """
        Performs a search for the given query.

        Args:
            query (str or PIL.Image.Image): The image or text query.
            question (str): The question for the prompt template.

        Returns:
            list: A list of search results.
        """
        query_embedding = self.compute_query(query).cpu().numpy()
        distances, indices = self.index.search(query_embedding, self.top_k)
        results = self.metadata_provider.get(ids=indices[0], cols=self.cols)
        return results, distances
    

class HFFaissSearch2(SearchAPI):
    """
    A search class that uses the nvidia/MM-Embed model for generating query embeddings
    and retrieves results from a FAISS index.
    """

    def __init__(
        self,
        index_path,
        metadata_path,
        model_name="nvidia/MM-Embed",
        instruction="Retrieve a Wikipedia paragraph that provides an answer to the given query about the image.",
        top_k=5,
        max_length=4096,
    ):
        """
        Initializes the HFFaissSearch2 with a FAISS index, metadata, and the MM-Embed model.

        Args:
            index_path (str): Path to the FAISS index file.
            metadata_path (str): Path to the Arrow or CSV metadata file.
            model_name (str): Name or path of the Hugging Face model (MM-Embed).
            instruction (str): The instruction string for the MM-Embed model's encoder.
            top_k (int): The number of top results to retrieve.
            max_length (int): The maximum sequence length for the model.
        """
        super().__init__()
        self.top_k = top_k
        self.instruction = instruction
        self.max_length = max_length
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model_name = model_name

        # Load FAISS index
        if os.path.exists(index_path):
            self.index = faiss.read_index(index_path)
            print(f"Loaded FAISS index from: {index_path}")
            print(f"Index dimension: {self.index.d}, Number of vectors: {self.index.ntotal}")
        else:
            raise FileNotFoundError(f"FAISS index file not found at {index_path}")

        # Load metadata
        if os.path.exists(metadata_path):
            if ".arrow" in metadata_path:
                self.metadata_provider = ArrowMetadataProvider(metadata_path)
                self.cols = ["url", "caption"]
            elif ".csv" in metadata_path:
                self.metadata_provider = CsvMetadataProvider(metadata_path)
                self.cols = ["image_path", "caption"]
            print(f"Loaded metadata from: {metadata_path}")
        else:
            raise FileNotFoundError(f"Metadata file not found at {metadata_path}")

        # Load Hugging Face model
        print(f"Loading model: {self.model_name}")
        self.model = AutoModel.from_pretrained(
            self.model_name, trust_remote_code=True
        ).to(self.device)
        print("Model loaded successfully.")

    def compute_query(self, query):
        """
        Processes a query to compute an embedding using the MM-Embed model.

        Args:
            query (dict): A dictionary containing the query information.
            - For image and question(or caption): {"image_path": str, "caption": str}
            - For image only: {"image_path": str}

        Returns:
            torch.Tensor: The normalized embedding vector for the query.
        """
        if not isinstance(query, dict) or "image_path" not in query:
            raise ValueError("Query must be a dictionary with an 'image_path' key.")

        image_path = query["image_path"]
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image file not found at: {image_path}")

        image = Image.open(image_path).convert("RGB")
        caption = query.get("caption")  # Safely get caption, will be None if not present

        # Construct the query list in the format expected by MM-Embed's encode method
        model_query = [{"img": image}]
        if caption:
            model_query[0]["txt"] = caption

        # Generate the embedding
        with torch.no_grad():
            embedding = self.model.encode(
                model_query,
                is_query=True,
                instruction=self.instruction,
                max_length=self.max_length,
            )["hidden_states"]

        return embedding

    def search(self, query):
        """
        Performs a search for the given query.

        Args:
            query (dict): The image and text query dictionary.

        Returns:
            tuple: A tuple containing:
                - list: A list of search results (dictionaries).
                - numpy.ndarray: An array of distances for the results.
        """
        # Compute embedding and move to CPU for FAISS search
        query_embedding = self.compute_query(query).cpu().numpy()
        distances, indices = self.index.search(query_embedding, self.top_k)
        results = self.metadata_provider.get(ids=indices[0].tolist(), cols=self.cols)
        return results, distances




if __name__ == "__main__":
    # This is an example of how to use the HFFaissSearch class.
    # You will need to replace the placeholder paths with your actual file paths.
    # and ensure you have a compatible FAISS index and metadata.

    # index_path = "/drl_nas2/ckddls1321/data/InfoSeek/unime_text_faiss_index.index"
    # metadata_path = "/drl_nas2/ckddls1321/data/InfoSeek/unime_text_metadata.csv"
    # model_name = "DeepGlint-AI/UniME-Phi3.5-V-4.2B"
    # image_path = "./bird1.png"

    # # Instantiate the search class with the correct arguments
    # retriever = HFFaissSearch(
    #     index_path=index_path,
    #     metadata_path=metadata_path,
    #     model_name=model_name,
    # )

    index_path = "/drl_nas2/ckddls1321/data/InfoSeek/mm_embed_faiss.index"
    metadata_path = "/drl_nas2/ckddls1321/data/InfoSeek/mm_embed_metadata.csv"
    model_name = "nvidia/MM-Embed"
    image_path = "./bird1.png"

    # Instantiate the search class with the correct arguments
    retriever = HFFaissSearch2(
        index_path=index_path,
        metadata_path=metadata_path,
        model_name=model_name,
    )

    # Test image from URL
    # response = requests.get(image_url)
    # if response.status_code == 200:
    #     image = Image.open(BytesIO(response.content)).convert("RGB")
    # else:
    #     print("Failed to load image from URL")
    #     return

    # image_text_pairs, distances = retriever.search(image_path)  # Only search image
    image_text_pairs, distances = retriever.search(
        {
            "image_path": image_path,
            # "caption": "This bird is Brown-crested_Flycatcher, The brown tail feathers have rufous inner webs, the remiges have rufous outer webs, and there are two dull wing bars",
            "caption": "What is the name of this bird?",
        }
    )
    print(image_text_pairs)

    # Print out the search results
    print("Search results:")
    for image_text_pair in image_text_pairs:
        print(
            f"Image: {image_text_pair['image_path']}, Text: {image_text_pair['caption']}"
        )
