import torch
import os
import h5py
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression, Ridge, SGDClassifier
from sklearn.metrics import roc_auc_score, classification_report, accuracy_score, r2_score, mean_squared_error, pairwise_distances
# from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
from pprint import pprint
from scipy.spatial import procrustes
from scipy.stats import spearmanr
import warnings
from itertools import combinations
from dataset_classes import *
import json
from multiprocessing import Pool
from functools import partial
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

warnings.filterwarnings("ignore")

valid_blocks = [
    ["mlp", "up_proj"],
    ["mlp", "down_proj"],
    ["mlp", "gate_proj"],
    ["self_attn", "q_proj"],
    ["self_attn", "k_proj"],
    ["self_attn", "v_proj"],
    ["self_attn", "o_proj"],
]



def load_layer_data(layer_idx, block, proj, act_dir, act_prefix, mani_dir):
    act_path = os.path.join(act_dir, f"{act_prefix}_layer_{layer_idx}_{block}_{proj}.h5")

    if not (os.path.exists(act_path)):
        return None

    with h5py.File(act_path, "r") as f:
        labels = [l.decode() for l in f["labels"][:]]
        vecs = f["vecs"][:]
        indices = f["indices"][:]

    return np.array(labels), torch.tensor(vecs), indices


def stratified_split(labels, val_ratio=0.2, seed=42):
    indices = np.arange(len(labels))
    train_idx, val_idx = train_test_split(
        indices,
        test_size=val_ratio,
        stratify=labels,
        random_state=seed
    )
    return train_idx, val_idx


def probe_one_vs_all(vecs, labels):
    aucs = {}
    for emotion in np.unique(labels):
        y_bin = (labels == emotion).astype(int)  # 1 for current emotion, 0 for all else
        X_train, X_test, y_train, y_test = train_test_split(vecs, y_bin, stratify=y_bin, test_size=0.2)

        clf = SGDClassifier(loss="log_loss", max_iter=1000)
        clf.fit(X_train, y_train)
        probs = clf.predict_proba(X_test)[:, 1]  # Prob for class 1
        aucs[emotion] = roc_auc_score(y_test, probs)
    return aucs

def compute_centroid_cosine_similarity(synth_centroids, real_centroids, labels):
    sims = {
        label: float(F.cosine_similarity(synth_centroids[i].unsqueeze(0),
                                         real_centroids[i].unsqueeze(0)).item())
        for i, label in enumerate(labels)
    }
    return sims


def stress1(D_high: torch.Tensor, D_low: torch.Tensor):
    """
    Kruskal's Stress-1, GPU-compatible.
    Assumes D_high and D_low are same shape [M, N].
    """
    assert D_high.shape == D_low.shape and D_high.ndim == 2, "Input must be 2D matrices of same shape"

    m, n = D_high.shape
    triu_idx = torch.triu_indices(m, n, offset=1, device=D_high.device)

    diff_sq = (D_high[triu_idx[0], triu_idx[1]] - D_low[triu_idx[0], triu_idx[1]]) ** 2
    denom = D_high[triu_idx[0], triu_idx[1]] ** 2

    return torch.sqrt(diff_sq.sum() / denom.sum()).item()

def stress2(D_high, D_low):
    """
    Kruskal's Stress-2
    """
    assert D_high.shape == D_low.shape and D_high.ndim == 2, "Input must be 2D matrices of same shape"

    m, n = D_high.shape
    triu_idx = torch.triu_indices(m, n, offset=1, device=D_high.device)

    diff_sq = (D_high[triu_idx[0], triu_idx[1]] - D_low[triu_idx[0], triu_idx[1]]) ** 2
    denom = D_high[triu_idx[0], triu_idx[1]] ** 2

    return (diff_sq.sum() / denom.sum()).item()

