import os
import torch
import h5py
from tqdm import tqdm
from collections import defaultdict
import numpy as np
from itertools import combinations
import pandas as pd
from scipy.stats import spearmanr, kendalltau
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
from sklearn.manifold import TSNE
import seaborn as sns

valid_blocks = [
    ["mlp", "up_proj"],
    ["mlp", "down_proj"],
    ["mlp", "gate_proj"],
    ["self_attn", "q_proj"],
    ["self_attn", "k_proj"],
    ["self_attn", "v_proj"],
    ["self_attn", "o_proj"],
]

def load_layer_data(layer_idx, block, proj, act_dir, act_prefix, mani_dir):
    act_path = os.path.join(act_dir, f"{act_prefix}_layer_{layer_idx}_{block}_{proj}.h5")

    if not (os.path.exists(act_path)):
        return None

    with h5py.File(act_path, "r") as f:
        labels = [l.decode() for l in f["labels"][:]]
        vecs = f["vecs"][:]
        indices = f["indices"][:]

    return np.array(labels), torch.tensor(vecs), indices


def centered_svd_activations(vecs, labels, excluded=("neutral", "condescension", "sarcastic")):
    print(vecs.shape)
    vecs = vecs[torch.tensor([i not in excluded for i in labels])]
    print(vecs.shape)
    mean = vecs.mean(axis=0)
    vecs_centered = vecs - mean
    try:
        U, S, Vh = torch.linalg.svd(vecs_centered, full_matrices=False)
    except torch._C._LinAlgError:
        # retry with tiny jitter
        eps = (vecs_centered.abs().mean() * 1e-6).clamp_min(1e-12).item()
        vecs_centered = vecs_centered + eps * torch.randn_like(vecs_centered)
        try:
            U, S, Vh = torch.linalg.svd(vecs_centered, full_matrices=False, driver="gesvd")
        except Exception:
            # last resort: same but with default driver
            U, S, Vh = torch.linalg.svd(vecs_centered, full_matrices=False)

    return {"U": U, "S": S, "Vh": Vh, "mean": mean}


def project_onto_manifold(vecs, labels, manifold):
    centered_vecs = vecs - manifold["mean"]
    projected_hidden_states = centered_vecs @ manifold["Vh"].T
    emotion_centroids = {emotion: projected_hidden_states[labels == emotion].mean(axis=0) for emotion in np.unique(labels)}
    return projected_hidden_states, emotion_centroids


def get_pc_axis_order(centroids, pc_index):
    return sorted(
        centroids.keys(),
        key=lambda emo: centroids[emo][pc_index].item(),
        reverse=True
    )

def get_pc_axis_distribution_info(projected_vecs, labels, pc_index, exclude=None):
    exclude = set(exclude or [])
    pcs = projected_vecs[:, pc_index]
    kept_mask = np.array([l not in exclude for l in labels])
    pcs = pcs[kept_mask]
    labels = labels[kept_mask]

    values_by_label = {label: pcs[labels == label] for label in np.unique(labels)}
    means = {label: v.mean().item() for label, v in values_by_label.items()}
    ordered = sorted(means.items(), key=lambda x: x[1], reverse=True)
    inter_class_var = np.var(list(means.values()))
    intra_class_var = np.mean([v.var().item() for v in values_by_label.values()])
    return ordered, inter_class_var, intra_class_var


def get_pc_axis_info(centroids, pc_index, normalize=False, exclude=None):
    """
    Returns a list of (emotion, value) sorted by value on PC[pc_index],
    and the inter-class variance. Options:
      - normalize: z-score normalization of values
      - exclude: set/list of emotions to skip
    """
    exclude = set(exclude or [])

    emo_vals = {
        emo: centroid[pc_index].item()
        for emo, centroid in centroids.items()
        if emo not in exclude
    }

    values = np.array(list(emo_vals.values()))
    if normalize:
        mean, std = values.mean(), values.std()
        values = (values - mean) / (std + 1e-8)
        emo_vals = {k: v for k, v in zip(emo_vals.keys(), values)}

    ordered = sorted(emo_vals.items(), key=lambda x: x[1], reverse=True)
    inter_class_variance = np.var(list(emo_vals.values()))
    return ordered, inter_class_variance


