
import pandas as pd
import numpy as np

import os
import json
import pickle

from tqdm.notebook import tqdm

from scipy.spatial.distance import pdist, cdist
from scipy.stats import wasserstein_distance

from sklearn.utils.extmath import randomized_svd

from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    confusion_matrix,
)

from sklearn.model_selection import StratifiedShuffleSplit

# ============================================================
# ====================== BASE FUNCTIONS ======================
# ============================================================

def loadParquet(fn, unifyYears=False):
    df = pd.read_parquet(fn)

    if unifyYears:
        df_2020 = df[df['year'] == 2020]
        df_2022 = df[df['year'] == 2022].copy()

        df_2022['prompt_id'] = df_2022['prompt_id'] + 100

        df = pd.concat([df_2020, df_2022], ignore_index=True)
    
    return df

def split_by_label(X, y):
    """
    Split embeddings by class label.

    Returns
    -------
    X_G : np.ndarray
        Genuine responses
    X_H : np.ndarray
        Hallucinated responses
    """
    X_G = X[~y]
    X_H = X[y]

    return X_G, X_H

def extract_prompt_data(df, model_id, prompt_id, embedding_label="response_embeddings"):
    """
    Extract embeddings X and labels y for a given (model_id, prompt_id).

    Returns
    -------
    X : np.ndarray, shape (n_samples, d)
    y : np.ndarray, shape (n_samples,)
        0 = genuine, 1 = hallucination
    """
    sub = df[
        (df["model_id"] == model_id) &
        (df["prompt_id"] == prompt_id)
    ]

    X = np.stack(sub[embedding_label].values)
    y = sub["hallucination"].values.astype(bool)

    return X, y


def fisher_direction(
    X_G,
    X_H,
    lambda_reg=1e-3,
    normalise=True,
    normalise_by_trace=True,
):
    """
    Compute regularised Fisher discriminant direction with
    trace-adaptive regularisation.

    Parameters
    ----------
    X_G, X_H : np.ndarray
        Shape (n_G, d), (n_H, d)
    lambda_reg : float
        Dimensionless regularisation strength
    normalise : bool
        Whether to L2-normalise the output direction
    normalise_by_trace : bool
        Whether to normalise lambda parameter by the average trace of S_W

    Returns
    -------
    v : np.ndarray, shape (d,)
    """
    mu_G = X_G.mean(axis=0)
    mu_H = X_H.mean(axis=0)

    # within-class scatter (biased = MLE)
    S_G = np.cov(X_G, rowvar=False, bias=True)
    S_H = np.cov(X_H, rowvar=False, bias=True)

    S_W = S_G + S_H
    d = S_W.shape[0]

    # trace-normalised regularisation
    if normalise_by_trace:
        trace_SW = np.trace(S_W)
        lambda_eff = lambda_reg * trace_SW / d
    else:
        lambda_eff = lambda_reg

    S_W_reg = S_W + lambda_eff * np.eye(d)

    v = np.linalg.solve(S_W_reg, mu_H - mu_G)

    if normalise:
        norm = np.linalg.norm(v)
        if norm > 0:
            v /= norm

    return v

def generate_fixed_test_sets(X, y, n_splits=5, test_fraction=0.2, random_state=42):
    sss = StratifiedShuffleSplit(n_splits=n_splits, test_size=test_fraction, random_state=random_state)
    trn_splits = []
    tst_splits = []
    for trn_idx, tst_idx in sss.split(X, y):
        X_trn, y_trn = X[trn_idx], y[trn_idx]
        X_tst, y_tst = X[tst_idx], y[tst_idx]
        trn_splits.append((X_trn, y_trn))
        tst_splits.append((X_tst, y_tst))
    return trn_splits, tst_splits

def subsample_training_set(X_train_full, y_train_full, fraction):
    """Balanced subsample by class"""
    X_G, X_H = split_by_label(X_train_full, y_train_full)

    n_G = max(1, int(fraction * len(X_G)))
    n_H = max(1, int(fraction * len(X_H)))

    rng = np.random.default_rng()
    idx_G = rng.choice(len(X_G), size=n_G, replace=False)
    idx_H = rng.choice(len(X_H), size=n_H, replace=False)

    X_sub = np.vstack([X_G[idx_G], X_H[idx_H]])
    y_sub = np.concatenate([np.zeros(n_G, dtype=bool), np.ones(n_H, dtype=bool)])

    return X_sub, y_sub


# ============================================================
# ============= STRUCTURAL ANALYSIS FUNCTIONS  ===============
# ============================================================

