from scipy.stats import entropy
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
import umap
from sklearn.preprocessing import StandardScaler, LabelEncoder
import os
from sklearn.manifold import MDS


def plot_umap(x, y, outfolder):
    umap_2d = umap.UMAP(n_components=2, random_state=42, n_neighbors=10, min_dist=0.5)
    X_2d_umap = umap_2d.fit_transform(x)

    unique_labels = np.unique(y)
    colors = plt.get_cmap('tab10')(np.linspace(0, 1, len(unique_labels)))
    label_to_color = {label: colors[i] for i, label in enumerate(unique_labels)}

    plt.figure(figsize=(8, 6))
    for label in unique_labels:
        indices = np.array(y) == label
        plt.scatter(X_2d_umap[indices, 0], X_2d_umap[indices, 1],
                    label=label, color=label_to_color[label], s=10, alpha=0.7)

    plt.legend(title="Labels", bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=15, title_fontsize=16, markerscale=2)

    plt.title("UMAP Projection of Code Vectors into 2D", fontsize=16)
    plt.xlabel("UMAP Dimension 1", fontsize=14)
    plt.ylabel("UMAP Dimension 2", fontsize=14)

    plt.tight_layout()
    plt.savefig(os.path.join(outfolder, "UMAP_codevectors.png"), dpi=300)
    plt.close()



def compute_information_metrics(probs: np.ndarray):
    """
    Compute entropy H(Label), conditional entropy H(Label | Code),
    and mutual information I(Label; Code) from P(label | code) matrix.
    Assumes `probs` is shape [num_labels x num_code_indices],
    where each column is a label distribution for a given code index.
    """
    # P(code index): sum over labels
    P_code = probs.sum(axis=0)
    P_code = P_code / P_code.sum()  # Normalize, just in case

    # Clip to avoid log(0)
    P_label_given_code = np.clip(probs / probs.sum(axis=0, keepdims=True), 1e-12, 1.0)

    # Compute H(Label | Code)
    H_label_given_code = np.sum(P_code * entropy(P_label_given_code, base=2, axis=0))

    # P(label): marginal over codes
    P_label = probs.sum(axis=1)
    P_label = P_label / P_label.sum()

    # Compute H(Label)
    H_label = entropy(P_label, base=2)

    # Mutual Information
    I_label_code = H_label - H_label_given_code

    return {
        "H_Label": H_label,
        "H_Label_given_Code": H_label_given_code,
        "I(Label; Code)": I_label_code
    }


def compute_cooccurrence_matrix(all_indices, labels):
    co_occurrence = defaultdict(lambda: defaultdict(int))
    label_to_idx = {}
    index_to_idx = {}
    idx_to_index = []

    for label, indices in zip(labels, all_indices):
        label_idx = label_to_idx.setdefault(label, len(label_to_idx))

        for code in indices.squeeze():
            # TODO
            first_idx = code[0]  # codebook (0: first CB 1: second CB)
            first_idx_idx = index_to_idx.setdefault(first_idx, len(index_to_idx))
            co_occurrence[first_idx_idx][label_idx] += 1

    mat = np.zeros((len(label_to_idx), len(index_to_idx)))
    for i in range(len(index_to_idx)):
        for j in range(len(label_to_idx)):
            mat[j, i] = co_occurrence[i][j]

    probs = mat / mat.sum(axis=0, keepdims=True)
    labels = [k for k, _ in sorted(label_to_idx.items(), key=lambda x: x[1])]
    return mat, probs, labels




def plot_cooccurrence_matrix(conditional_probs, labels, outfolder):
    max_probs = np.max(conditional_probs, axis=0)
    max_labels = np.argmax(conditional_probs, axis=0)
    sort_key = max_labels * 1000 - max_probs
    sorted_indices = np.argsort(sort_key)

    sorted_matrix = conditional_probs[:, sorted_indices]

    plt.figure(figsize=(12, 4))  # Slightly smaller figure
    ax = sns.heatmap(
        sorted_matrix,
        cmap='Blues',
        xticklabels=False,
        yticklabels=labels,
        cbar_kws={"label": "Probability", "shrink": 0.8}
    )
    ax.set_title("Quantized Vector Co-occurrence by Whistle Type", fontsize=16)
    ax.set_yticklabels(labels, fontsize=12, rotation=0)

    # Increase colorbar font
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=12)
    cbar.set_label("Probability", fontsize=14)

    plt.tight_layout()
    plt.savefig(os.path.join(outfolder, "Quantized_Vector_Cooccurrence_by_Whistle_Type.png"), dpi=300)
    plt.close()