def extract_high_snr_components(vecs, labels, excluded=("condescension", "sarcastic"), snr_thresh=0.1):
    """
    Perform centered SVD on vecs and return reduced U, S, Vh based on high-SNR dimensions.
    """
    keep_mask = torch.tensor([label not in excluded for label in labels])
    vecs = vecs[keep_mask]
    labels = np.array([label for i, label in enumerate(labels) if keep_mask[i]])

    mean = vecs.mean(dim=0)
    vecs_centered = vecs - mean
    U, S, Vh = torch.linalg.svd(vecs_centered, full_matrices=False)

    # Project onto all PCs
    projected = vecs_centered @ Vh.T  # [N, D]
    pcs = projected.detach().cpu().numpy()

    snrs = []
    for i in range(pcs.shape[1]):
        pc_vals = pcs[:, i]
        values_by_label = {label: pc_vals[labels == label] for label in np.unique(labels)}
        inter = np.var([v.mean() for v in values_by_label.values()])
        intra = np.mean([v.var() for v in values_by_label.values()])
        snr = inter / (intra + 1e-8)
        snrs.append(snr)

    snrs = np.array(snrs)
    high_snr_indices = np.where(snrs > snr_thresh)[0]

    return {
        "U": U[:, high_snr_indices],
        "S": S[high_snr_indices],
        "Vh": Vh[high_snr_indices],
        "mean": mean,
        "snrs": snrs,
        "indices": high_snr_indices
    }

def compare_pc_orders_all_metrics(pc_orders_all_1, pc_orders_all_2, excluded={"neutral", "condescension", "sarcastic"}, k=3):
    def filtered(order): return [e for e in order if e not in excluded]

    def top_bottom_sets(order):
        f = filtered(order)
        return set(f[:k]), set(f[-k:])

    def score_top_k(o1, o2):
        t1, b1 = top_bottom_sets(o1)
        t2, b2 = top_bottom_sets(o2)
        same = len(t1 & t2) + len(b1 & b2)
        flipped = len(t1 & b2) + len(b1 & t2)
        return max(same, flipped) / (2 * k)

    def score_spearman(o1, o2):
        f1, f2 = filtered(o1), filtered(o2)
        if set(f1) != set(f2): return 0.0
        rank_map = {e: i for i, e in enumerate(f1)}
        ranks = [rank_map[e] for e in f2]
        rho_fwd, _ = spearmanr(range(len(f1)), ranks)
        rho_rev, _ = spearmanr(range(len(f1)), ranks[::-1])
        return max((rho_fwd + 1) / 2 if not np.isnan(rho_fwd) else 0.5,
                   (rho_rev + 1) / 2 if not np.isnan(rho_rev) else 0.5)

    def score_kendall(o1, o2):
        f1, f2 = filtered(o1), filtered(o2)
        if set(f1) != set(f2): return 0.0
        rank_map = {e: i for i, e in enumerate(f1)}
        ranks = [rank_map[e] for e in f2]
        tau_fwd, _ = kendalltau(range(len(f1)), ranks)
        tau_rev, _ = kendalltau(range(len(f1)), ranks[::-1])
        return max((tau_fwd + 1) / 2 if not np.isnan(tau_fwd) else 0.5,
                   (tau_rev + 1) / 2 if not np.isnan(tau_rev) else 0.5)

    def compute_all_scores(pairs):
        results = defaultdict(lambda: defaultdict(list))
        for (key1, pcs1), (key2, pcs2) in pairs:
            for pc in ["PC1", "PC2", "PC3"]:
                o1, o2 = pcs1[pc], pcs2[pc]
                results[pc]["top_k"].append(score_top_k(o1, o2))
                results[pc]["spearman"].append(score_spearman(o1, o2))
                results[pc]["kendall"].append(score_kendall(o1, o2))
        return results

    def organize_by_proj(entries):
        layer_groups = defaultdict(list)
        for key, pcs in entries.items():
            parts = key.split("_")
            proj = "_".join(parts[:2])
            layer = int(parts[-1])
            layer_groups[proj].append((layer, pcs))
        return layer_groups

    all_pairs = [(i, j) for i, j in zip(sorted(pc_orders_all_1.items()), sorted(pc_orders_all_2.items())) if i[0] == j[0]]
    global_results = compute_all_scores(all_pairs)

    group1, group2 = organize_by_proj(pc_orders_all_1), organize_by_proj(pc_orders_all_2)
    local_results = []
    for proj in group1:
        if proj not in group2: continue
        entries1 = dict(group1[proj])
        entries2 = dict(group2[proj])
        shared_keys = sorted(set(entries1) & set(entries2))
        pairs = [((f"{proj}_{k}", entries1[k]), (f"{proj}_{k}", entries2[k])) for k in shared_keys]
        proj_results = compute_all_scores(pairs)
        for pc in ["PC1", "PC2", "PC3"]:
            local_results.append({
                "layer_type": proj,
                "PC": pc,
                "avg_top_k": np.mean(proj_results[pc]["top_k"]),
                "avg_spearman": np.mean(proj_results[pc]["spearman"]),
                "avg_kendall": np.mean(proj_results[pc]["kendall"]),
                "pairs_compared": len(proj_results[pc]["top_k"])
            })

    global_summary = [{
        "layer_type": "GLOBAL",
        "PC": pc,
        "avg_top_k": np.mean(global_results[pc]["top_k"]),
        "avg_spearman": np.mean(global_results[pc]["spearman"]),
        "avg_kendall": np.mean(global_results[pc]["kendall"]),
        "pairs_compared": len(global_results[pc]["top_k"])
    } for pc in ["PC1", "PC2", "PC3"]]

    return pd.DataFrame(local_results + global_summary)


