import time
from typing import Any, Dict, List

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from openai import OpenAI
from sklearn.metrics.pairwise import cosine_similarity

from src.core.registry import embeddings_registry
from src.embeddings.base import EmbeddingsInterface
from src.utils.decorator_utils import with_logger


# Register with embeddings registry, not LLM registry since this is an embedding model
@embeddings_registry.register("OpenAI_Embeddings")
class OpenAIEmbeddings(EmbeddingsInterface):
    """
    A class to generate embeddings for text inputs and compute similarity between them.
    This class uses OpenAI's API to generate embeddings for a list of text inputs,
    computes a cosine similarity matrix between these embeddings, and can
    visualize the similarities as a heatmap.
    Parameters
    ----------
    input : List[str]
        A list of text strings to generate embeddings for
    model : str, optional
        The embedding model to use, defaults to "bedrock-cohere-embed-eng-v3"
    Methods
    -------
    get_credentials()
        Returns the base_url and api_key for the OpenAI API
    get_embeddings()
        Generates embeddings for the input texts using OpenAI API
    similarity_matrix(embeddings)
        Computes the cosine similarity matrix between embeddings
    similarity_heatmap(similarity_df)
        Generates a heatmap visualization of the similarity matrix
    __call__()
        Runs the complete pipeline from embeddings to heatmap
    Returns
    -------
    When called, returns a tuple of (fig, ax) with the matplotlib figure and axis
    containing the similarity heatmap visualization.
    Example
    -------
    >>> texts = ["Hello world", "Hi there", "Completely different topic"]
    >>> api = EmbeddingsAPI(input=texts)
    >>> fig, ax = api()
    """

    @with_logger
    def __init__(self, input, model: str = "bedrock-cohere-embed-eng-v3", **kwargs):
        super().__init__()
        logger.info(
            f"Initialising OpenAI Embeddings with model {model} and input list of size {len(input)}"
        )
        self.input = input  # Input texts to generate embeddings for
        self.model = model
        self.kwargs = kwargs

        # Initialise the OpenAI client
        # ? Is this needed, or can we just reload same client
        base_url, api_key = self.load_api_info()
        self.configure_environment(base_url, api_key)
        logger.info(f"Using base URL: {base_url}")

        try:
            self.client = OpenAI(
                base_url=base_url,
                api_key=api_key,
            )
            logger.info("OpenAI client initialised successfully")
        except Exception as e:
            logger.error(f"Failed to initialise OpenAI client: {str(e)}", exc_info=True)
            raise

    @with_logger
    def generate(self, input_ls):
        """Call the class to generate heatmap from end to end"""
        start_time = time.time()

        embeddings = self.get_embeddings(input_ls)
        similarity_matrix = self.similarity_matrix(embeddings)
        fig, ax = self.similarity_heatmap(similarity_matrix)

        end_time = time.time()

        time_taken = end_time - start_time
        logger.info(
            f"✅ Embeddings heatmap generation: Total time taken = {time_taken:.2f} seconds"
        )
        return fig, ax

    def get_embeddings(self, input_texts: List[str] = None) -> Any:
        """Generate embeddings via OpenAI API"""
        # Select inputs from the LLM run or standalone run
        # Use provided input_texts if available, otherwise fall back to self.input
        input_ls = input_texts if input_texts is not None else self.input
        embeddings = self.client.embeddings.create(model=self.model, input=input_ls)
        return embeddings

    def similarity_matrix(self, embeddings) -> pd.DataFrame:
        """Generate similarity matrix based on embeddings"""

        similarity_matrix = cosine_similarity([e.embedding for e in embeddings.data])
        similarity_df = pd.DataFrame(
            similarity_matrix, columns=self.input, index=self.input
        )
        return similarity_df

    def similarity_heatmap(self, similarity_df):
        """Generate heatmap plot"""

        fig, ax = plt.subplots(figsize=(10, 10))
        ax = sns.heatmap(
            similarity_df, annot=True, cmap="coolwarm", vmin=-1, vmax=1, ax=ax
        )
        ax.set_title(
            f"Embedding Model = {self.model}; Inter-embedding Cosine Similarity Heatmap",
            fontsize=16,
        )
        # plt.show()

        # heatmap = similarity_df
        return fig, ax

    @property
    def model_info(self) -> Dict[str, Any]:
        """
        Get information about the model.

        Returns:
            A dictionary containing information about the model
        """
        return {
            "name": self.model,
            "vendor": "OpenAI",
            # "version_name": self.version_name,
            # "temperature": self.temperature,
            # "top_p": self.top_p,
            **self.kwargs,
        }
