import os
import openai
import pandas as pd
from typing import Union, List, Optional, Tuple
import numpy as np
import logging

logger = logging.getLogger(__name__)


M = 1_000_000

OPENAI_EMBEDDING_MODELS = [
    "text-embedding-3-small",
    "text-embedding-3-large",
]

AZURE_EMBEDDING_MODELS = [
    "azure-text-embedding-3-small",
    "azure-text-embedding-3-large",
]

OPENAI_EMBEDDING_COSTS = {
    "text-embedding-3-small": 0.02 / M,
    "text-embedding-3-large": 0.13 / M,
}


def get_client_model(model_name: str) -> tuple[openai.OpenAI, str]:
    if model_name in OPENAI_EMBEDDING_MODELS:
        client = openai.OpenAI()
        model_to_use = model_name
    elif model_name in AZURE_EMBEDDING_MODELS:
        
        model_to_use = model_name.split("azure-")[-1]
        client = openai.AzureOpenAI(
            api_key=os.getenv("AZURE_OPENAI_API_KEY"),
            api_version=os.getenv("AZURE_API_VERSION"),
            azure_endpoint=os.getenv("AZURE_API_ENDPOINT"),
        )
    else:
        raise ValueError(f"Invalid embedding model: {model_name}")

    return client, model_to_use


class EmbeddingClient:
    def __init__(
        self, model_name: str = "text-embedding-3-small", verbose: bool = False
    ):
        
        self.client, self.model = get_client_model(model_name)
        self.verbose = verbose

    def get_embedding(
        self, code: Union[str, List[str]]
    ) -> Union[Tuple[List[float], float], Tuple[List[List[float]], float]]:
        
        if isinstance(code, str):
            code = [code]
            single_code = True
        else:
            single_code = False
        try:
            response = self.client.embeddings.create(
                model=self.model, input=code, encoding_format="float"
            )
            cost = response.usage.total_tokens * OPENAI_EMBEDDING_COSTS[self.model]
            
            if single_code:
                return response.data[0].embedding, cost
            else:
                return [d.embedding for d in response.data], cost
        except Exception as e:
            logger.info(f"Error getting embedding: {e}")
            if single_code:
                return [], 0.0
            else:
                return [[]], 0.0

    def get_column_embedding(
        self,
        df: pd.DataFrame,
        column_name: Union[str, List[str]],
    ) -> pd.DataFrame:
        
        if isinstance(column_name, str):
            column_name = [column_name]

        for column_name in column_name:
            model_name_str = self.model.replace("-", "_")
            new_col_name = f"{column_name}_embedding_{model_name_str}"
            df[new_col_name] = df[column_name].apply(
                lambda x: self.get_embedding(x),
            )
        return df

    def get_closest_k_neighbors(
        self,
        new_str_query: str,
        embeddings: list,
        top_k: Union[int, str] = 5,
    ) -> tuple[list, list]:
        
        
        new_embedding, _ = self.get_embedding(new_str_query)

        if not new_embedding:  
            return [], []

        
        def cosine_similarity(a, b):
            return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

        
        similarities = [
            cosine_similarity(new_embedding, embedding) for embedding in embeddings
        ]

        
        if top_k == "random":
            if len(similarities) < 5:
                top_idx = np.random.choice(
                    len(similarities), size=len(similarities), replace=False
                )
            else:
                top_idx = np.random.choice(len(similarities), size=5, replace=False)
            similarities_subset = [similarities[i] for i in top_idx]
            return top_idx.tolist(), similarities_subset
        elif isinstance(top_k, int):
            top_idx = np.argsort(similarities)[-top_k:]
            similarities_subset = [similarities[i] for i in top_idx]
            return top_idx[::-1].tolist(), similarities_subset[::-1]
        else:
            raise ValueError("top_k must be an int or 'random'")

    def get_dim_reduction(
        self,
        embeddings: list,
        method: str = "pca",
        dims: int = 2,
    ):
        
        if isinstance(embeddings, pd.Series):
            embeddings = embeddings.tolist()

        
        X = np.array(embeddings) if isinstance(embeddings, list) else embeddings
        
        from sklearn.preprocessing import StandardScaler

        scaler = StandardScaler()
        X = scaler.fit_transform(X)

        if method.lower() == "pca":
            from sklearn.decomposition import PCA

            model = PCA(n_components=dims)
            return model.fit_transform(X)
        elif method.lower() == "umap":
            from umap import UMAP

            model = UMAP(n_components=dims, random_state=42)
            return model.fit_transform(X)
        elif method.lower() == "tsne":
            from sklearn.manifold import TSNE

            model = TSNE(n_components=dims, random_state=42)
            return model.fit_transform(X)
        else:
            raise ValueError("Method must be one of: 'pca', 'umap', 'tsne'")

    def get_embedding_clusters(
        self,
        embeddings: list,
        num_clusters: int = 4,
        verbose: bool = False,
    ) -> list:
        
        from sklearn.mixture import GaussianMixture

        
        gmm = GaussianMixture(n_components=num_clusters, random_state=42)
        gmm.fit(embeddings)
        clusters = gmm.predict(embeddings)

        
        if verbose:
            logger.info(
                f"GMM {num_clusters} Clusters ==> Got {len(embeddings)} "
                f"embeddings with cluster assignments:"
            )
            num_members = pd.Series(clusters).value_counts()
            logger.info(num_members)

        return clusters

    def plot_reduced_embeddings(
        self,
        embeddings: list,
        method: str = "pca",
        num_dims: int = 3,
        title="Embedding",
        cluster_ids: Optional[list] = None,
        cluster_label: str = "Cluster",
        patch_type: Optional[list] = None,
    ):
        transformed = self.get_dim_reduction(embeddings, method, num_dims)

        if num_dims == 2:
            fig, ax = plot_2d_scatter(
                transformed, title, cluster_ids, cluster_label, patch_type
            )
        elif num_dims == 3:
            fig, ax = plot_3d_scatter(
                transformed, title, cluster_ids, cluster_label, patch_type
            )
        else:
            raise ValueError(f"Invalid number of dimensions: {num_dims}")

        return fig, ax