def compute_overlap_from_pc_orders_all_metrics(pc_orders_all, excluded={"neutral", "condescension", "sarcastic"}, k=3):
    def filtered(order): return [e for e in order if e not in excluded]

    def top_bottom_sets(order):
        f = filtered(order)
        return set(f[:k]), set(f[-k:])

    def score_top_k(o1, o2):
        t1, b1 = top_bottom_sets(o1)
        t2, b2 = top_bottom_sets(o2)
        same = len(t1 & t2) + len(b1 & b2)
        flipped = len(t1 & b2) + len(b1 & t2)
        return max(same, flipped) / (2 * k)

    def score_spearman(o1, o2):
        f1, f2 = filtered(o1), filtered(o2)
        if set(f1) != set(f2): return 0.0
        rank_map = {e: i for i, e in enumerate(f1)}
        ranks = [rank_map[e] for e in f2]
        rho_fwd, _ = spearmanr(range(len(f1)), ranks)
        rho_rev, _ = spearmanr(range(len(f1)), ranks[::-1])
        return max((rho_fwd + 1) / 2 if not np.isnan(rho_fwd) else 0.5,
                   (rho_rev + 1) / 2 if not np.isnan(rho_rev) else 0.5)

    def score_kendall(o1, o2):
        f1, f2 = filtered(o1), filtered(o2)
        if set(f1) != set(f2): return 0.0
        rank_map = {e: i for i, e in enumerate(f1)}
        ranks = [rank_map[e] for e in f2]
        tau_fwd, _ = kendalltau(range(len(f1)), ranks)
        tau_rev, _ = kendalltau(range(len(f1)), ranks[::-1])
        return max((tau_fwd + 1) / 2 if not np.isnan(tau_fwd) else 0.5,
                   (tau_rev + 1) / 2 if not np.isnan(tau_rev) else 0.5)

    def compute_all_scores(pairs):
        results = defaultdict(lambda: defaultdict(list))
        for (key1, pcs1), (key2, pcs2) in pairs:
            for pc in ["PC1", "PC2", "PC3"]:
                o1, o2 = pcs1[pc], pcs2[pc]
                results[pc]["top_k"].append(score_top_k(o1, o2))
                results[pc]["spearman"].append(score_spearman(o1, o2))
                results[pc]["kendall"].append(score_kendall(o1, o2))
        return results

    layer_groups = defaultdict(list)
    all_entries = []
    for key, pcs in pc_orders_all.items():
        parts = key.split("_")
        proj = "_".join(parts[:2])
        layer = int(parts[-1])
        layer_groups[proj].append((layer, pcs))
        all_entries.append((key, pcs))

    global_results = compute_all_scores(combinations(all_entries, 2))

    local_results = []
    for proj, entries in layer_groups.items():
        proj_results = compute_all_scores(combinations(entries, 2))
        for pc in ["PC1", "PC2", "PC3"]:
            local_results.append({
                "layer_type": proj,
                "PC": pc,
                "avg_top_k": np.mean(proj_results[pc]["top_k"]),
                "avg_spearman": np.mean(proj_results[pc]["spearman"]),
                "avg_kendall": np.mean(proj_results[pc]["kendall"]),
                "pairs_compared": len(proj_results[pc]["top_k"])
            })

    global_summary = []
    for pc in ["PC1", "PC2", "PC3"]:
        global_summary.append({
            "layer_type": "GLOBAL",
            "PC": pc,
            "avg_top_k": np.mean(global_results[pc]["top_k"]),
            "avg_spearman": np.mean(global_results[pc]["spearman"]),
            "avg_kendall": np.mean(global_results[pc]["kendall"]),
            "pairs_compared": len(global_results[pc]["top_k"])
        })

    return pd.DataFrame(local_results + global_summary)