def sammon_stress(D_high, D_low, eps=1e-8):
    """
    Sammon Stress (weighted by inverse of original distances)
    """
    assert D_high.shape == D_low.shape and D_high.ndim == 2, "Input must be 2D matrices of same shape"

    m, n = D_high.shape
    triu_idx = torch.triu_indices(m, n, offset=1, device=D_high.device)

    d_high = D_high[triu_idx[0], triu_idx[1]] + eps
    d_low = D_low[triu_idx[0], triu_idx[1]]

    num = ((d_high - d_low) ** 2 / d_high).sum()
    denom = d_high.sum()

    return (num / denom).item()


def compute_lq_distortion(rho_matrix, q=2):
    assert rho_matrix.ndim == 2, "Input must be 2D"

    m, n = rho_matrix.shape
    triu_idx = torch.triu_indices(m, n, offset=1, device=rho_matrix.device)

    rho_vals = rho_matrix[triu_idx[0], triu_idx[1]]
    return rho_vals.pow(q).mean().pow(1 / q).item()


def compute_sigma_distortion(X, Y, eps=1e-8):
    """
    Computes sigma-distortion from Vankadara & von Luxburg (2019)
    """

    assert X.shape == Y.shape and X.ndim == 2, "X and Y must be 2D and same shape"

    device = X.device
    D_X = torch.cdist(X, X, p=2) + eps  # [N, N]
    D_Y = torch.cdist(Y, Y, p=2)        # [N, N]
    rho = D_Y / D_X

    n = X.shape[0]
    triu_idx = torch.triu_indices(n, n, offset=1, device=device)
    rho_vals = rho[triu_idx[0], triu_idx[1]]

    norm_factor = (n * (n - 1)) / 2
    rho_normed = norm_factor * rho_vals / rho_vals.sum()
    sigma_dist = ((rho_normed - 1) ** 2).mean()

    return sigma_dist.item()



def compute_distortion_metrics(X, Y, eps=1e-8):
    assert X.shape == Y.shape and X.ndim == 2, "X and Y must be 2D tensors of the same shape"

    n = X.shape[0]
    device = X.device

    D_X = torch.cdist(X, X, p=2) + eps  # [N, N]
    D_Y = torch.cdist(Y, Y, p=2)        # [N, N]
    rho = D_Y / D_X                     # [N, N]

    triu_idx = torch.triu_indices(n, n, offset=1, device=device)
    rho_vals = rho[triu_idx[0], triu_idx[1]]  # [N(N-1)/2]

    phi_wc = rho_vals.max() * (1 / (rho_vals + eps)).max()
    phi_avg = (2 / (n * (n - 1))) * rho_vals.sum()
    alpha = rho_vals.min() + eps
    phi_naavg = (2 / (n * (n - 1))) * (rho_vals / alpha).sum()

    return {
        "worst_case": phi_wc.item(),
        "average": phi_avg.item(),
        "normalized_avg": phi_naavg.item(),
        "rho_matrix": rho  # already on GPU
    }


def compute_pairwise_distance_matrix(centroids: torch.Tensor):
    """
    Compute pairwise Euclidean distance matrix between centroids.
    GPU-compatible.
    """
    return torch.cdist(centroids, centroids, p=2)


def get_centroid_matrix(vecs, labels, shared_labels=None):
    """Compute centroid matrix for shared labels."""
    unique = sorted(set(labels) if shared_labels is None else shared_labels)
    return torch.stack([
        vecs[torch.tensor([lbl == l for l in labels], device=vecs.device)].mean(0)
        for lbl in unique
    ]), unique

