import torch
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score, davies_bouldin_score, adjusted_rand_score
from scipy.linalg import svd
from scipy.stats import entropy
import itertools
from bertopic import BERTopic

from model import MultiViewEncoder
from loss import l_mv_cl, l_global, l_or
from view_extraction import SymbolicViewExtractor, ConceptEncoder
from evaluation import calculate_effective_rank, calculate_uniformity, get_isotropy_coeff, clustering_stability, run_downstream
from bertopic_utils import get_custom_bertopic_config

# --- Config & parameters ---
INPUT_DIMS = {"text": 768, "clinical": 64, "procedure": 32, "pharma": 64}
D_SHARED = 64
BATCH_SIZE = 32
EPOCHS = 50
DEVICE = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

# --- Data ---
class HMVCLDataset(Dataset):
    def __init__(self, df):
        # On suppose que le df contient déjà les vecteurs de chaque vue [4]
        self.data = {v: torch.tensor(np.stack(df[f"{v}_emb"].values), dtype=torch.float) 
                     for v in INPUT_DIMS.keys()}
        self.texts = df["text"].tolist()

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return {v: self.data[v][idx] for v in INPUT_DIMS.keys()}

# --- Geometric metrics ---
def get_geometry_metrics(E):
    """Calcule l'ensemble des métriques de l'espace latent"""
    E_np = E.detach().cpu().numpy()
    N, D = E_np.shape

    # Isotropic Coefficient (IC)
    ic = get_isotropy_coeff(E)

    # Effective Rank (RankMe)
    eff_rank = calculate_effective_rank(E)

    # Uniformity 
    uniformity = calculate_uniformity(E)

    # Explained Variance (PCA) 
    pca = PCA(n_components=min(10, D))
    pca.fit(E_np)
    var_10 = pca.explained_variance_ratio_.sum()
    var_1 = pca.explained_variance_ratio_ 

    return {"IC": ic, "EffRank": eff_rank, "Unif": uniformity, "Var10": var_10, "Var1": var_1}

# --- Clustering metrics ---
def get_clustering_metrics(E, topics):
    """Calcule Silhouette, Davies-Bouldin et Outlier Ratio"""
    mask = topics != -1
    if mask.sum() < 2: return {"sil": 0, "db": 0, "outliers": 100}
    
    E_filtered = E[mask]
    y_filtered = topics[mask]
    
    return {
        "silhouette": silhouette_score(E_filtered, y_filtered),
        "db": davies_bouldin_score(E_filtered, y_filtered),
        "outliers": (topics == -1).mean() * 100
    }

# --- Training loop + embedding extraction ---
class EmbeddingDataset(Dataset):
    def __init__(self, df):
        self.text = df["text_emb"].tolist()
        self.clinical = df["clinical_emb"].tolist()
        self.procedure = df["procedure_emb"].tolist()
        self.pharma = df["pharma_emb"].tolist()

    def __len__(self):
        return len(self.text)

    def __getitem__(self, idx):
        return {
            "text": torch.tensor(self.text[idx], dtype=torch.float),
            "clinical": torch.tensor(self.clinical[idx], dtype=torch.float),
            "procedure": torch.tensor(self.procedure[idx], dtype=torch.float),
            "pharma": torch.tensor(self.pharma[idx], dtype=torch.float)
        }

def train_and_extract_embeddings(
    df, dims, device, tau=0.1, use_orth=False, lambda_orth=10, seed=0, epochs=50
):
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    dataset = EmbeddingDataset(df)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

    model = MultiViewEncoder(dims, D=64).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

    model.train()
    for epoch in range(epochs):
        for batch in dataloader:
    
            batch = {k: v.to(device) for k, v in batch.items()}
            optimizer.zero_grad()

            e_k, z_views = model(batch)

            # 1. Multi-View Contrastive Loss (Eq. 2)
            loss_mv = l_mv_cl(z_views, tau=tau)

            # 2. Global Alignment Loss (Eq. 3)
            loss_global = 0
            for v_name in z_views.keys():
                loss_global += l_global(e_k, z_views, tau=tau)
            loss_global /= len(z_views)

            total_loss = loss_mv + loss_global

            # 3. Orthogonality Regularization (Eq. 5)
            if use_orth:
                total_loss += lambda_orth * l_or(z_views)

            total_loss.backward()
            optimizer.step()

    # e_k final extraction
    model.eval()
    extraction_loader = DataLoader(dataset, batch_size=32, shuffle=False)
    all_e_k = []
    with torch.no_grad():
        for batch in extraction_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            e_k, _ = model(batch)
            all_e_k.append(e_k.cpu())

    return torch.cat(all_e_k, dim=0).numpy()

# --- geometry experiment ---
def run_experiment(df, dims, device, tau=0.1, use_orth=False, lambda_orth=10):

    E_final = train_and_extract_embeddings(
        df, dims, device, tau=tau, use_orth=use_orth, 
        lambda_orth=lambda_orth, epochs=50
    )

    # geometry metrics
    metrics = {}
    metrics["ic"] = get_isotropy_coeff(E_final)
    metrics["effective_rank"] = calculate_effective_rank(E_final)
    metrics["uniformity"] = calculate_uniformity(E_final)
    
    return E_final, pd.Series(metrics)

# --- exec ---
if __name__ == "__main__":
    # charging data with columns *_emb et text
    df = pd.read_pickle("data_dummy.pkl")

    vec, empty_dim_model, hdbscan_model = get_custom_bertopic_config()
    tm = BERTopic(
        language='multilingual',
        vectorizer_model=vec, 
        umap_model=empty_dim_model, 
        hdbscan_model=hdbscan_model, 
        embedding_model=None, 
        calculate_probabilities=True,
        verbose=True
    )
    
    print("--- Starting HMV-CL Training & Evaluation (15 runs) ---")
    
    df_results_E2, df_stab, all_labels = run_downstream(
        df,
        dims=INPUT_DIMS,
        device=DEVICE,
        bertopic_model=tm,
        tau=1.0, 
        lambda_orth=0.0,
        n_runs=15,
        epochs=50,
        save_labels=True
    )
    _, geo_results = run_experiment(df, dims=INPUT_DIMS, device=DEVICE, tau=1.0, use_orth=False)

    

    print("\n" + "="*40)
    print("EMBEDDING SPACE GEOMETRY")
    print(geo_results)

    print("\nDOWNSTREAM CLUSTERING PERFORMANCE")
    for method in df_stab['method'].unique():
        print(f"\n--- Method: {method} ---")
    
        ari_mean = df_stab.loc[df_stab['method']==method, 'ARI_mean'].item()
        ari_std = df_stab.loc[df_stab['method']==method, 'ARI_std'].item()
        print(f"Stability (ARI): {ari_mean:.2f} ± {ari_std:.2f}")
    
        method_results = df_results_E2[df_results_E2['method'] == method]
        avg_sil = method_results['silhouette'].mean()
        avg_db = method_results['davies_bouldin'].mean()
        avg_outliers = method_results['outlier_ratio'].mean() * 100 
        avg_topics = method_results['n_topics'].mean()
    
        print(f"Silhouette Score: {avg_sil:.3f}")
        print(f"Davies-Bouldin Index: {avg_db:.3f}")
        print(f"Number of Topics: {avg_topics:.1f}")
        print(f"Outliers Ratio: {avg_outliers:.2f}%")

    print("="*40)