def compute_distance_distributions(X_G, X_H):
    """
    Compute intra- and inter-class distance distributions.

    Returns
    -------
    D_GG : np.ndarray
    D_HH : np.ndarray
    D_GH : np.ndarray
    """
    D_GG = pdist(X_G) if len(X_G) > 1 else np.array([])
    D_HH = pdist(X_H) if len(X_H) > 1 else np.array([])
    D_GH = cdist(X_G, X_H).ravel() if len(X_G) > 0 and len(X_H) > 0 else np.array([])

    return D_GG, D_HH, D_GH

def wasserstein_GG_HH(D_GG, D_HH):
    """
    Wasserstein distance between intra-class distance distributions.
    """
    if len(D_GG) == 0 or len(D_HH) == 0:
        return np.nan
    return wasserstein_distance(D_GG, D_HH)

def wasserstein_null_model(X, y, n_permutations=100, random_state=None):
    """
    Null distribution of W(GG, HH) under random relabelling.
    """
    rng = np.random.default_rng(random_state)
    W_null = []

    for _ in range(n_permutations):
        y_perm = rng.permutation(y)
        X_Gp, X_Hp = split_by_label(X, y_perm)

        if len(X_Gp) < 2 or len(X_Hp) < 2:
            continue

        D_GG_p, D_HH_p, _ = compute_distance_distributions(X_Gp, X_Hp)
        W_null.append(wasserstein_GG_HH(D_GG_p, D_HH_p))

    W_null = np.array(W_null)
    return W_null

def analyse_prompt(
    df,
    model_id,
    prompt_id,
    lambda_reg=1e-3,
    n_permutations=100,
    random_state=None,
    embedding_label="response_embeddings"
):
    X, y = extract_prompt_data(df, model_id, prompt_id, embedding_label=embedding_label)
    X_G, X_H = split_by_label(X, y)

    n_G = len(X_G)
    n_H = len(X_H)

    res = {
        "model_id": model_id,
        "prompt_id": prompt_id,
        "n_G": n_G,
        "n_H": n_H,
    }

    # ---- original space distances ----
    D_GG, D_HH, D_GH = compute_distance_distributions(X_G, X_H)
    res["D_GG"] = D_GG
    res["D_HH"] = D_HH
    res["D_GH"] = D_GH

    res["W_GG_HH"] = wasserstein_GG_HH(D_GG, D_HH)

    # ---- Fisher space ----
    if n_G >= 2 and n_H >= 2:
        v = fisher_direction(X_G, X_H, lambda_reg=lambda_reg)
        Z = X @ v

        Z_G = Z[y == 0][:, None]
        Z_H = Z[y == 1][:, None]

        D_GG_z, D_HH_z, D_GH_z = compute_distance_distributions(Z_G, Z_H)

        res["v_fisher"] = v
        res["D_GG_z"] = D_GG_z
        res["D_HH_z"] = D_HH_z
        res["D_GH_z"] = D_GH_z
        res["W_GG_HH_z"] = wasserstein_GG_HH(D_GG_z, D_HH_z)

    # ---- null model ----
    if n_permutations is not None and n_G >= 2 and n_H >= 2:
        W_null = wasserstein_null_model(
            X, y,
            n_permutations=n_permutations,
            random_state=random_state
        )
        
        res["W_null_samples"] = W_null
        if len(W_null) > 0:
            res["W_null_mean"] = W_null.mean()
            res["W_null_std"] = W_null.std()
        else:
            res["W_null_mean"] = None
            res["W_null_std"]  = None
    else:
        res["W_null_samples"] = None
        res["W_null_mean"] = None
        res["W_null_std"]  = None

    return res

