import numpy as np
from scipy.linalg import svd
from scipy.stats import entropy
import torch
import pandas as pd
import itertools
from sklearn.metrics import adjusted_rand_score, silhouette_score, davies_bouldin_score
from bertopic_utils import get_custom_bertopic_config
from sklearn.metrics import silhouette_score, davies_bouldin_score

# --- geometry metrics ---
def calculate_effective_rank(E):
    _, s, _ = svd(E, full_matrices=False)
    p = s / np.sum(s)
    return np.exp(entropy(p))

def calculate_uniformity(E, t=2):
    if isinstance(E, torch.Tensor):
        E = E.detach().cpu().numpy()
    sq_dist = 2 - 2 * (E @ E.T)
    return np.log(np.mean(np.exp(-t * sq_dist)))

def get_isotropy_coeff(E):
    _, s, _ = svd(E - E.mean(0), full_matrices=False)
    return s.min() / s.max()

# --- clustering metrics ---
def clustering_quality_metrics(embeddings, labels):
    mask = labels != -1
    E = embeddings[mask]
    y = labels[mask]

    # Edge case: not enough clusters to compute metrics
    if len(np.unique(y)) < 2 or len(y) < 5:
        return {"silhouette": np.nan, "davies_bouldin": np.nan}

    return {
        "silhouette": silhouette_score(E, y, metric="euclidean"),
        "davies_bouldin": davies_bouldin_score(E, y)
    }

def clustering_stability(clusterings):
    aris = []
    for c1, c2 in itertools.combinations(clusterings, 2):
        aris.append(adjusted_rand_score(c1, c2))
    return np.mean(aris), np.std(aris)

def run_downstream(df, dims, device, bertopic_model, tau=1.0, 
                   lambda_orth=0.1, n_runs=15, epochs=50, save_labels=False):
    all_results = []
    clusterings_by_method = {"CamemBERT-v2": [], "MV-CL": [], "MV-CLOR": []}
    all_labels = {"CamemBERT-v2": [], "MV-CL": [], "MV-CLOR": []}
    texts = df["text"].tolist()

    for run in range(n_runs):
        print(f"Run {run+1}/{n_runs}")
        
        # Extraction of Embeddings
        E_text = np.vstack(df["text_emb"].values)
        
        from main import train_and_extract_embeddings
        E_mvcl = train_and_extract_embeddings(
            df, dims, device, tau=tau, use_orth=False, seed=run, epochs=epochs
        )

        E_mvclor = train_and_extract_embeddings(
            df, dims, device, tau=tau, use_orth=True, 
            lambda_orth=lambda_orth, seed=run, epochs=epochs
        )

        #clustering + metrics sil, db, outliers
        for method, E in zip(["CamemBERT-v2", "MV-CL", "MV-CLOR"], [E_text, E_mvcl, E_mvclor]):

            topics, _ = bertopic_model.fit_transform(texts, E)
            clusterings_by_method[method].append(topics)
            all_labels[method].append(topics)
            
            quality = clustering_quality_metrics(E, topics)
            
            outlier_ratio = np.mean(topics == -1)
            n_topics = len(set(topics)) - (1 if -1 in topics else 0)

            all_results.append({
                "method": method, 
                "run": run, 
                "n_topics": n_topics,
                "outlier_ratio": outlier_ratio,
                **quality  
            })

    # ARI
    stability_rows = []
    for method, clusterings in clusterings_by_method.items():
        ari_mean, ari_std = clustering_stability(clusterings)
        stability_rows.append({"method": method, "ARI_mean": ari_mean, "ARI_std": ari_std})

    df_results = pd.DataFrame(all_results)
    df_stability = pd.DataFrame(stability_rows)

    return (df_results, df_stability, all_labels) if save_labels else (df_results, df_stability)