def compare_pc_orders_across_datasets(synth_centroids_all, goemo_centroids_all, goemo_to_synth_map, exclude={"neutral", "condescension", "sarcastic", "envy"}, k=3):
    results = []

    for layer_key in synth_centroids_all:
        for pc in ["PC1", "PC2", "PC3"]:
            synth_centroids = synth_centroids_all[layer_key][pc]
            goemo_centroids = goemo_centroids_all[layer_key][pc]

            remapped = defaultdict(list)
            for emo, vec in goemo_centroids.items():
                if emo in goemo_to_synth_map:
                    remapped[goemo_to_synth_map[emo]].append(vec)

            goemo_mapped = {
                emo: torch.stack(vecs).mean(0)
                for emo, vecs in remapped.items()
                if emo in synth_centroids and emo not in exclude
            }

            shared_labels = sorted(set(synth_centroids) & set(goemo_mapped) - set(exclude))
            if len(shared_labels) < k:
                continue

            synth_vals = torch.stack([synth_centroids[l] for l in shared_labels])
            goemo_vals = torch.stack([goemo_mapped[l] for l in shared_labels])
            synth_proj, goemo_proj = synth_vals[:, 0], goemo_vals[:, 0]

            rho_fwd, _ = spearmanr(synth_proj, goemo_proj)
            rho_rev, _ = spearmanr(synth_proj, -goemo_proj)
            rho = max(rho_fwd, rho_rev, key=lambda x: abs(x) if not np.isnan(x) else -1)
            rho = 0.5 if np.isnan(rho) else (rho + 1) / 2

            tau_fwd, _ = kendalltau(synth_proj, goemo_proj)
            tau_rev, _ = kendalltau(synth_proj, -goemo_proj)
            tau = max(tau_fwd, tau_rev, key=lambda x: abs(x) if not np.isnan(x) else -1)
            tau = 0.5 if np.isnan(tau) else (tau + 1) / 2

            synth_order = [l for _, l in sorted(zip(synth_proj, shared_labels), reverse=True)]
            goemo_order = [l for _, l in sorted(zip(goemo_proj, shared_labels), reverse=True)]
            flipped_goemo_order = [l for _, l in sorted(zip(-goemo_proj, shared_labels), reverse=True)]

            top_synth, bottom_synth = set(synth_order[:k]), set(synth_order[-k:])
            top_goemo, bottom_goemo = set(goemo_order[:k]), set(goemo_order[-k:])
            top_flip, bottom_flip = set(flipped_goemo_order[:k]), set(flipped_goemo_order[-k:])

            overlap_normal = len(top_synth & top_goemo) + len(bottom_synth & bottom_goemo)
            overlap_flip = len(top_synth & bottom_flip) + len(bottom_synth & top_flip)
            topk_overlap = max(overlap_normal, overlap_flip) / (2 * k)

            results.append({
                "layer": layer_key,
                "PC": pc,
                "topk_overlap": topk_overlap,
                "spearman": rho,
                "kendall": tau,
                "num_shared": len(shared_labels)
            })

    return pd.DataFrame(results)