def collect_prompt_result(res, min_per_class_plot=5):
    """
    Split analyse_prompt output into:
    - scalar metadata row
    - geometry payload
    """
    m = res["model_id"]
    p = res["prompt_id"]

    n_G = res["n_G"]
    n_H = res["n_H"]
    n_total = n_G + n_H

    frac_G = n_G / n_total if n_total > 0 else np.nan
    frac_H = n_H / n_total if n_total > 0 else np.nan

    valid_geom = (n_G >= 2) and (n_H >= 2)
    valid_plot = (n_G >= min_per_class_plot) and (n_H >= min_per_class_plot)

    # ---- scalar row ----
    row = {
        "model_id": m,
        "prompt_id": p,
        "n_total": n_total,
        "n_G": n_G,
        "n_H": n_H,
        "frac_G": frac_G,
        "frac_H": frac_H,
        "W_GG_HH": res.get("W_GG_HH", np.nan),
        "W_GG_HH_z": res.get("W_GG_HH_z", np.nan),
        "valid_geom": valid_geom,
        "valid_plot": valid_plot,
    }

    # optional null-model statistics
    row["W_null_mean"] = res["W_null_mean"]
    row["W_null_std"] = res["W_null_std"]
    if row["W_null_mean"] is not None:
        row["delta_W"] = row["W_GG_HH"] - row["W_null_mean"]
    else:
        row["delta_W"] = None

    if "W_GG_HH_z" in res and "W_null_mean" in res:
        row["delta_W_z"] = row["W_GG_HH_z"] - row["W_null_mean"]

    # ---- geometry payload ----
    gs = {
        "D_GG": res.get("D_GG"),
        "D_HH": res.get("D_HH"),
        "D_GH": res.get("D_GH"),
        "D_GG_z": res.get("D_GG_z"),
        "D_HH_z": res.get("D_HH_z"),
        "D_GH_z": res.get("D_GH_z"),
        "v_fisher": res.get("v_fisher"),
    }

    return row, gs

def run_structural_analysis(
    df,
    lambda_reg=1e-3,
    n_permutations=100,
    random_state=42,
    min_per_class_plot=5,
    use_cache=False,
    cache_dir="cache",
    overwrite_cache=False,
    embedding_label="response_embeddings"
):
    """
    Run structural analysis over all (model_id, prompt_id) pairs.

    Parameters
    ----------
    df : pd.DataFrame
    lambda_reg : float
    n_permutations : int
    random_state : int
    min_per_class_plot : int
    use_cache : bool
        Whether to load/save results from filesystem
    cache_dir : str or None
        Directory where cache files are stored
    overwrite_cache : bool
        If True, recompute even if cache exists

    Returns
    -------
    results_df : pd.DataFrame
    geometry_store : dict
        (model_id, prompt_id) -> geometry payload
    null_store : dict
        (model_id, prompt_id) -> np.ndarray of null Wasserstein samples
    """

    # ---- cache paths ----
    if use_cache:
        if cache_dir is None:
            raise ValueError("cache_dir must be provided when use_cache=True")

        os.makedirs(cache_dir, exist_ok=True)

        results_path = os.path.join(cache_dir, f"results_df-{lambda_reg}.parquet")
        geometry_path = os.path.join(cache_dir, f"geometry_store-{lambda_reg}.pkl")
        null_path = os.path.join(cache_dir, f"null_store-{lambda_reg}.pkl")
        meta_path = os.path.join(cache_dir, f"meta-{lambda_reg}.json")

        cache_exists = all(
            os.path.exists(p)
            for p in [results_path, geometry_path, null_path, meta_path]
        )

        if cache_exists and not overwrite_cache:
            results_df = pd.read_parquet(results_path)

            with open(geometry_path, "rb") as f:
                geometry_store = pickle.load(f)

            with open(null_path, "rb") as f:
                null_store = pickle.load(f)

            print("Cache correctly loaded.")

            return results_df, geometry_store, null_store

    # ---- computation ----
    rows = []
    geometry_store = {}
    null_store = {}

    grouped = df.groupby(["model_id", "prompt_id"])

    for (m, p), _ in tqdm(grouped):
        res = analyse_prompt(
            df,
            model_id=m,
            prompt_id=p,
            lambda_reg=lambda_reg,
            n_permutations=n_permutations,
            random_state=random_state,
            embedding_label=embedding_label
        )

        row, gs = collect_prompt_result(
            res,
            min_per_class_plot=min_per_class_plot
        )

        geometry_store[(m, p)] = gs
        null_store[(m, p)] = res.get("W_null_samples")
        rows.append(row)

    results_df = pd.DataFrame(rows)

    # ---- save cache ----
    if use_cache:
        results_df.to_parquet(results_path, index=False)

        with open(geometry_path, "wb") as f:
            pickle.dump(geometry_store, f)

        with open(null_path, "wb") as f:
            pickle.dump(null_store, f)

        meta = {
            "lambda_reg": lambda_reg,
            "n_permutations": n_permutations,
            "random_state": random_state,
            "min_per_class_plot": min_per_class_plot,
            "n_rows": len(results_df),
        }

        with open(meta_path, "w") as f:
            json.dump(meta, f, indent=2)

        print("Cache correctly dumped.")

    return results_df, geometry_store, null_store

