import numpy as np
import matplotlib.pyplot as plt
import umap
from pathlib import Path
from sklearn.metrics import silhouette_score, adjusted_rand_score, normalized_mutual_info_score
from sklearn.mixture import GaussianMixture

def load_data(embeddings_folder, model, dataset):
    folder = Path(embeddings_folder)
    X_path = folder / f"embeddings_{model}_{dataset}.npy"
    y_path = folder / f"labels_{dataset}.npy"
    X = np.load(X_path)
    y = np.load(y_path)
    return X, y

def encode_labels(y):
    label_to_index = {label: idx for idx, label in enumerate(sorted(set(y)))}
    return np.array([label_to_index[label] for label in y])

def umap_project(X):
    reducer = umap.UMAP(n_components=2, random_state=5)
    return reducer.fit_transform(X)

def cluster_and_plot(X, X_umap, y_numeric, name, n_clusters=6):
    gmm = GaussianMixture(n_components=n_clusters, covariance_type='full', random_state=23)
    gmm_labels = gmm.fit_predict(X_umap)

    n_colors = max(y_numeric.max(), gmm_labels.max()) + 1
    cmap = plt.get_cmap('tab20', n_colors)

    plt.figure(figsize=(8, 6))
    plt.scatter(X_umap[:, 0], X_umap[:, 1], c=gmm_labels, cmap=cmap, s=5, alpha=0.7)
    plt.title(f"GMM Clusters - {name}")
    plt.xlabel("UMAP-1")
    plt.ylabel("UMAP-2")
    plt.colorbar(label="Cluster")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(8, 6))
    plt.scatter(X_umap[:, 0], X_umap[:, 1], c=y_numeric, cmap=cmap, s=5, alpha=0.7)
    plt.title(f"True Labels - {name}")
    plt.xlabel("UMAP-1")
    plt.ylabel("UMAP-2")
    plt.colorbar(label="Label")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    print(f"{name} - Silhouette (UMAP): {silhouette_score(X_umap, gmm_labels):.4f}")
    print(f"{name} - Silhouette (Raw): {silhouette_score(X, gmm_labels):.4f}")
    print(f"{name} - ARI: {adjusted_rand_score(y_numeric, gmm_labels):.4f}")
    print(f"{name} - NMI: {normalized_mutual_info_score(y_numeric, gmm_labels):.4f}")


if __name__ == "__main__":
    embeddings_folder = "embeddings"
    dataset_name = "dolphin_reef_unbalanced"
    models = ["dolph2vec", "biolingual", "aves_bio"]

    for model in models:
        X, y = load_data(embeddings_folder, model, dataset_name)
        y_numeric = encode_labels(y)
        X_umap = umap_project(X)
        name = f"{model}_{dataset_name}"
        cluster_and_plot(X, X_umap, y_numeric, name)