def plot3d_axes(vecs, labels):
    svd_manifold = centered_svd_activations(vecs, labels)

    # Project vectors into 3D space
    vecs_projected = (vecs - svd_manifold["mean"]) @ svd_manifold["Vh"][:6].T
    # vecs_3d = vecs_projected[:, :3]
    vecs_3d = vecs_projected[:, 3:]

    # Filter out unwanted labels
    valid_mask = np.array([e not in {"condescension", "sarcastic"} for e in labels])
    vecs_3d = vecs_3d[valid_mask]
    labels_filtered = np.array(labels)[valid_mask]

    # Compute centroids
    emotions = np.unique(labels_filtered)
    colors = plt.cm.tab10.colors
    centroids = {
        emotion: vecs_3d[labels_filtered == emotion].mean(axis=0)
        for emotion in emotions
    }

    # Plot
    fig = plt.figure(figsize=(10, 7))
    ax = fig.add_subplot(111, projection="3d")
    ax.set_title("Emotion Centroids in 3D Latent Space", fontsize=14, weight="bold")
    # ax.set_xlabel("Valence", fontsize=12)
    # ax.set_ylabel("Dominance or Perceived Control", fontsize=12)
    # ax.set_zlabel("Approach–Avoidance Motivation", fontsize=12)
    ax.set_xlabel("Arousal or Urgency", fontsize=12)
    ax.set_ylabel("Emotional Volatility or Temporal Intensity", fontsize=12)
    ax.set_zlabel("Self-Conscious (tentative)", fontsize=12)

    for i, (emotion, coord) in enumerate(centroids.items()):
        ax.scatter(*coord, color=colors[i], s=70, edgecolor="k", label=emotion)
        ax.text(
            coord[0], coord[1], coord[2] + 0.025,  # shift label upward in Z
            emotion, fontsize=9, ha='center', va='bottom'
        )
    # After setting axis labels and title
    def filter_tick_labels(axis):
        ticks = axis.get_ticklocs()
        labels = ["" if abs(t) > 1e-6 else "0" for t in ticks]
        return labels

    ax.set_xticklabels(filter_tick_labels(ax.xaxis))
    ax.set_yticklabels(filter_tick_labels(ax.yaxis))
    ax.set_zticklabels(filter_tick_labels(ax.zaxis))

    ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", title="Emotion")
    plt.tight_layout()
    plt.show()


def plot_tsne(layer, block, proj, excluded={"condescension", "sarcastic"}):
    result = load_layer_data(layer, block, proj, "hidden_state_dumps_mean", "synth", None)
    labels, vecs, indices_synth = result
    vecs = vecs[torch.tensor([i not in excluded for i in labels])]
    labels = labels[torch.tensor([i not in excluded for i in labels])]

    tsne = TSNE(n_components=2, perplexity=30, init="pca", random_state=0)
    proj_tsne = tsne.fit_transform(vecs)

    df_tsne = pd.DataFrame(proj_tsne, columns=["tSNE-1", "tSNE-2"])
    df_tsne["Emotion"] = labels

    emotions = sorted(set(labels))
    palette = sns.color_palette("tab10", n_colors=len(emotions))
    emotion_palette = dict(zip(emotions, palette))


    # t-SNE plot
    sns.scatterplot(data=df_tsne, x="tSNE-1", y="tSNE-2", hue="Emotion", palette=emotion_palette, alpha=0.6, s=20)

    handles, labels = plt.gca().get_legend_handles_labels()
    plt.legend(handles=handles, labels=labels, loc="upper right", ncol=1, labelspacing=0.)
    plt.title("t-SNE Projection of Mean-Pooled Hidden States")
    # Shared legend

    plt.tight_layout()
    plt.show()


goemotions_to_synth = defaultdict(str, {
    "amusement": "happy",
    "anger": "anger",
    "annoyance": "anger",
    "disapproval": "disgust",
    "disgust": "disgust",
    "excitement": "excitement",
    "fear": "fear",
    "nervousness": "fear",
    "joy": "happy",
    "love": "happy",
    "gratitude": "happy",
    "optimism": "happy",
    "sadness": "sad",
    "disappointment": "sad",
    "grief": "sad",
    "remorse": "sad",
    "surprise": "surprise",
    "neutral": "neutral"
})