# ============================================================
# =================== PROJECTION CLASSES =====================
# ============================================================

class ProjectionBase:
    """
    Maps X in R^d -> Z in R^k
    """
    def fit(self, X, y=None):
        raise NotImplementedError

    def transform(self, X):
        raise NotImplementedError

    def fit_transform(self, X, y=None):
        self.fit(X, y)
        return self.transform(X)
    
class FisherProjection(ProjectionBase):
    def __init__(self, lambda_reg=1e-3, normalise=True, normalise_by_trace=True):
        self.lambda_reg = lambda_reg
        self.normalise = normalise
        self.normalise_by_trace = normalise_by_trace
        self.v = None

    def fit(self, X, y):
        X_G, X_H = split_by_label(X, y)

        self.v = fisher_direction(
            X_G,
            X_H,
            lambda_reg=self.lambda_reg,
            normalise=self.normalise,
            normalise_by_trace=self.normalise_by_trace,
        )

    def transform(self, X):
        Z = X @ self.v
        return Z.reshape(-1, 1)



# class WhitenedPCAProjection(ProjectionBase):
#     def __init__(self, n_components=1):
#         self.n_components = n_components
#         self.mean_ = None
#         self.components_ = None
#         self.singular_values_ = None

#     def fit(self, X, y=None):
#         self.mean_ = X.mean(axis=0)
#         Xc = X - self.mean_

#         U, S, Vt = np.linalg.svd(Xc, full_matrices=False)

#         self.components_ = Vt[:self.n_components]
#         self.singular_values_ = S[:self.n_components]

#     def transform(self, X):
#         Xc = X - self.mean_
#         Z = Xc @ self.components_.T
#         Z /= (self.singular_values_ + 1e-12)
#         return Z

class WhitenedPCAProjection(ProjectionBase):
    def __init__(self, n_components=1, n_iter=3, random_state=0):
        self.n_components = n_components
        self.n_iter = n_iter
        self.random_state = random_state
        self.mean_ = None
        self.components_ = None
        self.singular_values_ = None

    def fit(self, X, y=None):
        self.mean_ = X.mean(axis=0)
        Xc = X - self.mean_

        U, S, Vt = randomized_svd(
            Xc,
            n_components=self.n_components,
            n_iter=self.n_iter,
            random_state=self.random_state,
        )

        self.components_ = Vt
        self.singular_values_ = S

    def transform(self, X):
        Xc = X - self.mean_
        Z = Xc @ self.components_.T
        Z /= (self.singular_values_ + 1e-12)
        return Z
    
class RandomProjection(ProjectionBase):
    def __init__(self, n_components=1, random_state=None):
        self.n_components = n_components
        self.random_state = random_state
        self.R = None

    def fit(self, X, y=None):
        rng = np.random.default_rng(self.random_state)
        d = X.shape[1]

        self.R = rng.normal(
            loc=0.0,
            scale=1.0 / np.sqrt(self.n_components),
            size=(d, self.n_components),
        )

    def transform(self, X):
        return X @ self.R


class SupervisedUMAPProjection(ProjectionBase):
    def __init__(
        self,
        n_components=1,
        n_neighbors=15,
        min_dist=0.1,
        random_state=42,
    ):
        import umap

        self.reducer = umap.UMAP(
            n_components=n_components,
            n_neighbors=n_neighbors,
            min_dist=min_dist,
            metric="euclidean",
            random_state=random_state,
        )

    def fit(self, X, y):
        self.reducer.fit(X, y)

    def transform(self, X):
        return self.reducer.transform(X)
    

# ============================================================
# ========================= LP CLASS =========================
# ============================================================