def plot_2d_scatter(
    transformed: np.ndarray,
    title: str = "Embedding",
    cluster_ids: Optional[list] = None,
    cluster_label: str = "Cluster",
    patch_type: Optional[list] = None,
):
    import matplotlib.pyplot as plt
    from matplotlib.colors import ListedColormap
    from matplotlib.lines import Line2D

    
    fig, ax = plt.subplots(figsize=(10, 7))

    
    if cluster_ids is not None:
        original_unique_ids, cluster_ids_for_coloring = np.unique(
            cluster_ids, return_inverse=True
        )
        num_distinct_colors = len(original_unique_ids)
        
        
        
    else:
        cluster_ids_for_coloring = np.zeros(transformed.shape[0])
        original_unique_ids = [
            0
        ]  
        num_distinct_colors = 1

    
    base_colors = [
        "green",
        "red",
        "blue",
        "yellow",
        "purple",
        "orange",
        "brown",
        "pink",
        "gray",
        "cyan",
    ]
    if num_distinct_colors > 0:
        multiplier = (num_distinct_colors - 1) // len(base_colors) + 1
        extended_colors = base_colors * multiplier
        colors_for_cmap = extended_colors[:num_distinct_colors]
    else:  
        colors_for_cmap = ["blue"]

    cmap = ListedColormap(colors_for_cmap)

    marker_shapes = ["o", "s", "^", "P", "X", "D", "v", "<", ">"]

    if patch_type is not None:
        patch_type_array = np.array(patch_type)
        unique_patches = np.unique(patch_type_array)

        for i, patch_val in enumerate(unique_patches):
            patch_mask = patch_type_array == patch_val
            current_marker = marker_shapes[i % len(marker_shapes)]

            c_val_scatter = None
            cmap_val_scatter = (
                None  
            )
            if cluster_ids is not None:
                c_val_scatter = cluster_ids_for_coloring[patch_mask]
                cmap_val_scatter = cmap

            label_text = str(patch_val)

            scatter_args = {
                "marker": current_marker,
                "alpha": 0.6,
                "s": 100,
                "label": label_text,
            }
            if c_val_scatter is not None:  
                scatter_args["c"] = c_val_scatter
                scatter_args["cmap"] = cmap_val_scatter

            ax.scatter(
                transformed[patch_mask, 0],  
                transformed[patch_mask, 1],  
                **scatter_args,
            )
    else:  
        c_val_scatter_else = None
        if cluster_ids is not None:
            c_val_scatter_else = (
                cluster_ids_for_coloring  
            )

        

        scatter_args_else = {"marker": "o", "alpha": 0.6, "s": 100}
        if (
            c_val_scatter_else is not None
        ):  
            scatter_args_else["c"] = c_val_scatter_else
            scatter_args_else["cmap"] = cmap  

        ax.scatter(
            transformed[:, 0],  
            transformed[:, 1],  
            **scatter_args_else,
        )

    
    ax.set_xlabel("1st Latent Dim.", fontsize=20)
    ax.set_ylabel("2nd Latent Dim.", fontsize=20)
    ax.set_title(title, fontsize=30)

    
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)

    
    if (
        cluster_ids is not None
    ):  
        try:
            
            
            
            ax.scatter(
                transformed[:, 0],
                transformed[:, 1],
                c=cluster_ids_for_coloring,  
                cmap=cmap,  
                s=0,
                alpha=0,
            )
            
            
            
            
            
            
            
            
            
            
        except Exception:
            pass  

    if patch_type is not None:
        
        legend_handles = []
        unique_patches_for_legend = np.unique(np.array(patch_type))
        for i, patch_val in enumerate(unique_patches_for_legend):
            legend_handles.append(
                Line2D(
                    [0],
                    [0],
                    marker=marker_shapes[i % len(marker_shapes)],
                    color="black",
                    label=str(patch_val),
                    linestyle="None",
                    markersize=10,
                )
            )
        if legend_handles:
            ax.legend(handles=legend_handles, title="Patch Types", loc="best")

    fig.tight_layout()
    
    
    

    return fig, ax