if __name__ == "__main__":
    model = "llama"
    pc_orders_all_synth = {}
    pc_orders_all_goemotions = {}
    pc_orders_all_abbr_goemotions = {}
    manifolds_all = {}
    centroids_all_synth = {}
    centroids_all_goemotions = {}
    centroids_all_abbr_goemotions = {}
    n_layers = max([int(i.split("_")[2]) for i in os.listdir(f"{model}_hidden_state_dumps_mean") if "synth" in i and "all" not in i and "layer" in i]) + 1
    with tqdm(total=n_layers) as pbar:
        for layer in range(n_layers):
            # if os.path.exists(f"emotional_manifold/{model}_manifold_slice_layer_{layer}.pt"):
            #     try:
            #         layer_result = torch.load(f"emotional_manifold/{model}_manifold_slice_layer_{layer}.pt")
            #         if isinstance(layer_result, list):
            #             [manifolds_all.update(i) for i in layer_result]
            #         elif isinstance(layer_result, dict):
            #             manifolds_all.update(layer_result)
            #         pbar.update(1)
            #         continue
            #     except RuntimeError:
            #         pass
            for block, proj in valid_blocks:
                result = load_layer_data(layer, block, proj, f"{model}_hidden_state_dumps_mean", "synth", None)
                if result is None: continue
                labels_synth, vecs_synth, indices_synth = result
                result = load_layer_data(layer, block, proj, f"{model}_hidden_state_dumps_mean", "goemotions", None)
                if result is None: continue
                labels_goemotions, vecs_goemotions, indices_goemotions = result


                new_labels_goemotions = [goemotions_to_synth[i] for i in labels_goemotions]
                goemotions_mask = np.array([i != "" for i in new_labels_goemotions])
                abbr_vecs_goemotions = vecs_goemotions[goemotions_mask]; abbr_labels_goemotions = np.array(new_labels_goemotions)[goemotions_mask]
                #
                # cur_manifold = extract_high_snr_components(vecs_goemotions, labels_goemotions, excluded={"sarcastic", "condescension"}, snr_thresh=0.1)
                #
                cur_manifold = centered_svd_activations(vecs_synth, labels_synth)
                projected_hidden_states_synth, centroids_synth = project_onto_manifold(vecs_synth, labels_synth, cur_manifold)
                projected_hidden_states_goemotions, centroids_goemotions = project_onto_manifold(vecs_goemotions, labels_goemotions, cur_manifold)
                projected_hidden_states_abbr_goemotions, centroids_abbr_goemotions = project_onto_manifold(abbr_vecs_goemotions, abbr_labels_goemotions, cur_manifold)
                # #
                key = f"{proj}_{layer}"
                pc_orders_all_synth[key] = {f"PC{pc_index + 1}": get_pc_axis_order(centroids_synth, pc_index) for pc_index in range(3)}
                # pc_orders_all_goemotions[key] = {f"PC{pc_index + 1}": get_pc_axis_order(centroids_goemotions, pc_index) for pc_index in range(3)}
                # pc_orders_all_abbr_goemotions[key] = {f"PC{pc_index + 1}": get_pc_axis_order(centroids_abbr_goemotions, pc_index) for pc_index in range(3)}
                centroids_all_synth[key] = {f"PC{pc_index + 1}": centroids_synth for pc_index in range(3)}
                centroids_all_goemotions[key] = {f"PC{pc_index + 1}": centroids_goemotions for pc_index in range(3)}
                centroids_all_abbr_goemotions[key] = {f"PC{pc_index + 1}": centroids_abbr_goemotions for pc_index in range(3)}
                manifolds_all[key] = cur_manifold

            result = {k: v for k, v in manifolds_all.items() if k.endswith(f"_{layer}")}
            # torch.save(result, f"emotional_manifold/{model}_manifold_slice_layer_{layer}.pt")
            pbar.update(1)

    torch.save({
        "manifolds": manifolds_all,
        "centroids_synth": centroids_all_synth,
        "centroids_goemotions": centroids_all_goemotions,
        "centroids_abbr_goemotions": centroids_all_abbr_goemotions
    }, f"hidden_state_synth_no_lora_{model}.pt")
    # num = max(int(k.rsplit("_", 1)[-1]) for k in manifolds_all) + 1
    # result = [
    #     {k: v for k, v in manifolds_all.items() if k.endswith(f"_{i}")}
    #     for i in range(num)
    # ]
    # _ = [
    #     torch.save(d, f"emotional_manifold/{model}_manifold_slice_layer_{i}.pt")
    #     for i, d in enumerate(result)
    # ]
    manifolds_all, centroids_all_synth, centroids_all_goemotions, _ = torch.load(f"hidden_state_synth_no_lora_{model}.pt", weights_only=False).values()


    pc_orders_all_synth = {key: {f"PC{pc_index + 1}": get_pc_axis_order(value["PC1"], pc_index) for pc_index in range(3)} for key, value in centroids_all_synth.items()}
    pc_orders_all_goemotions = {key: {f"PC{pc_index + 1}": get_pc_axis_order(value["PC1"], pc_index) for pc_index in range(3)} for key, value in centroids_all_goemotions.items()}

    pc_overlap_synth = compute_overlap_from_pc_orders_all_metrics(pc_orders_all_synth)
    pc_overlap_goemotions = compute_overlap_from_pc_orders_all_metrics(pc_orders_all_goemotions)
    centroid_pc_overlap = compare_pc_orders_across_datasets(centroids_all_synth, centroids_all_abbr_goemotions, goemotions_to_synth)
    layer_types = ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

    [[centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC1") & (centroid_pc_overlap["layer"].str.contains(f"{i}"))]["topk_overlap"].mean(), centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC1") & (centroid_pc_overlap["layer"].str.contains(f"{i}"))]["spearman"].mean(), centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC1") & (centroid_pc_overlap["layer"].str.contains(f"{i}"))]["kendall"].mean()] for i in range(32)]
    [[i, centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC1") & (centroid_pc_overlap["layer"].str.contains(f"{i}"))]["topk_overlap"].mean(), centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC1") & (centroid_pc_overlap["layer"].str.contains(f"{i}"))]["spearman"].mean(), centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC1") & (centroid_pc_overlap["layer"].str.contains(f"{i}"))]["kendall"].mean()] for i in layer_types]
    centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC1")]["topk_overlap"].mean(), centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC1")]["spearman"].mean(), centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC1")]["kendall"].mean()

    [[centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC2") & (centroid_pc_overlap["layer"].str.contains(f"{i}"))]["topk_overlap"].mean(), centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC2") & (centroid_pc_overlap["layer"].str.contains(f"{i}"))]["spearman"].mean(), centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC2") & (centroid_pc_overlap["layer"].str.contains(f"{i}"))]["kendall"].mean()] for i in range(32)]
    [[i, centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC2") & (centroid_pc_overlap["layer"].str.contains(f"{i}"))]["topk_overlap"].mean(), centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC2") & (centroid_pc_overlap["layer"].str.contains(f"{i}"))]["spearman"].mean(), centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC2") & (centroid_pc_overlap["layer"].str.contains(f"{i}"))]["kendall"].mean()] for i in layer_types]
    centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC2")]["topk_overlap"].mean(), centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC2")]["spearman"].mean(), centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC2")]["kendall"].mean()

    [[centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC3") & (centroid_pc_overlap["layer"].str.contains(f"{i}"))]["topk_overlap"].mean(), centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC3") & (centroid_pc_overlap["layer"].str.contains(f"{i}"))]["spearman"].mean(), centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC3") & (centroid_pc_overlap["layer"].str.contains(f"{i}"))]["kendall"].mean()] for i in range(32)]
    [[i, centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC3") & (centroid_pc_overlap["layer"].str.contains(f"{i}"))]["topk_overlap"].mean(), centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC3") & (centroid_pc_overlap["layer"].str.contains(f"{i}"))]["spearman"].mean(), centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC3") & (centroid_pc_overlap["layer"].str.contains(f"{i}"))]["kendall"].mean()] for i in layer_types]
    centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC3")]["topk_overlap"].mean(), centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC3")]["spearman"].mean(), centroid_pc_overlap[(centroid_pc_overlap["PC"] == "PC3")]["kendall"].mean()

    import IPython; IPython.embed()
    # # import IPython; IPython.embed()
    #
    #