class WassersteinLabelPropagator:
    def __init__(self, projection: ProjectionBase):
        self.projection = projection
        self.Z_G = None
        self.Z_H = None
        self.ref_G = None
        self.ref_H = None

    def fit(self, X_train, y_train):
        self.projection.fit(X_train, y_train)

        Z = self.projection.transform(X_train)
        Z_G, Z_H = split_by_label(Z, y_train)

        self.Z_G = Z_G
        self.Z_H = Z_H

        self.ref_G = pdist(Z_G)
        self.ref_H = pdist(Z_H)

    def score_point(self, x):
        z = self.projection.transform(x[None, :])

        dG = cdist(z, self.Z_G).ravel()
        dH = cdist(z, self.Z_H).ravel()

        W_G = wasserstein_distance(dG, self.ref_G)
        W_H = wasserstein_distance(dH, self.ref_H)

        return W_G, W_H

    def predict_point(self, x):
        W_G, W_H = self.score_point(x)
        return int(W_H < W_G)

    def predict(self, X):
        return np.array([self.predict_point(x) for x in X])

    def margins(self, X):
        """Return signed margins per sample"""
        return np.array([self.score_point(x)[1] - self.score_point(x)[0] for x in X])

    def abs_margins(self, X):
        """Return absolute margins per sample"""
        return np.abs(self.margins(X))

    def margins_by_class(self, X, y):
        """
        Return a dictionary with signed and absolute margins per class
        Example output:
        {
            0: {"signed": array([...]), "abs": array([...])},
            1: {"signed": array([...]), "abs": array([...])},
        }
        """
        signed = self.margins(X)
        abs_val = np.abs(signed)
        classes = np.unique(y)
        return {
            c: {"signed": signed[y == c], "abs": abs_val[y == c]} 
            for c in classes
        }
    

# ============================================================
# ==================== EVALUATOR  CLASS ======================
# ============================================================


class LabelPropagationEvaluator:
    def __init__(self, detector, X_test, y_test, fisher_ref=None):
        self.detector = detector
        self.X_test = X_test
        self.y_test = y_test
        self.fisher_ref = fisher_ref

    def evaluate(self, ambiguity_eps=None, per_class=True):
        y_pred = self.detector.predict(self.X_test)
        margins = self.detector.margins(self.X_test)
        abs_margins = np.abs(margins)

        if ambiguity_eps is None:
            ambiguity_eps = margins.mean() - 3 * margins.std()

        tn, fp, fn, tp = confusion_matrix(self.y_test, y_pred, labels=[0, 1]).ravel()

        metrics = {
            "accuracy": accuracy_score(self.y_test, y_pred),
            "f1": f1_score(self.y_test, y_pred, zero_division=0),
            "precision": precision_score(self.y_test, y_pred, zero_division=0),
            "recall": recall_score(self.y_test, y_pred, zero_division=0),
            # confusion matrix entries
            "tn": tn,
            "fp": fp,
            "fn": fn,
            "tp": tp,
            # signed and absolute margins
            "mean_margin": margins.mean(),
            "std_margin": margins.std(),
            "mean_abs_margin": abs_margins.mean(),
            "std_abs_margin": abs_margins.std(),
            "ambiguous_frac": np.mean(abs_margins <= ambiguity_eps),
        }

        # optionally compute margins per class
        if per_class:
            classes = np.unique(self.y_test)
            for c in classes:
                cls_mask = self.y_test == c
                cls_m = margins[cls_mask]
                cls_abs = abs_margins[cls_mask]
                metrics[f"mean_margin_class_{c}"] = cls_m.mean()
                metrics[f"std_margin_class_{c}"] = cls_m.std()
                metrics[f"mean_abs_margin_class_{c}"] = cls_abs.mean()
                metrics[f"std_abs_margin_class_{c}"] = cls_abs.std()

        # optional: agreement with fisher projection
        if self.fisher_ref is not None:
            y_fisher = self.fisher_ref.predict(self.X_test)
            metrics["agreement_fisher"] = np.mean(y_pred == y_fisher)

            fisher_margins = self.fisher_ref.margins(self.X_test)
            confident = np.abs(fisher_margins) > np.percentile(np.abs(fisher_margins), 50)
            metrics["agreement_fisher_confident"] = np.mean(y_pred[confident] == y_fisher[confident])

        return metrics

# ============================================================
# ====================== EVAL FUNCTIONS ======================
# ============================================================