class TorchLogReg(torch.nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.linear = torch.nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.linear(x)

def projected_probing(projection_manifold, vecs, labels, dims=(3, 10, -1), already_centered=False, device="cuda", epochs=60):
    train_idx, val_idx = stratified_split(labels)
    labels = np.array(labels)
    label_set = sorted(set(labels))
    label_to_idx = {lbl: i for i, lbl in enumerate(label_set)}
    y = torch.tensor([label_to_idx[l] for l in labels], device=device)

    centered_vecs = vecs if already_centered else vecs - projection_manifold["mean"]

    results = {}
    for dim in dims:
        d = dim if dim != -1 else projection_manifold["Vh"].shape[0]
        Vh = projection_manifold["Vh"][:d].T.to(device)
        projected = centered_vecs if already_centered else centered_vecs @ Vh

        X_train = projected[train_idx]
        y_train = y[train_idx]
        X_val = projected[val_idx]
        y_val = y[val_idx]

        model = TorchLogReg(d, len(label_set)).to(device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
        loss_fn = torch.nn.CrossEntropyLoss()

        for _ in range(epochs):
            model.train()
            optimizer.zero_grad()
            out = model(X_train)
            loss = loss_fn(out, y_train)
            loss.backward()
            optimizer.step()
            scheduler.step()  # update LR

        model.eval()
        with torch.no_grad():
            preds = model(X_val).argmax(dim=1)
            acc = (preds == y_val).float().mean().item()

        results[f"dim_{d}"] = acc

    return results


# def centered_svd_activations(vecs, labels, excluded=("neutral", "condescension", "sarcastic")):
def centered_svd_activations(vecs, labels, excluded=()):
    device = vecs.device
    labels = np.array(labels)

    # Avoid Python loop for masking
    mask = np.isin(labels, excluded, invert=True)
    vecs = vecs[torch.from_numpy(mask).to(device)]

    mean = vecs.mean(dim=0)
    vecs_centered = vecs - mean

    # Use SVD optimized for device
    U, S, Vh = torch.linalg.svd(vecs_centered, full_matrices=False)

    return {"U": U, "S": S, "Vh": Vh, "mean": mean}





def align_projected_centroid_geometry_torch(vect_array, label_array, svd_dict, k=10):
    results = {}
    device = next(iter(vect_array.values())).device

    for (p1, p2) in combinations(label_array.keys(), 2):
        labels1, labels2 = np.array(label_array[p1]), np.array(label_array[p2])
        shared = sorted(set(labels1) & set(labels2))
        if len(shared) < 2:
            continue

        def compute_centroids(vecs, labels, shared):
            return torch.stack([
                vecs[torch.from_numpy(labels == lbl).to(device)].mean(0)
                for lbl in shared
            ])

        X1 = compute_centroids(vect_array[p1], labels1, shared) @ svd_dict[p1][:k].T
        X2 = compute_centroids(vect_array[p2], labels2, shared) @ svd_dict[p2][:k].T

        D1 = torch.cdist(X1, X1)
        D2 = torch.cdist(X2, X2)

        W = torch.linalg.lstsq(D1, D2).solution
        D2_pred = D1 @ W

        mse = F.mse_loss(D2_pred, D2).item()
        cosine = F.cosine_similarity(D2_pred, D2, dim=1).mean().item()

        results[(p1, p2)] = {
            "mse": mse,
            "mean_cosine": cosine,
            "W": W.detach(),
            "shared": shared
        }

    return results


def compute_linear_alignment_map(hidden_synth, labels_synth, hidden_real, labels_real, shared_labels=None):
    """
    Compute optimal linear map W that aligns hidden_real to hidden_synth using label-wise correspondences.
    All inputs assumed to already be on the same device.
    """
    device = hidden_synth.device
    labels_synth = np.array(labels_synth)
    labels_real = np.array(labels_real)

    if shared_labels is None:
        shared_labels = sorted(set(labels_synth) & set(labels_real))

    synth_batches = [hidden_synth[torch.from_numpy(labels_synth == lbl).to(device)] for lbl in shared_labels]
    real_batches = [hidden_real[torch.from_numpy(labels_real == lbl).to(device)] for lbl in shared_labels]

    X, Y = [], []
    for xs, ys in zip(synth_batches, real_batches):
        n = min(len(xs), len(ys))
        if n < 1:
            continue
        X.append(xs[:n])
        Y.append(ys[:n])

    X = torch.cat(X, dim=0)
    Y = torch.cat(Y, dim=0)

    # print(X.shape, Y.shape, X.dtype, Y.dtype, X.device, Y.device, flush=True)
    W = torch.linalg.lstsq(Y, X).solution
    # U, S, Vh = torch.linalg.svd(Y, full_matrices=False)
    # k = 512  # or 1024 depending on desired tradeoff
    # S_inv = torch.diag(1.0 / S[:k])
    # W = Vh[:k].T @ S_inv @ U[:, :k].T @ X
    # X_hat = Y @ W

    X_hat = Y @ W

    mse = F.mse_loss(X_hat, X).item()
    cosine = F.cosine_similarity(X_hat, X, dim=-1).mean().item()
    _, S, _ = torch.linalg.svd(W)

    return W, {
        "mse": mse,
        "mean_cosine": cosine,
        "singular_values": S.cpu().tolist(),
        "condition_number": (S[0] / S[-1]).item()
    }


def analyze_topology_for_layer_block(label_array, vect_array, svd_manifold, prefixes, dims=(50,), device="cuda"):
    results = {"metrics": {}}

    # synth_labels, synth_vecs = label_array["synth_llama"], vect_array["synth_llama"].to(device)
    synth_labels, synth_vecs = label_array["synth"], vect_array["synth"].to(device)
    svd_manifold["Vh"] = svd_manifold["Vh"].to(device)
    svd_manifold["mean"] = svd_manifold["mean"].to(device)

    with tqdm(total=len(prefixes)) as pbar:
        for prefix in prefixes:
            print(prefix)
            # if prefix == "synth_llama":
            if prefix == "synth":
                pbar.update(1)
                continue

            cleaned = prefix + "_cleaned"
            if cleaned not in label_array or cleaned not in vect_array:
                pbar.update(1)
                continue

            real_labels = label_array[cleaned]
            real_vecs = vect_array[cleaned].to(device)
            shared = sorted(set(synth_labels) & set(real_labels))
            # if len(shared) < 3:
            #     pbar.update(1)
            #     continue

            real_svd_manifold = centered_svd_activations(real_vecs, real_labels)
            real_svd_manifold["Vh"] = real_svd_manifold["Vh"].to(device)
            real_svd_manifold["mean"] = real_svd_manifold["mean"].to(device)

            results["metrics"][prefix] = {}

            # Cache projections
            print("Projecting Vectors")
            X_synth_proj = {dim: (synth_vecs - svd_manifold["mean"]) @ svd_manifold["Vh"][:dim].T for dim in dims}
            X_real_proj = {dim: (real_vecs - svd_manifold["mean"]) @ svd_manifold["Vh"][:dim].T for dim in dims}
            X_real_own_proj = {dim: (real_vecs - real_svd_manifold["mean"]) @ real_svd_manifold["Vh"][:dim].T for dim in dims}

            synth_np = np.array(synth_labels)
            real_np = np.array(real_labels)

            # for dim in list(dims):
            for dim in [None] + list(dims):
                if dim is None:
                    X_synth = synth_vecs
                    X_real = real_vecs
                    proj_mode = "raw"
                    proj_dict = {
                        "Vh": torch.eye(synth_vecs.shape[-1], device=device),
                        "mean": torch.zeros_like(synth_vecs[0])
                    }
                else:
                    X_synth = X_synth_proj[dim]
                    X_real = X_real_proj[dim]
                    proj_mode = f"proj_{dim}"
                    proj_dict = {
                        "Vh": svd_manifold["Vh"],
                        "mean": svd_manifold["mean"]
                    }

                C_synth, _ = get_centroid_matrix(X_synth, synth_labels, shared)
                C_real, _ = get_centroid_matrix(X_real, real_labels, shared)

                ### --- CENTROID GEOMETRY METRICS ---
                D_synth = compute_pairwise_distance_matrix(C_synth)
                D_real = compute_pairwise_distance_matrix(C_real)
                distort_c = compute_distortion_metrics(C_synth, C_real)

                centroid_geom = {
                    "stress1": stress1(D_synth, D_real),
                    "stress2": stress2(D_synth, D_real),
                    "sammon_stress": sammon_stress(D_synth, D_real),
                    "distortion_avg": distort_c["average"],
                    "distortion_l2": compute_lq_distortion(distort_c["rho_matrix"], q=2),
                    "distortion_sigma": compute_sigma_distortion(C_synth, C_real),
                }

                ### --- CENTROID ALIGNMENT METRICS ---
                sim_dict = compute_centroid_cosine_similarity(C_synth, C_real, shared)
                centroid_align = {
                    "cosine_similarity": float(torch.tensor(list(sim_dict.values()), device=device).mean().item())
                }

                print("Aligning Vectors")
                align_res = align_projected_centroid_geometry_torch(
                    {"synth": X_synth, cleaned: X_real},
                    {"synth": synth_labels, cleaned: real_labels},
                    {"synth": torch.eye(X_synth.shape[-1], device=device), cleaned: torch.eye(X_real.shape[-1], device=device)},
                    k=10
                )
                align_key = ("synth", cleaned)
                if align_key in align_res:
                    centroid_align.update({
                        "alignment_mse": align_res[align_key]["mse"],
                        "alignment_cosine": align_res[align_key]["mean_cosine"],
                        "alignment_W": align_res[align_key]["W"],
                    })

                ### --- FULL SPACE GEOMETRY METRICS ---
                X_synth_filt, X_real_filt = [], []
                for lbl in shared:
                    xs = X_synth[torch.from_numpy(synth_np == lbl).to(device)]
                    ys = X_real[torch.from_numpy(real_np == lbl).to(device)]
                    n = min(len(xs), len(ys))
                    if n > 0:
                        X_synth_filt.append(xs[:n])
                        X_real_filt.append(ys[:n])
                X_synth_filt = torch.cat(X_synth_filt, dim=0)
                X_real_filt = torch.cat(X_real_filt, dim=0)

                D_synth_f = compute_pairwise_distance_matrix(X_synth_filt)
                D_real_f = compute_pairwise_distance_matrix(X_real_filt)
                distort_f = compute_distortion_metrics(X_synth_filt, X_real_filt)
                print("Calculating Stress")
                full_geom = {
                    "stress1": stress1(D_synth_f, D_real_f),
                    "stress2": stress2(D_synth_f, D_real_f),
                    "sammon_stress": sammon_stress(D_synth_f, D_real_f),
                    "distortion_avg": distort_f["average"],
                    "distortion_l2": compute_lq_distortion(distort_f["rho_matrix"], q=2),
                    "distortion_sigma": compute_sigma_distortion(X_synth_filt, X_real_filt),
                }

                ### --- PROBING ---
                print("Calculating probes")
                probe_metrics = projected_probing(proj_dict, X_real, real_labels, dims=[X_real.shape[-1]], device=device, already_centered=True)

                ### --- FULLSPACE LINEAR ALIGNMENT ---
                full_align = {}
                if dim is not None:
                    X_real = X_real_own_proj[dim]
                print("Calculating full space alignment")
                W, linear_stats = compute_linear_alignment_map(
                    hidden_synth=X_synth,
                    labels_synth=synth_labels,
                    hidden_real=X_real,
                    labels_real=real_labels,
                    shared_labels=shared
                )
                full_align.update({
                    "alignment_W": W.detach().cpu(),
                    "alignment_mse": linear_stats["mse"],
                    "alignment_cosine": linear_stats["mean_cosine"],
                    "alignment_condition": linear_stats["condition_number"],
                    "singular_values": linear_stats["singular_values"],
                })

                results["metrics"][prefix][proj_mode] = {
                    "centroid_geometry": centroid_geom,
                    "centroid_alignment": centroid_align,
                    "full_geometry": full_geom,
                    "linear_alignment": full_align,
                    "probe": probe_metrics
                }

            pbar.update(1)

    return results


def analyze_model_topology_for_layer_block(label_array, vect_array, svd_manifold, prefixes, dims=(50,), device="cuda"):
    results = {"metrics": {}}

    synth_labels, synth_vecs = label_array["synth_llama"], vect_array["synth_llama"].to(device)
    llama_svd_manifold, model_svd_manifold = svd_manifold

    llama_svd_manifold["Vh"] = llama_svd_manifold["Vh"].to(device)
    llama_svd_manifold["mean"] = llama_svd_manifold["mean"].to(device)

    model_svd_manifold["Vh"] = model_svd_manifold["Vh"].to(device)
    model_svd_manifold["mean"] = model_svd_manifold["mean"].to(device)

    svd_manifold = llama_svd_manifold

    with tqdm(total=len(prefixes)) as pbar:
        for prefix in prefixes:
            print(prefix)
            if prefix == "synth_llama":
            # if prefix == "synth_all":
                pbar.update(1)
                continue

            cleaned = prefix + "_cleaned"
            if cleaned not in label_array or cleaned not in vect_array:
                pbar.update(1)
                continue

            real_labels = label_array[cleaned]
            real_vecs = vect_array[cleaned].to(device)
            shared = sorted(set(synth_labels) & set(real_labels))
            if len(shared) < 3:
                pbar.update(1)
                continue

            real_svd_manifold = model_svd_manifold
            real_svd_manifold["Vh"] = real_svd_manifold["Vh"].to(device)
            real_svd_manifold["mean"] = real_svd_manifold["mean"].to(device)

            results["metrics"][prefix] = {}

            # Cache projections
            print("Projecting Vectors")
            X_synth_proj = {dim: (synth_vecs - svd_manifold["mean"]) @ svd_manifold["Vh"][:dim].T for dim in dims}
            X_real_own_proj = {dim: (real_vecs - real_svd_manifold["mean"]) @ real_svd_manifold["Vh"][:dim].T for dim in dims}

            synth_np = np.array(synth_labels)
            real_np = np.array(real_labels)

            for dim in list(dims):
                X_synth = X_synth_proj[dim]
                X_real = X_real_own_proj[dim]
                proj_mode = f"proj_{dim}"
                proj_dict = {
                    "Vh": svd_manifold["Vh"],
                    "mean": svd_manifold["mean"]
                }

                C_synth, _ = get_centroid_matrix(X_synth, synth_labels, shared)
                C_real, _ = get_centroid_matrix(X_real, real_labels, shared)

                ### --- CENTROID GEOMETRY METRICS ---
                D_synth = compute_pairwise_distance_matrix(C_synth)
                D_real = compute_pairwise_distance_matrix(C_real)
                distort_c = compute_distortion_metrics(C_synth, C_real)

                print("Calculating Stress")
                centroid_geom = {
                    "stress1": stress1(D_synth, D_real),
                    "stress2": stress2(D_synth, D_real),
                    "sammon_stress": sammon_stress(D_synth, D_real),
                    "distortion_avg": distort_c["average"],
                    "distortion_l2": compute_lq_distortion(distort_c["rho_matrix"], q=2),
                    "distortion_sigma": compute_sigma_distortion(C_synth, C_real),
                }

                ### --- CENTROID ALIGNMENT METRICS ---
                sim_dict = compute_centroid_cosine_similarity(C_synth, C_real, shared)
                centroid_align = {
                    "cosine_similarity": float(torch.tensor(list(sim_dict.values()), device=device).mean().item())
                }

                print("Aligning Vectors")
                align_res = align_projected_centroid_geometry_torch(
                    {"synth": X_synth, cleaned: X_real},
                    {"synth": synth_labels, cleaned: real_labels},
                    {"synth": torch.eye(X_synth.shape[-1], device=device), cleaned: torch.eye(X_real.shape[-1], device=device)},
                    k=10
                )
                align_key = ("synth", cleaned)
                if align_key in align_res:
                    centroid_align.update({
                        "alignment_mse": align_res[align_key]["mse"],
                        "alignment_cosine": align_res[align_key]["mean_cosine"],
                        "alignment_W": align_res[align_key]["W"],
                    })

                ### --- FULL SPACE GEOMETRY METRICS ---
                X_synth_filt, X_real_filt = [], []
                for lbl in shared:
                    xs = X_synth[torch.from_numpy(synth_np == lbl).to(device)]
                    ys = X_real[torch.from_numpy(real_np == lbl).to(device)]
                    n = min(len(xs), len(ys))
                    if n > 0:
                        X_synth_filt.append(xs[:n])
                        X_real_filt.append(ys[:n])
                X_synth_filt = torch.cat(X_synth_filt, dim=0)
                X_real_filt = torch.cat(X_real_filt, dim=0)

                D_synth_f = compute_pairwise_distance_matrix(X_synth_filt)
                D_real_f = compute_pairwise_distance_matrix(X_real_filt)
                distort_f = compute_distortion_metrics(X_synth_filt, X_real_filt)

                full_geom = {
                    "stress1": stress1(D_synth_f, D_real_f),
                    "stress2": stress2(D_synth_f, D_real_f),
                    "sammon_stress": sammon_stress(D_synth_f, D_real_f),
                    "distortion_avg": distort_f["average"],
                    "distortion_l2": compute_lq_distortion(distort_f["rho_matrix"], q=2),
                    "distortion_sigma": compute_sigma_distortion(X_synth_filt, X_real_filt),
                }

                ### --- PROBING ---
                probe_metrics = projected_probing(proj_dict, X_real, real_labels, dims=[X_real.shape[-1]], device=device, already_centered=True)

                ### --- FULLSPACE LINEAR ALIGNMENT ---
                full_align = {}
                if dim is not None:
                    X_real = X_real_own_proj[dim]

                print("Calculating full space alignment")
                W, linear_stats = compute_linear_alignment_map(
                    hidden_synth=X_synth,
                    labels_synth=synth_labels,
                    hidden_real=X_real,
                    labels_real=real_labels,
                    shared_labels=shared
                )
                full_align.update({
                    "alignment_W": W.detach().cpu(),
                    "alignment_mse": linear_stats["mse"],
                    "alignment_cosine": linear_stats["mean_cosine"],
                    "alignment_condition": linear_stats["condition_number"],
                    "singular_values": linear_stats["singular_values"],
                })

            results["metrics"][prefix][proj_mode] = {
                "centroid_geometry": centroid_geom,
                "centroid_alignment": centroid_align,
                "full_geometry": full_geom,
                "linear_alignment": full_align,
                "probe": probe_metrics
            }

            pbar.update(1)

    return results


def save_results(fname, results, model, layer, block, proj, save_dir="analysis_results"):
    os.makedirs(save_dir, exist_ok=True)
    # fname = f"layer_alignment_statsv3_{model}_{layer}_{block}_{proj}.pt"
    # path = os.path.join(fname)
    torch.save(results, fname)



def run_topology_analysis(model, layer, block, proj, prefixes):
    fname = f"analysis_results/layer_alignment_statsv4_all_{model}_{layer}_{block}_{proj}.pt"
    # fname = f"analysis_results/layer_alignment_statsvmath_unconditional_{model}_{layer}_{block}_{proj}.pt"
    # if os.path.exists(f"analysis_results/layer_alignment_statsv4_all_{model}_{layer}_{block}_{proj}.pt"):
    #     return
    label_array, vect_array = {}, {}
    for prefix in prefixes:
        result = load_layer_data(layer, block, proj, f"{model}_hidden_state_dumps_mean", prefix, None)
        if result is None: continue
        label_array[prefix] = result[0]
        vect_array[prefix] = result[1]
        if prefix == "synth": continue
        cleaned = prefixes_to_class[prefix].remap_to_synth(label_array[prefix])
        mask = cleaned != "other"
        label_array[prefix+"_cleaned"] = cleaned[mask]
        vect_array[prefix+"_cleaned"] = result[1][mask]
    print("Loaded Datasets\nLoading SVD Space")
    svd = torch.load(f"emotional_manifold/{model}_manifold_slice_layer_{layer}.pt")[f"{proj}_{layer}"]
    print("SVD Space Loaded")
    results = analyze_topology_for_layer_block(label_array, vect_array, svd, prefixes, dims=(50,))
    # results = graph_centroids(label_array, vect_array, svd, prefixes)
    save_results(fname, results, model, layer, block, proj)


def run_model_topology_analysis(model, layer, block, proj, prefixes):
    fname = f"analysis_results/llama_base_layer_model_alignment_stats_{model}_{layer}_{block}_{proj}.pt"
    if os.path.exists(fname):
        return
    base_prefix = prefixes[0]  # "synth"
    llama_key = base_prefix + "_llama"
    model_key = base_prefix + f"_{model}"

    label_array, vect_array = {}, {}

    llama_result = load_layer_data(layer, block, proj, "llama_hidden_state_dumps_mean", base_prefix, None)
    model_result = load_layer_data(layer, block, proj, f"{model}_hidden_state_dumps_mean", base_prefix, None)
    
    label_array[llama_key] = llama_result[0]
    vect_array[llama_key] = llama_result[1]
    label_array[model_key] = model_result[0]
    vect_array[model_key] = model_result[1]

    label_array[llama_key+"_cleaned"] = llama_result[0]
    vect_array[llama_key+"_cleaned"] = llama_result[1]
    label_array[model_key+"_cleaned"] = model_result[0]
    vect_array[model_key+"_cleaned"] = model_result[1]

    # Use just these two keys
    active_prefixes = [llama_key, model_key]

    # Use LLaMA as reference SVD
    llama_svd = torch.load(f"emotional_manifold/manifold_slice_layer_{layer}.pt")[f"{proj}_{layer}"]
    model_svd = torch.load(f"emotional_manifold/{model}_manifold_slice_layer_{layer}_with_exclusions.pt")[f"{proj}_{layer}"]

    llama_dim = vect_array[prefixes[0] + "_llama"].shape[1]
    model_dim = vect_array[prefixes[0] + f"_{model}"].shape[1]
    min_dim = min(llama_dim, model_dim)
    dims = (50, min_dim)

    results = analyze_model_topology_for_layer_block(label_array, vect_array, [llama_svd, model_svd], active_prefixes, dims=dims)
    save_results(fname, results, model, layer, block, proj)




if __name__ == "__main__":
    # prefixes = ["synth", "hindi", "semeval", "german", "italian", "emoevent_es", "emoevent_en", "french", "twitter", "goemotions"]
    prefixes = ["synth", "math"]
    # prefixes = ["synth"]
    prefixes_to_class = {"synth": synth_text_dataset, "semeval": semeval_dataset, "german": german_plays_dataset,
                         "italian": italian_dataset, "emoevent_es": emoevents_dataset, "emoevent_en": emoevents_dataset,
                         "french": french_dataset, "twitter": twitter_dataset, "goemotions": go_emotions_dataset,
                         "hindi": hindi_dataset}

    import argparse
    parser = argparse.ArgumentParser(description="Select model and layers.")
    parser.add_argument("--model", type=str, default="llama_instruct", required=True, help="Model name")
    parser.add_argument("--start-layer", type=int, default=0, required=True, help="Index of the starting layer")
    # parser.add_argument("--end-layer", type=int, default=32, required=True, help="Index of the ending layer (inclusive)")

    args = parser.parse_args()

    model = args.model
    start_layer = args.start_layer
    end_layer = start_layer + 4
    # start_layer = 0
    # end_layer = 1

    # model = "mistralai"

    # graph_centroids(prefixes)

    tasks = [(l, b, p) for l in range(start_layer, end_layer) for b, p in valid_blocks]
    from joblib import Parallel, delayed, parallel_backend

    # with parallel_backend("loky", n_jobs=8):
    #     Parallel(verbose=10)(delayed(run_topology_analysis)(model, layer, block, proj, prefixes) for (layer, block, proj) in tasks)
    #
    #
    #
    # run_topology_analysis(16, "mlp", "up_proj", prefixes)
    for (layer, block, proj) in tqdm(tasks):
        run_topology_analysis(model, layer, block, proj, prefixes)