def plot_3d_scatter(
    transformed: np.ndarray,
    title: str = "Embedding",
    cluster_ids: Optional[list] = None,
    cluster_label: str = "Cluster",
    patch_type: Optional[list] = None,
):
    import matplotlib.pyplot as plt
    from matplotlib.lines import Line2D
    from matplotlib.colors import ListedColormap

    
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection="3d", computed_zorder=False)

    
    if cluster_ids is not None:
        original_unique_ids, cluster_ids_for_coloring = np.unique(
            cluster_ids, return_inverse=True
        )
        num_distinct_colors = len(original_unique_ids)
    else:
        cluster_ids_for_coloring = np.zeros(transformed.shape[0])
        original_unique_ids = [0]
        num_distinct_colors = 1

    
    base_colors = [
        "green",
        "red",
        "blue",
        "yellow",
        "purple",
        "orange",
        "brown",
        "pink",
        "gray",
        "cyan",
    ]
    if num_distinct_colors > 0:
        multiplier = (num_distinct_colors - 1) // len(base_colors) + 1
        extended_colors = base_colors * multiplier
        colors_for_cmap = extended_colors[:num_distinct_colors]
    else:
        colors_for_cmap = ["blue"]

    cmap = ListedColormap(colors_for_cmap)

    marker_shapes = ["o", "s", "^", "P", "X", "D", "v", "<", ">"]

    if patch_type is not None:
        patch_type_array = np.array(patch_type)
        unique_patches = np.unique(patch_type_array)

        for i, patch_val in enumerate(unique_patches):
            patch_mask = patch_type_array == patch_val
            current_marker = marker_shapes[i % len(marker_shapes)]

            c_val_scatter = None
            cmap_val_scatter = None
            if cluster_ids is not None:
                c_val_scatter = cluster_ids_for_coloring[patch_mask]
                cmap_val_scatter = cmap

            label_text = str(patch_val)

            scatter_args = {
                "marker": current_marker,
                "alpha": 0.6,
                "s": 20,  
                "label": label_text,
                
            }
            if c_val_scatter is not None:
                scatter_args["c"] = c_val_scatter
                scatter_args["cmap"] = cmap_val_scatter

            scatter = ax.scatter(
                transformed[patch_mask, 0],  
                transformed[patch_mask, 1],  
                transformed[patch_mask, 2],  
                **scatter_args,
            )
    else:  
        c_val_scatter_else = None
        if cluster_ids is not None:
            c_val_scatter_else = cluster_ids_for_coloring

        scatter_args_else = {
            "marker": "o",
            "alpha": 0.6,
            "s": 20,  
            
        }
        if c_val_scatter_else is not None:
            scatter_args_else["c"] = c_val_scatter_else
            scatter_args_else["cmap"] = cmap

        scatter = ax.scatter(
            transformed[:, 0],  
            transformed[:, 1],  
            transformed[:, 2],  
            **scatter_args_else,
        )

    
    ax.set_xlabel("1st Latent Dim.", labelpad=-15, fontsize=8)
    ax.set_ylabel("2nd Latent Dim.", labelpad=-15, fontsize=8)
    ax.set_zlabel(
        "3rd Latent Dim.", labelpad=-17, rotation=90, fontsize=8
    )  
    ax.set_title(title, y=0.95)

    
    if cluster_ids is not None:  
        try:
            temp_scatter_for_colorbar = ax.scatter(
                transformed[:, 0],
                transformed[:, 1],
                transformed[:, 2],
                c=cluster_ids_for_coloring,  
                cmap=cmap,
                s=0,
                alpha=0,
            )
            
            
            
            
            
            
            
        except Exception:
            pass  

    if patch_type is not None:
        
        legend_handles_3d = []
        unique_patches_for_legend_3d = np.unique(np.array(patch_type))
        for i, patch_val in enumerate(unique_patches_for_legend_3d):
            legend_handles_3d.append(
                Line2D(
                    [0],
                    [0],
                    marker=marker_shapes[i % len(marker_shapes)],
                    color="black",
                    label=str(patch_val),
                    linestyle="None",
                    markersize=10,
                )
            )
        if legend_handles_3d:
            ax.legend(
                handles=legend_handles_3d,
                title="Patch Types",
                loc="best",
                bbox_to_anchor=(0.9, 0.5),
            )

    
    ax.view_init(elev=20, azim=45)

    
    
    plt.subplots_adjust(left=0.05, right=0.9, top=0.9, bottom=0.05)
    fig.tight_layout()
    return fig, ax