def run_label_propagation_experiment(
    df,
    model_id,
    prompt_id,
    projector_class,
    projector_kwargs=None,
    train_fractions=None,
    n_iter=10,
    test_fraction=0.2,
    n_splits=5,
    ref_lambda_reg=None,
    logskip=False,
    random_state=42,
    embedding_label="response_embeddings"
):
    """
    Run label propagation experiment with fixed test sets and multiple
    random subsamples of the training set per fraction.

    Parameters
    ----------
    df : pd.DataFrame
        Full dataset
    model_id : int
        Model index
    prompt_id : int
        Prompt index
    projector_class : class
        The detector class to use (must implement .fit(X_G, X_H) and .predict(X))
    projector_kwargs : dict
        Optional keyword arguments to pass to the detector class
    train_fractions : list of float | None
        Fractions of the training set to subsample; defaults to [1.0] (full training, None)
    n_iter : int
        Number of random subsamples per fraction
    test_fraction : float
        Fraction of data for test
    n_splits : int
        Number of fixed test splits
    ref_lambda_reg : float | None
        Value of the lambda_reg for the reference fisher projector. If None, no reference is evaluated
    logskip : bool
        Whether to display skipped (model, prompt) couples; default to False
    random_state : int
        RNG seed

    Returns
    -------
    pd.DataFrame
        Metrics for all splits, fractions, and iterations
    """

    if projector_kwargs is None:
        projector_kwargs = {}

    if train_fractions is None:
        train_fractions = [1.0]
        n_iter = 1

    # ---- extract full dataset ----
    X, y = extract_prompt_data(df, model_id, prompt_id, embedding_label=embedding_label)
    n_H = sum(y)
    n_G = len(y) - n_H
    if n_H < 5 or n_G < 5:
        if not logskip:
            print(f"Skipping model {model_id}, prompt {prompt_id} due to unbalancedeness (n_G = {n_G}, n_H = {n_H})")
        return None 

    # ---- fixed stratified splits ----
    trn_sets, tst_sets = generate_fixed_test_sets(
        X, y, n_splits=n_splits, test_fraction=test_fraction, random_state=random_state
    )

    results = []

    # ---- loop over fixed test sets ----
    for test_id, ((X_train_full, y_train_full), (X_test, y_test)) in enumerate(zip(trn_sets, tst_sets)):

        for tf in train_fractions:
            for iter_id in range(n_iter):
                # balanced random subsample of training set
                X_sub, y_sub = subsample_training_set(X_train_full, y_train_full, tf)

                X_G_sub, X_H_sub = split_by_label(X_sub, y_sub)
                if len(X_G_sub) < 2 or len(X_H_sub) < 2:
                    continue

                # Fisher reference detector
                if ref_lambda_reg is not None:
                    fisher_proj = FisherProjection(lambda_reg=ref_lambda_reg)
                    fisher_detector = WassersteinLabelPropagator(fisher_proj)
                    fisher_detector.fit(X_sub, y_sub)
                else:
                    fisher_detector = None

                # ---- fit detector ----
                projector = projector_class(**projector_kwargs)
                detector = WassersteinLabelPropagator(projector)
                detector.fit(X_sub, y_sub)

                # ---- evaluate ----
                evaluator = LabelPropagationEvaluator(detector, X_test, y_test, fisher_ref=fisher_detector)
                metrics = evaluator.evaluate()

                # add metadata
                metrics.update({
                    "train_fraction": tf,
                    "iter_id": iter_id,
                    "test_set_id": test_id,
                    "model_id": model_id,
                    "prompt_id": prompt_id,
                    "n_train": len(X_sub)
                })

                results.append(metrics)

    return pd.DataFrame(results)

def run_full_label_propagation_study(
    df,
    model_ids,
    prompt_ids_by_model,
    projector_class,
    projector_kwargs=None,
    train_fractions=None,
    ref_lambda_reg=None,
    n_iter=10,
    test_fraction=0.2,
    n_splits=5,
    random_state=42,
    use_cache=False,
    cache_dir="cache/label_propagation",
    overwrite_cache=False,
    logskip=False,
    embedding_label="response_embeddings"
):
    """
    Run label propagation experiments for multiple models/prompts using a generic detector.
    """

    if projector_kwargs is None:
        projector_kwargs = {}

    if train_fractions is None:
        train_fractions = [1.0]
        n_iter = 1

    if use_cache:
        os.makedirs(cache_dir, exist_ok=True)

    all_results = []

    for mid in tqdm(model_ids, desc="Model"):
        for pid in tqdm(prompt_ids_by_model[mid], desc="Prompt"):

            cache_path = None
            if use_cache:
                fname = f"model={mid}__prompt={pid}__detector={projector_class.__name__}.parquet"
                cache_path = os.path.join(cache_dir, fname)

                if os.path.exists(cache_path) and not overwrite_cache:
                    all_results.append(pd.read_parquet(cache_path))
                    continue

            # ---- compute ----
            res_df = run_label_propagation_experiment(
                df=df,
                model_id=mid,
                prompt_id=pid,
                projector_class=projector_class,
                projector_kwargs=projector_kwargs,
                train_fractions=train_fractions,
                n_iter=n_iter,
                test_fraction=test_fraction,
                n_splits=n_splits,
                ref_lambda_reg=ref_lambda_reg,
                logskip=logskip,
                random_state=random_state,
                embedding_label=embedding_label
            )

            if res_df is None or len(res_df) == 0:
                continue

            all_results.append(res_df)

            # ---- save subcache ----
            if use_cache:
                res_df.to_parquet(cache_path, index=False)

    if not all_results:
        return pd.DataFrame()

    results_lp = pd.concat(all_results, ignore_index=True)

    # ---- global metadata (optional) ----
    if use_cache:
        meta = {
            "model_ids": list(model_ids),
            "train_fractions": train_fractions,
            "n_iter": n_iter,
            "test_fraction": test_fraction,
            "n_splits": n_splits,
            "random_state": random_state,
            "n_cached_pairs": len(all_results),
            "projector_class": projector_class.__name__,
            "projector_kwargs": projector_kwargs,
            "ref_lambda_reg": ref_lambda_reg,
            "embedding_label": embedding_label
        }

        with open(os.path.join(cache_dir, "meta.json"), "w") as f:
            json.dump(meta, f, indent=2)

    return results_lp


# ============================================================
# ============== LAMBDA SENSIBILITY EVALUATION  ==============
# ============================================================

def run_lambda_sensitivity_experiment(
    df,
    model_id,
    prompt_id,
    lambda_values,
    test_fraction=0.2,
    n_splits=5,
    logskip=False,
    random_state=42,
    embedding_label="response_embeddings",
):
    """
    Sensitivity analysis of Fisher regularisation parameter lambda.
    Uses full training set (no subsampling).
    """

    # ---- extract full dataset ----
    X, y = extract_prompt_data(df, model_id, prompt_id, embedding_label=embedding_label)
    n_H = sum(y)
    n_G = len(y) - n_H
    if n_H < 5 or n_G < 5:
        if not logskip:
            print(
                f"Skipping model {model_id}, prompt {prompt_id} "
                f"(n_G={n_G}, n_H={n_H})"
            )
        return None

    # ---- fixed stratified splits ----
    trn_sets, tst_sets = generate_fixed_test_sets(
        X,
        y,
        n_splits=n_splits,
        test_fraction=test_fraction,
        random_state=random_state,
    )

    results = []

    for split_id, ((X_train, y_train), (X_test, y_test)) in enumerate(zip(trn_sets, tst_sets)):

        for lambda_reg in lambda_values:

            fisher_proj = FisherProjection(lambda_reg=lambda_reg)
            fisher_detector = WassersteinLabelPropagator(fisher_proj)
            fisher_detector.fit(X_train, y_train)

            evaluator = LabelPropagationEvaluator(fisher_detector, X_test, y_test, fisher_ref=fisher_detector)
            metrics = evaluator.evaluate()

            metrics.update({
                "lambda_reg": lambda_reg,
                "split_id": split_id,
                "model_id": model_id,
                "prompt_id": prompt_id,
                "n_train": len(X_train),
                "n_test": len(X_test),
            })

            results.append(metrics)

    if not results:
        return None

    return pd.DataFrame(results)

def run_full_lambda_sensitivity_study(
    df,
    model_ids,
    prompt_ids_by_model,
    lambda_values,
    test_fraction=0.2,
    n_splits=5,
    random_state=42,
    use_cache=False,
    cache_dir="cache/lambda_sensitivity",
    overwrite_cache=False,
    logskip=False,
    embedding_label="response_embeddings"
):
    """
    Run lambda sensitivity analysis for all (model, prompt) pairs.
    """

    if use_cache:
        os.makedirs(cache_dir, exist_ok=True)

    all_results = []

    for mid in tqdm(model_ids, desc="Model"):
        for pid in tqdm(prompt_ids_by_model[mid], desc="Prompt"):

            cache_path = None
            if use_cache:
                fname = f"model={mid}__prompt={pid}.parquet"
                cache_path = os.path.join(cache_dir, fname)

                if os.path.exists(cache_path) and not overwrite_cache:
                    all_results.append(pd.read_parquet(cache_path))
                    continue

            res_df = run_lambda_sensitivity_experiment(
                df=df,
                model_id=mid,
                prompt_id=pid,
                lambda_values=lambda_values,
                test_fraction=test_fraction,
                n_splits=n_splits,
                random_state=random_state,
                logskip=logskip,
                embedding_label=embedding_label
            )

            if res_df is None or len(res_df) == 0:
                continue

            all_results.append(res_df)

            if use_cache:
                res_df.to_parquet(cache_path, index=False)

    if not all_results:
        return pd.DataFrame()

    results = pd.concat(all_results, ignore_index=True)

    if use_cache:
        meta = {
            "model_ids": list(model_ids),
            "lambda_values": list(lambda_values),
            "test_fraction": test_fraction,
            "n_splits": n_splits,
            "random_state": random_state,
            "n_cached_pairs": len(all_results),
            "embedding_label": embedding_label
        }

        with open(os.path.join(cache_dir, "meta.json"), "w") as f:
            json.dump(meta, f, indent=2)

    return results

# ============================================================
# ================== PLOTTING  CAPABILITIES ==================
# ============================================================


def build_model_size_order(model_names):
    """
    Returns:
      model_order: list of model_ids ordered by increasing parameter size
      model_rank: dict model_id -> rank
    """
    sizes = {}
    for mid, name in model_names.items():
        # extract number before 'B'
        size = int(name.split('-')[1][:-1])
        sizes[mid] = size

    model_order = sorted(sizes, key=lambda m: sizes[m])
    model_rank = {m: i for i, m in enumerate(model_order)}

    return model_order, model_rank

def select_prompt_by_fraction(
    df_model,
    mode="balanced",
    require_valid_plot=True,
    require_valid_geom=False
):
    """
    Select a prompt for a single model based on frac_G.
    """
    df_sel = df_model.copy()

    if require_valid_plot:
        df_sel = df_sel[df_sel["valid_plot"]]

    if require_valid_geom:
        df_sel = df_sel[df_sel["valid_geom"]]

    if df_sel.empty:
        return None

    if mode == "balanced":
        idx = (df_sel["frac_G"] - 0.5).abs().idxmin()
    elif mode == "most_genuine":
        idx = df_sel["frac_G"].idxmax()
    elif mode == "most_hallucinated":
        idx = df_sel["frac_G"].idxmin()
    else:
        raise ValueError(f"Unknown mode: {mode}")

    return df_sel.loc[idx]

def select_representative_prompts(results_df, model_id, require_valid_plot=True, require_valid_geom=True):
    """
    Return balanced, most genuine, most hallucinated prompts for a model.
    """
    df_model = results_df[results_df["model_id"] == model_id]

    return {
        "balanced": select_prompt_by_fraction(df_model, "balanced", require_valid_plot=require_valid_plot, require_valid_geom=require_valid_geom),
        "most_genuine": select_prompt_by_fraction(df_model, "most_genuine", require_valid_plot=require_valid_plot, require_valid_geom=require_valid_geom),
        "most_hallucinated": select_prompt_by_fraction(df_model, "most_hallucinated", require_valid_plot=require_valid_plot, require_valid_geom=require_valid_geom),
    }


def reorder_selected_keys_by_model_size(selected_keys_dict, model_rank):
    """
    Reorders the (model_id, prompt_id) tuples in each panel
    according to model size ordering.
    """
    reordered = {}

    for panel, keys in selected_keys_dict.items():
        reordered[panel] = sorted(
            keys,
            key=lambda k: model_rank[k[0]]
        )

    return reordered

def aggregate_metric_over_prompts(
    df,
    metric="f1",
    score_metric='accuracy',
    agg_prompts=True,
    agg_train_frac=False,
    agg_models=False
):
    """
    Aggregate a metric over prompts (and test splits / subsamples).

    Returns one row per (model_id, train_fraction).

    Parameters
    ----------
    df : pd.DataFrame
        Output of label propagation experiments
    metric : str
        Metric to aggregate (e.g. 'f1')
    score_metric : str
        Metric to aggregate (e.g. 'accuracy')
    agg_prompts : bool
        If True, aggregate over prompts as well

    Returns
    -------
    agg_df : pd.DataFrame
        Columns:
        - model_id
        - train_fraction
        - metric_mean
        - metric_std
        - score_mean
        - score_std
        - mean_n_train
        - std_n_train
        - n_runs
    """

    group_cols = []
    if not agg_models:
        group_cols.append("model_id")
    if not agg_prompts:
        group_cols.append("prompt_id")
    if not agg_train_frac:
        group_cols.append("train_fraction")

    agg_df = (
        df
        .groupby(group_cols)
        .agg(
            metric_mean=(metric, "mean"),
            metric_std=(metric, "std"),
            score_mean=(score_metric, "mean"),
            score_std=(score_metric, "std"),
            mean_n_train=("n_train", "mean"),
            std_n_train=("n_train", "std"),
            n_runs=(metric, "count"),
        )
        .reset_index()
    )

    return agg_df