{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad59dd92",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "import pickle\n",
    "import json\n",
    "import os\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.colors as mcolors\n",
    "import matplotlib.cm as cm\n",
    "from matplotlib.patches import Rectangle, Patch\n",
    "from matplotlib.transforms import offset_copy\n",
    "\n",
    "from scipy.spatial.distance import pdist, cdist\n",
    "from scipy.stats import wasserstein_distance\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from sklearn.metrics import (\n",
    "    accuracy_score,\n",
    "    f1_score,\n",
    "    precision_score,\n",
    "    recall_score,\n",
    "    confusion_matrix,\n",
    ")\n",
    "\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "\n",
    "import networkx as nx\n",
    "\n",
    "from sklearn.model_selection import StratifiedShuffleSplit\n",
    "\n",
    "import HALL_lib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bf87ecc",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['text.usetex'] = True\n",
    "plt.rcParams['text.latex.preamble'] = r'''\n",
    "\\usepackage{mathtools}\n",
    "% more packages here\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b600274",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============== NEWDATASET ==============\n",
    "\n",
    "# cache_dir = \"cache-newdataset\"\n",
    "# maxResponsesPerPrompt = 500\n",
    "\n",
    "# model_names = {\n",
    "#     0 : \"Llama-8B\",\n",
    "#     1 : \"Phi-14B\",\n",
    "#     2 : \"Mistral-7B\",\n",
    "#     3 : \"Solar-11B\",\n",
    "#     4 : \"Gemma-9B\",\n",
    "#     5 : \"Qwen-15B\",\n",
    "#     6 : \"DeepSeek-7B\",\n",
    "#     7 : \"Yi-9B\"\n",
    "# }\n",
    "\n",
    "# ============== OLDDATASET ==============\n",
    "\n",
    "# cache_dir = \"cache-olddataset-20-22\"\n",
    "# maxResponsesPerPrompt=100\n",
    "\n",
    "# model_names = {\n",
    "#     0: 'Mistral-7B',\n",
    "#     1: 'Gemma-9B',\n",
    "#     2: 'Solar-11B',\n",
    "#     3: 'Phi-14B',\n",
    "#     4: 'Qwen-32B',\n",
    "#     5: 'Gemma-27B',\n",
    "#     6: 'Qwen-15B'\n",
    "# }\n",
    "\n",
    "# ============== OLDDATASET ==============\n",
    "\n",
    "cache_dir = \"cache-icml\"\n",
    "maxResponsesPerPrompt=150\n",
    "\n",
    "model_names = {\n",
    "    0: 'Mistral-7B',\n",
    "    1: 'Gemma-9B',\n",
    "    2: 'Solar-11B',\n",
    "    3: 'Phi-14B',\n",
    "    4: 'Qwen-15B',\n",
    "    5: 'Gemma-27B',\n",
    "    6: 'Qwen-32B',\n",
    "    7: \"DeepSeek-7B\",\n",
    "    8: \"Llama-8B\",\n",
    "    9: \"Yi-9B\"\n",
    "}\n",
    "\n",
    "best_reg_lambda = 1.9\n",
    "\n",
    "do_descriptors = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a810ae8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_fn = \"../llm-hallucinations/dataset_icml.parquet\"\n",
    "\n",
    "df = HALL_lib.loadParquet(dataset_fn, unifyYears=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c884fc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0475af0",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a8fced1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7f99912",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_model_size_order(model_names):\n",
    "    \"\"\"\n",
    "    Returns:\n",
    "      model_order: list of model_ids ordered by increasing parameter size\n",
    "      model_rank: dict model_id -> rank\n",
    "    \"\"\"\n",
    "    sizes = {}\n",
    "    for mid, name in model_names.items():\n",
    "        # extract number before 'B'\n",
    "        size = int(name.split('-')[1][:-1])\n",
    "        sizes[mid] = size\n",
    "\n",
    "    model_order = sorted(sizes, key=lambda m: sizes[m])\n",
    "    model_rank = {m: i for i, m in enumerate(model_order)}\n",
    "\n",
    "    return model_order, model_rank\n",
    "\n",
    "model_order, model_rank = build_model_size_order(model_names)\n",
    "\n",
    "false_premise = {pid : fp for pid, fp in df[['prompt_id', \"false_premise\"]].value_counts().keys()} if \"false_premise\" in df.columns else None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac6ecf0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_prompt_data(df, model_id, prompt_id):\n",
    "    \"\"\"\n",
    "    Extract embeddings X and labels y for a given (model_id, prompt_id).\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    X : np.ndarray, shape (n_samples, d)\n",
    "    y : np.ndarray, shape (n_samples,)\n",
    "        0 = genuine, 1 = hallucination\n",
    "    \"\"\"\n",
    "    sub = df[\n",
    "        (df[\"model_id\"] == model_id) &\n",
    "        (df[\"prompt_id\"] == prompt_id)\n",
    "    ]\n",
    "\n",
    "    X = np.stack(sub[\"response_embeddings\"].values)\n",
    "    y = sub[\"hallucination\"].values.astype(bool)\n",
    "\n",
    "    return X, y\n",
    "\n",
    "\n",
    "def split_by_label(X, y):\n",
    "    \"\"\"\n",
    "    Split embeddings by class label.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    X_G : np.ndarray\n",
    "        Genuine responses\n",
    "    X_H : np.ndarray\n",
    "        Hallucinated responses\n",
    "    \"\"\"\n",
    "    X_G = X[~y]\n",
    "    X_H = X[y]\n",
    "\n",
    "    return X_G, X_H\n",
    "\n",
    "def compute_distance_distributions(X_G, X_H):\n",
    "    \"\"\"\n",
    "    Compute intra- and inter-class distance distributions.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    D_GG : np.ndarray\n",
    "    D_HH : np.ndarray\n",
    "    D_GH : np.ndarray\n",
    "    \"\"\"\n",
    "    D_GG = pdist(X_G) if len(X_G) > 1 else np.array([])\n",
    "    D_HH = pdist(X_H) if len(X_H) > 1 else np.array([])\n",
    "    D_GH = cdist(X_G, X_H).ravel() if len(X_G) > 0 and len(X_H) > 0 else np.array([])\n",
    "\n",
    "    return D_GG, D_HH, D_GH\n",
    "\n",
    "\n",
    "def fisher_direction(\n",
    "    X_G,\n",
    "    X_H,\n",
    "    lambda_reg=1e-3,\n",
    "    normalise=True,\n",
    "    normalise_by_trace=True,\n",
    "):\n",
    "    \"\"\"\n",
    "    Compute regularised Fisher discriminant direction with\n",
    "    trace-adaptive regularisation.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    X_G, X_H : np.ndarray\n",
    "        Shape (n_G, d), (n_H, d)\n",
    "    lambda_reg : float\n",
    "        Dimensionless regularisation strength\n",
    "    normalise : bool\n",
    "        Whether to L2-normalise the output direction\n",
    "    normalise_by_trace : bool\n",
    "        Whether to normalise lambda parameter by the average trace of S_W\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    v : np.ndarray, shape (d,)\n",
    "    \"\"\"\n",
    "    mu_G = X_G.mean(axis=0)\n",
    "    mu_H = X_H.mean(axis=0)\n",
    "\n",
    "    # within-class scatter (biased = MLE)\n",
    "    S_G = np.cov(X_G, rowvar=False, bias=True)\n",
    "    S_H = np.cov(X_H, rowvar=False, bias=True)\n",
    "\n",
    "    S_W = S_G + S_H\n",
    "    d = S_W.shape[0]\n",
    "\n",
    "    # trace-normalised regularisation\n",
    "    if normalise_by_trace:\n",
    "        trace_SW = np.trace(S_W)\n",
    "        lambda_eff = lambda_reg * trace_SW / d\n",
    "    else:\n",
    "        lambda_eff = lambda_reg\n",
    "\n",
    "    S_W_reg = S_W + lambda_eff * np.eye(d)\n",
    "\n",
    "    v = np.linalg.solve(S_W_reg, mu_H - mu_G)\n",
    "\n",
    "    if normalise:\n",
    "        norm = np.linalg.norm(v)\n",
    "        if norm > 0:\n",
    "            v /= norm\n",
    "\n",
    "    return v\n",
    "\n",
    "def wasserstein_GG_HH(D_GG, D_HH):\n",
    "    \"\"\"\n",
    "    Wasserstein distance between intra-class distance distributions.\n",
    "    \"\"\"\n",
    "    if len(D_GG) == 0 or len(D_HH) == 0:\n",
    "        return np.nan\n",
    "    return wasserstein_distance(D_GG, D_HH)\n",
    "\n",
    "def wasserstein_null_model(X, y, n_permutations=100, random_state=None):\n",
    "    \"\"\"\n",
    "    Null distribution of W(GG, HH) under random relabelling.\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(random_state)\n",
    "    W_null = []\n",
    "\n",
    "    for _ in range(n_permutations):\n",
    "        y_perm = rng.permutation(y)\n",
    "        X_Gp, X_Hp = split_by_label(X, y_perm)\n",
    "\n",
    "        if len(X_Gp) < 2 or len(X_Hp) < 2:\n",
    "            continue\n",
    "\n",
    "        D_GG_p, D_HH_p, _ = compute_distance_distributions(X_Gp, X_Hp)\n",
    "        W_null.append(wasserstein_GG_HH(D_GG_p, D_HH_p))\n",
    "\n",
    "    W_null = np.array(W_null)\n",
    "    return W_null"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df56008b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def analyse_prompt(\n",
    "    df,\n",
    "    model_id,\n",
    "    prompt_id,\n",
    "    lambda_reg=1e-3,\n",
    "    n_permutations=100,\n",
    "    random_state=None\n",
    "):\n",
    "    X, y = extract_prompt_data(df, model_id, prompt_id)\n",
    "    X_G, X_H = split_by_label(X, y)\n",
    "\n",
    "    n_G = len(X_G)\n",
    "    n_H = len(X_H)\n",
    "\n",
    "    res = {\n",
    "        \"model_id\": model_id,\n",
    "        \"prompt_id\": prompt_id,\n",
    "        \"n_G\": n_G,\n",
    "        \"n_H\": n_H,\n",
    "    }\n",
    "\n",
    "    # ---- original space distances ----\n",
    "    D_GG, D_HH, D_GH = compute_distance_distributions(X_G, X_H)\n",
    "    res[\"D_GG\"] = D_GG\n",
    "    res[\"D_HH\"] = D_HH\n",
    "    res[\"D_GH\"] = D_GH\n",
    "\n",
    "    res[\"W_GG_HH\"] = wasserstein_GG_HH(D_GG, D_HH)\n",
    "\n",
    "    # ---- Fisher space ----\n",
    "    if n_G >= 2 and n_H >= 2:\n",
    "        v = fisher_direction(X_G, X_H, lambda_reg=lambda_reg)\n",
    "        Z = X @ v\n",
    "\n",
    "        Z_G = Z[y == 0][:, None]\n",
    "        Z_H = Z[y == 1][:, None]\n",
    "\n",
    "        D_GG_z, D_HH_z, D_GH_z = compute_distance_distributions(Z_G, Z_H)\n",
    "\n",
    "        res[\"v_fisher\"] = v\n",
    "        res[\"D_GG_z\"] = D_GG_z\n",
    "        res[\"D_HH_z\"] = D_HH_z\n",
    "        res[\"D_GH_z\"] = D_GH_z\n",
    "        res[\"W_GG_HH_z\"] = wasserstein_GG_HH(D_GG_z, D_HH_z)\n",
    "\n",
    "    # ---- null model ----\n",
    "    if n_permutations is not None and n_G >= 2 and n_H >= 2:\n",
    "        W_null = wasserstein_null_model(\n",
    "            X, y,\n",
    "            n_permutations=n_permutations,\n",
    "            random_state=random_state\n",
    "        )\n",
    "        \n",
    "        res[\"W_null_samples\"] = W_null\n",
    "        if len(W_null) > 0:\n",
    "            res[\"W_null_mean\"] = W_null.mean()\n",
    "            res[\"W_null_std\"] = W_null.std()\n",
    "        else:\n",
    "            res[\"W_null_mean\"] = None\n",
    "            res[\"W_null_std\"]  = None\n",
    "    else:\n",
    "        res[\"W_null_samples\"] = None\n",
    "        res[\"W_null_mean\"] = None\n",
    "        res[\"W_null_std\"]  = None\n",
    "\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e157d1de",
   "metadata": {},
   "outputs": [],
   "source": [
    "def collect_prompt_result(res, min_per_class_plot=5):\n",
    "    \"\"\"\n",
    "    Split analyse_prompt output into:\n",
    "    - scalar metadata row\n",
    "    - geometry payload\n",
    "    \"\"\"\n",
    "    m = res[\"model_id\"]\n",
    "    p = res[\"prompt_id\"]\n",
    "\n",
    "    n_G = res[\"n_G\"]\n",
    "    n_H = res[\"n_H\"]\n",
    "    n_total = n_G + n_H\n",
    "\n",
    "    frac_G = n_G / n_total if n_total > 0 else np.nan\n",
    "    frac_H = n_H / n_total if n_total > 0 else np.nan\n",
    "\n",
    "    valid_geom = (n_G >= 2) and (n_H >= 2)\n",
    "    valid_plot = (n_G >= min_per_class_plot) and (n_H >= min_per_class_plot)\n",
    "\n",
    "    # ---- scalar row ----\n",
    "    row = {\n",
    "        \"model_id\": m,\n",
    "        \"prompt_id\": p,\n",
    "        \"n_total\": n_total,\n",
    "        \"n_G\": n_G,\n",
    "        \"n_H\": n_H,\n",
    "        \"frac_G\": frac_G,\n",
    "        \"frac_H\": frac_H,\n",
    "        \"W_GG_HH\": res.get(\"W_GG_HH\", np.nan),\n",
    "        \"W_GG_HH_z\": res.get(\"W_GG_HH_z\", np.nan),\n",
    "        \"valid_geom\": valid_geom,\n",
    "        \"valid_plot\": valid_plot,\n",
    "    }\n",
    "\n",
    "    # optional null-model statistics\n",
    "    row[\"W_null_mean\"] = res[\"W_null_mean\"]\n",
    "    row[\"W_null_std\"] = res[\"W_null_std\"]\n",
    "    if row[\"W_null_mean\"] is not None:\n",
    "        row[\"delta_W\"] = row[\"W_GG_HH\"] - row[\"W_null_mean\"]\n",
    "    else:\n",
    "        row[\"delta_W\"] = None\n",
    "\n",
    "    if \"W_GG_HH_z\" in res and \"W_null_mean\" in res:\n",
    "        row[\"delta_W_z\"] = row[\"W_GG_HH_z\"] - row[\"W_null_mean\"]\n",
    "\n",
    "    # ---- geometry payload ----\n",
    "    gs = {\n",
    "        \"D_GG\": res.get(\"D_GG\"),\n",
    "        \"D_HH\": res.get(\"D_HH\"),\n",
    "        \"D_GH\": res.get(\"D_GH\"),\n",
    "        \"D_GG_z\": res.get(\"D_GG_z\"),\n",
    "        \"D_HH_z\": res.get(\"D_HH_z\"),\n",
    "        \"D_GH_z\": res.get(\"D_GH_z\"),\n",
    "        \"v_fisher\": res.get(\"v_fisher\"),\n",
    "    }\n",
    "\n",
    "    return row, gs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac84be75",
   "metadata": {},
   "outputs": [],
   "source": [
    "# reloadFromCache = True\n",
    "\n",
    "# if reloadFromCache:\n",
    "# else:\n",
    "# rows = []\n",
    "# geometry_store = {}\n",
    "# null_store = {}\n",
    "\n",
    "# for (m, p), _ in df.groupby([\"model_id\", \"prompt_id\"]):\n",
    "#     res = analyse_prompt(\n",
    "#         df,\n",
    "#         model_id=m,\n",
    "#         prompt_id=p,\n",
    "#         lambda_reg=1e-3,\n",
    "#         n_permutations=100,\n",
    "#         random_state=42\n",
    "#     )\n",
    "#     row, gs = collect_prompt_result(res, min_per_class_plot=5)\n",
    "    \n",
    "#     geometry_store[(m, p)] = gs\n",
    "#     null_store[(m, p)] = res['W_null_samples']\n",
    "#     rows.append(row)\n",
    "\n",
    "# results_df = pd.DataFrame(rows)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b67494f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_structural_analysis(\n",
    "    df,\n",
    "    analyse_prompt,\n",
    "    collect_prompt_result,\n",
    "    lambda_reg=1e-3,\n",
    "    n_permutations=100,\n",
    "    random_state=42,\n",
    "    min_per_class_plot=5,\n",
    "    use_cache=False,\n",
    "    cache_dir=\"cache\",\n",
    "    overwrite_cache=False,\n",
    "):\n",
    "    \"\"\"\n",
    "    Run structural analysis over all (model_id, prompt_id) pairs.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    df : pd.DataFrame\n",
    "    analyse_prompt : callable\n",
    "    collect_prompt_result : callable\n",
    "    lambda_reg : float\n",
    "    n_permutations : int\n",
    "    random_state : int\n",
    "    min_per_class_plot : int\n",
    "    use_cache : bool\n",
    "        Whether to load/save results from filesystem\n",
    "    cache_dir : str or None\n",
    "        Directory where cache files are stored\n",
    "    overwrite_cache : bool\n",
    "        If True, recompute even if cache exists\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    results_df : pd.DataFrame\n",
    "    geometry_store : dict\n",
    "        (model_id, prompt_id) -> geometry payload\n",
    "    null_store : dict\n",
    "        (model_id, prompt_id) -> np.ndarray of null Wasserstein samples\n",
    "    \"\"\"\n",
    "\n",
    "    # ---- cache paths ----\n",
    "    if use_cache:\n",
    "        if cache_dir is None:\n",
    "            raise ValueError(\"cache_dir must be provided when use_cache=True\")\n",
    "\n",
    "        os.makedirs(cache_dir, exist_ok=True)\n",
    "\n",
    "        results_path = os.path.join(cache_dir, f\"results_df-{lambda_reg}.parquet\")\n",
    "        geometry_path = os.path.join(cache_dir, f\"geometry_store-{lambda_reg}.pkl\")\n",
    "        null_path = os.path.join(cache_dir, f\"null_store-{lambda_reg}.pkl\")\n",
    "        meta_path = os.path.join(cache_dir, f\"meta-{lambda_reg}.json\")\n",
    "\n",
    "        cache_exists = all(\n",
    "            os.path.exists(p)\n",
    "            for p in [results_path, geometry_path, null_path, meta_path]\n",
    "        )\n",
    "\n",
    "        if cache_exists and not overwrite_cache:\n",
    "            results_df = pd.read_parquet(results_path)\n",
    "\n",
    "            with open(geometry_path, \"rb\") as f:\n",
    "                geometry_store = pickle.load(f)\n",
    "\n",
    "            with open(null_path, \"rb\") as f:\n",
    "                null_store = pickle.load(f)\n",
    "\n",
    "            print(\"Cache correctly loaded.\")\n",
    "\n",
    "            return results_df, geometry_store, null_store\n",
    "\n",
    "    # ---- computation ----\n",
    "    rows = []\n",
    "    geometry_store = {}\n",
    "    null_store = {}\n",
    "\n",
    "    grouped = df.groupby([\"model_id\", \"prompt_id\"])\n",
    "\n",
    "    for (m, p), _ in grouped:\n",
    "        res = analyse_prompt(\n",
    "            df,\n",
    "            model_id=m,\n",
    "            prompt_id=p,\n",
    "            lambda_reg=lambda_reg,\n",
    "            n_permutations=n_permutations,\n",
    "            random_state=random_state,\n",
    "        )\n",
    "\n",
    "        row, gs = collect_prompt_result(\n",
    "            res,\n",
    "            min_per_class_plot=min_per_class_plot\n",
    "        )\n",
    "\n",
    "        geometry_store[(m, p)] = gs\n",
    "        null_store[(m, p)] = res.get(\"W_null_samples\")\n",
    "        rows.append(row)\n",
    "\n",
    "    results_df = pd.DataFrame(rows)\n",
    "\n",
    "    # ---- save cache ----\n",
    "    if use_cache:\n",
    "        results_df.to_parquet(results_path, index=False)\n",
    "\n",
    "        with open(geometry_path, \"wb\") as f:\n",
    "            pickle.dump(geometry_store, f)\n",
    "\n",
    "        with open(null_path, \"wb\") as f:\n",
    "            pickle.dump(null_store, f)\n",
    "\n",
    "        meta = {\n",
    "            \"lambda_reg\": lambda_reg,\n",
    "            \"n_permutations\": n_permutations,\n",
    "            \"random_state\": random_state,\n",
    "            \"min_per_class_plot\": min_per_class_plot,\n",
    "            \"n_rows\": len(results_df),\n",
    "        }\n",
    "\n",
    "        with open(meta_path, \"w\") as f:\n",
    "            json.dump(meta, f, indent=2)\n",
    "\n",
    "        print(\"Cache correctly dumped.\")\n",
    "\n",
    "    return results_df, geometry_store, null_store\n",
    "\n",
    "results_df, geometry_store, null_store = run_structural_analysis(\n",
    "    df,\n",
    "    analyse_prompt,\n",
    "    collect_prompt_result,\n",
    "    lambda_reg=best_reg_lambda,\n",
    "    n_permutations=100,\n",
    "    random_state=42,\n",
    "    min_per_class_plot=5,\n",
    "    use_cache=True,\n",
    "    cache_dir=cache_dir,\n",
    "    overwrite_cache=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "905b860f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c33707e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def select_prompt_by_fraction(\n",
    "    df_model,\n",
    "    mode=\"balanced\",\n",
    "    require_valid_plot=True,\n",
    "    require_valid_geom=False\n",
    "):\n",
    "    \"\"\"\n",
    "    Select a prompt for a single model based on frac_G.\n",
    "    \"\"\"\n",
    "    df_sel = df_model.copy()\n",
    "\n",
    "    if require_valid_plot:\n",
    "        df_sel = df_sel[df_sel[\"valid_plot\"]]\n",
    "\n",
    "    if require_valid_geom:\n",
    "        df_sel = df_sel[df_sel[\"valid_geom\"]]\n",
    "\n",
    "    if df_sel.empty:\n",
    "        return None\n",
    "\n",
    "    if mode == \"balanced\":\n",
    "        idx = (df_sel[\"frac_G\"] - 0.5).abs().idxmin()\n",
    "    elif mode == \"most_genuine\":\n",
    "        idx = df_sel[\"frac_G\"].idxmax()\n",
    "    elif mode == \"most_hallucinated\":\n",
    "        idx = df_sel[\"frac_G\"].idxmin()\n",
    "    else:\n",
    "        raise ValueError(f\"Unknown mode: {mode}\")\n",
    "\n",
    "    return df_sel.loc[idx]\n",
    "\n",
    "def select_representative_prompts(results_df, model_id, require_valid_plot=True, require_valid_geom=True):\n",
    "    \"\"\"\n",
    "    Return balanced, most genuine, most hallucinated prompts for a model.\n",
    "    \"\"\"\n",
    "    df_model = results_df[results_df[\"model_id\"] == model_id]\n",
    "\n",
    "    return {\n",
    "        \"balanced\": select_prompt_by_fraction(df_model, \"balanced\", require_valid_plot=require_valid_plot, require_valid_geom=require_valid_geom),\n",
    "        \"most_genuine\": select_prompt_by_fraction(df_model, \"most_genuine\", require_valid_plot=require_valid_plot, require_valid_geom=require_valid_geom),\n",
    "        \"most_hallucinated\": select_prompt_by_fraction(df_model, \"most_hallucinated\", require_valid_plot=require_valid_plot, require_valid_geom=require_valid_geom),\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67fc84d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_wasserstein_vs_null(\n",
    "    W_obs,\n",
    "    W_null,\n",
    "    title=None,\n",
    "    ax=None\n",
    "):\n",
    "    \"\"\"\n",
    "    Plot observed Wasserstein distance against null distribution.\n",
    "    \"\"\"\n",
    "    if ax is None:\n",
    "        fig, ax = plt.subplots(figsize=(6, 4))\n",
    "\n",
    "    W_null_sorted = np.sort(W_null)\n",
    "    x = np.linspace(0, 1, len(W_null_sorted))\n",
    "\n",
    "    # null distribution\n",
    "    ax.plot(\n",
    "        x,\n",
    "        W_null_sorted,\n",
    "        label=\"Random labelling\",\n",
    "        linewidth=2\n",
    "    )\n",
    "\n",
    "    # observed value\n",
    "    ax.axhline(\n",
    "        W_obs,\n",
    "        linestyle=\"--\",\n",
    "        linewidth=2,\n",
    "        label=\"Observed\"\n",
    "    )\n",
    "\n",
    "    ax.set_xlabel(\"Null distribution quantile\")\n",
    "    ax.set_ylabel(\"Wasserstein distance $W(D_{GG}, D_{HH})$\")\n",
    "\n",
    "    if title is not None:\n",
    "        ax.set_title(title)\n",
    "\n",
    "    ax.legend()\n",
    "    ax.grid(True)\n",
    "\n",
    "    return ax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffa597c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "row = select_representative_prompts(results_df, model_id=0)[\"balanced\"]\n",
    "\n",
    "m = row[\"model_id\"]\n",
    "p = row[\"prompt_id\"]\n",
    "\n",
    "res = analyse_prompt(\n",
    "    df,\n",
    "    model_id=m,\n",
    "    prompt_id=p,\n",
    "    n_permutations=200,\n",
    "    random_state=42\n",
    ")\n",
    "\n",
    "plot_wasserstein_vs_null(\n",
    "    W_obs=res[\"W_GG_HH\"],\n",
    "    W_null=res[\"W_null_samples\"],\n",
    "    title=f\"Model {m}, Prompt {p}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1023b937",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_hallucination_ratio_histogram(results_df):\n",
    "    frac_H = results_df[\"frac_H\"].dropna()\n",
    "\n",
    "    bins = np.linspace(0, 1, 6)  # 0, 0.2, ..., 1\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(6, 4))\n",
    "    ax.hist(frac_H, bins=bins, edgecolor=\"black\")\n",
    "\n",
    "    ax.set_xlabel(\"Hallucination rate (n_H / n_total)\")\n",
    "    ax.set_ylabel(\"Number of (model, prompt) pairs\")\n",
    "    ax.set_xticks(bins)\n",
    "\n",
    "    ax.set_title(\"Distribution of hallucination rates\")\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "plot_hallucination_ratio_histogram(results_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5df92b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_hallucination_rate_heatmap(results_df, model_names):\n",
    "    pivot = results_df.pivot(\n",
    "        index=\"model_id\",\n",
    "        columns=\"prompt_id\",\n",
    "        values=\"frac_H\"\n",
    "    )\n",
    "\n",
    "    # sort prompts by mean hallucination rate\n",
    "    prompt_order = pivot.mean(axis=0).sort_values().index\n",
    "    pivot = pivot[prompt_order]\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(21, 7))\n",
    "\n",
    "    im = ax.imshow(\n",
    "        pivot.values,\n",
    "        aspect=\"auto\",\n",
    "        cmap=\"RdBu_r\",\n",
    "        vmin=0,\n",
    "        vmax=1\n",
    "    )\n",
    "\n",
    "    ax.set_yticks(range(len(pivot.index)))\n",
    "    ax.set_yticklabels([model_names[m] for m in pivot.index])\n",
    "\n",
    "    ax.set_xticks([x for x, _ in enumerate(prompt_order)])\n",
    "    ax.set_xticklabels([x for x in prompt_order], rotation=90, ha='center')\n",
    "    ax.set_xlabel(\"Prompts (sorted by hallucination rate)\")\n",
    "\n",
    "    cbar = fig.colorbar(im, ax=ax)\n",
    "    cbar.set_label(\"Hallucination rate\")\n",
    "\n",
    "    ax.set_title(\"Hallucination rates per model and prompt\")\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "plot_hallucination_rate_heatmap(results_df, model_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71941837",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_hallucination_rate_heatmap_by_premise(\n",
    "    results_df,\n",
    "    model_names,\n",
    "    false_premise_map,\n",
    "    ratio=(12, 4),\n",
    "    scale=1.5,\n",
    "):\n",
    "    \"\"\"\n",
    "    Heatmap of hallucination rates split by false premise.\n",
    "\n",
    "    Left: false_premise = False\n",
    "    Right: false_premise = True\n",
    "    \"\"\"\n",
    "    if false_premise_map is None:\n",
    "        return\n",
    "\n",
    "    # ---- base pivot ----\n",
    "    pivot = results_df.pivot(\n",
    "        index=\"model_id\",\n",
    "        columns=\"prompt_id\",\n",
    "        values=\"frac_H\"\n",
    "    )\n",
    "\n",
    "    # ---- split prompts ----\n",
    "    prompts_false = [\n",
    "        p for p, v in false_premise_map.items() if not v and p in pivot.columns\n",
    "    ]\n",
    "    prompts_true = [\n",
    "        p for p, v in false_premise_map.items() if v and p in pivot.columns\n",
    "    ]\n",
    "\n",
    "    # ---- sort within each group by mean hallucination rate ----\n",
    "    prompts_false = (\n",
    "        pivot[prompts_false]\n",
    "        .mean(axis=0)\n",
    "        .sort_values()\n",
    "        .index\n",
    "    )\n",
    "\n",
    "    prompts_true = (\n",
    "        pivot[prompts_true]\n",
    "        .mean(axis=0)\n",
    "        .sort_values()\n",
    "        .index\n",
    "    )\n",
    "\n",
    "    pivot_false = pivot[prompts_false]\n",
    "    pivot_true  = pivot[prompts_true]\n",
    "\n",
    "    # ---- plotting ----\n",
    "    fig, axes = plt.subplots(\n",
    "        1, 2,\n",
    "        figsize=[scale * x for x in ratio],\n",
    "        sharey=True\n",
    "    )\n",
    "\n",
    "    cmap = \"RdBu_r\"\n",
    "    vmin, vmax = 0, 1\n",
    "\n",
    "    im0 = axes[0].imshow(\n",
    "        pivot_false.values,\n",
    "        aspect=\"auto\",\n",
    "        cmap=cmap,\n",
    "        vmin=vmin,\n",
    "        vmax=vmax\n",
    "    )\n",
    "\n",
    "    im1 = axes[1].imshow(\n",
    "        pivot_true.values,\n",
    "        aspect=\"auto\",\n",
    "        cmap=cmap,\n",
    "        vmin=vmin,\n",
    "        vmax=vmax\n",
    "    )\n",
    "\n",
    "    # ---- y-axis (models) ----\n",
    "    axes[0].set_yticks(range(len(pivot.index)))\n",
    "    axes[0].set_yticklabels([model_names[m] for m in pivot.index])\n",
    "\n",
    "    # ---- x-axis ----\n",
    "    for ax in axes:\n",
    "        ax.set_xticks([])\n",
    "        ax.set_xlabel(\"Prompts (sorted by hallucination rate)\")\n",
    "\n",
    "    axes[0].set_title(\"False premise = False\")\n",
    "    axes[1].set_title(\"False premise = True\")\n",
    "\n",
    "    # ---- colourbar ----\n",
    "    cbar = fig.colorbar(im1, ax=axes, fraction=0.025, pad=0.02)\n",
    "    cbar.set_label(\"Hallucination rate\")\n",
    "\n",
    "    fig.suptitle(\"Hallucination rates per model and prompt\", y=1.02)\n",
    "\n",
    "    # fig.tight_layout()\n",
    "    return fig, axes\n",
    "\n",
    "plot_hallucination_rate_heatmap_by_premise(\n",
    "    results_df,\n",
    "    model_names,\n",
    "    false_premise\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49bdb282",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_response_count_histogram(results_df, expected=500):\n",
    "    n_total = results_df[\"n_total\"]\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(6, 4))\n",
    "    ax.hist(n_total, bins=20, edgecolor=\"black\")\n",
    "\n",
    "    ax.axvline(expected, linestyle=\"--\", color=\"red\", label=\"Expected\")\n",
    "\n",
    "    ax.set_xlabel(\"Number of responses per (model, prompt)\")\n",
    "    ax.set_ylabel(\"Count\")\n",
    "    ax.legend()\n",
    "\n",
    "    ax.set_title(\"Response count per (model, prompt)\")\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "plot_response_count_histogram(results_df, expected=maxResponsesPerPrompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a0a08c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_validity_coverage(results_df):\n",
    "    counts = {\n",
    "        \"Total\": len(results_df),\n",
    "        \"Valid geometry\": results_df[\"valid_geom\"].sum(),\n",
    "        \"Valid plot\": results_df[\"valid_plot\"].sum(),\n",
    "    }\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(5, 4))\n",
    "    ax.bar(counts.keys(), counts.values())\n",
    "\n",
    "    ax.set_ylabel(\"Number of (model, prompt) pairs\")\n",
    "    ax.set_title(\"Coverage of valid configurations\")\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "def plot_validity_by_model(results_df, model_names):\n",
    "    grouped = results_df.groupby(\"model_id\")[[\"valid_geom\", \"valid_plot\"]].mean()\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(6, 4))\n",
    "    grouped.plot.bar(ax=ax)\n",
    "\n",
    "    ax.set_ylabel(\"Fraction\")\n",
    "    ax.set_xticklabels([model_names[m] for m in grouped.index], rotation=30)\n",
    "    ax.set_title(\"Validity fraction per model\")\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "plot_validity_coverage(results_df)\n",
    "plot_validity_by_model(results_df, model_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ae0d466",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31074aea",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_wasserstein_ordered_with_null(\n",
    "    df_model,\n",
    "    null_store,\n",
    "    model_id,\n",
    "    ax=None,\n",
    "    axc=None,\n",
    "    title=None,\n",
    "    show_fliers=False,\n",
    "    cmap_name=\"berlin\",\n",
    "    requireValidPlot=False,\n",
    "    false_premise=None\n",
    "):\n",
    "    \"\"\"\n",
    "    Plot observed Wasserstein distances ordered by magnitude,\n",
    "    with corresponding null distributions as boxplots,\n",
    "    colour-coded by fraction of genuine responses.\n",
    "    \"\"\"\n",
    "    # ---- filter prompts with available null model ----\n",
    "    rows = []\n",
    "    for _, r in df_model.iterrows():\n",
    "        if requireValidPlot and not r[\"valid_plot\"]:\n",
    "            continue\n",
    "        key = (model_id, r[\"prompt_id\"])\n",
    "        if key in null_store and len(null_store[key]) > 0:\n",
    "            rows.append(r)\n",
    "\n",
    "    if len(rows) == 0:\n",
    "        raise ValueError(\"No prompts with available null distributions\")\n",
    "\n",
    "    df = (\n",
    "        df_model\n",
    "        .loc[[r.name for r in rows]]\n",
    "        .sort_values(\"W_GG_HH\")\n",
    "        .reset_index(drop=True)\n",
    "    )\n",
    "\n",
    "    # ---- data ----\n",
    "    W_obs = df[\"W_GG_HH\"].values\n",
    "    prompt_ids = df[\"prompt_id\"].values\n",
    "    genuine = df[\"frac_G\"].values\n",
    "    x = np.arange(len(W_obs))\n",
    "\n",
    "    W_null_list = [\n",
    "        null_store[(model_id, p)]\n",
    "        for p in prompt_ids\n",
    "    ]\n",
    "\n",
    "    # ---- colour mapping ----\n",
    "    cmap = plt.get_cmap(cmap_name)\n",
    "    norm = mcolors.TwoSlopeNorm(vmin=0.0, vcenter=0.5, vmax=1.0)\n",
    "    colors = cmap(norm(genuine))\n",
    "\n",
    "    # ---- figure / axes ----\n",
    "    if ax is None:\n",
    "        fig, (ax, axc) = plt.subplots(\n",
    "            2, 1,\n",
    "            sharex=True,\n",
    "            height_ratios=[3, 1],\n",
    "            figsize=(15, 5),\n",
    "            \n",
    "        )\n",
    "    else:\n",
    "        fig = ax.figure\n",
    "\n",
    "    # ---- null boxplots ----\n",
    "    bp = ax.boxplot(\n",
    "        W_null_list,\n",
    "        positions=x,\n",
    "        widths=0.6,\n",
    "        showfliers=show_fliers,\n",
    "        patch_artist=True,\n",
    "        medianprops=dict(color=\"black\"),\n",
    "        whiskerprops=dict(alpha=0.6),\n",
    "        capprops=dict(alpha=0.6),\n",
    "        label=\"$H_0$ reference\"\n",
    "    )\n",
    "\n",
    "    # colour each box\n",
    "    for patch, c in zip(bp[\"boxes\"], colors):\n",
    "        patch.set_facecolor(c)\n",
    "        patch.set_alpha(0.35)\n",
    "\n",
    "    # ---- observed scatter ----\n",
    "    if false_premise is None:\n",
    "        sub = [True for pid in prompt_ids]\n",
    "    else:\n",
    "        sub = [(not false_premise[pid]) for pid in prompt_ids]\n",
    "\n",
    "    ax.scatter(\n",
    "        x[sub],\n",
    "        W_obs[sub],\n",
    "        c=colors[sub],\n",
    "        edgecolor=\"black\",\n",
    "        marker='o',\n",
    "        linewidth=0.5,\n",
    "        zorder=3,\n",
    "        label=\"$W(D_{GG}, D_{HH})$\"\n",
    "    )\n",
    "    if False in sub:\n",
    "        ax.scatter(\n",
    "            x[[not s for s in sub]],\n",
    "            W_obs[[not s for s in sub]],\n",
    "            c=colors[[not s for s in sub]],\n",
    "            edgecolor=\"black\",\n",
    "            marker='*',\n",
    "            linewidth=0.5,\n",
    "            zorder=3,\n",
    "            label=\"$W(D_{GG}, D_{HH})$ (False premise)\"\n",
    "        )\n",
    "\n",
    "    # ---- styling ----\n",
    "    ax.set_ylabel(\"Wasserstein distance\")\n",
    "    ax.set_xlim(-0.5, len(x) - 0.5)\n",
    "    ax.set_ylim(bottom=0)\n",
    "\n",
    "    if title is None:\n",
    "        title = f\"{model_names[model_id]}\"\n",
    "    ax.set_title(title)\n",
    "\n",
    "    ax.legend(loc=\"upper left\")\n",
    "    ax.grid(True, axis=\"y\", alpha=0.3)\n",
    "    ax.set_xticks([])\n",
    "\n",
    "    # ---- fraction panel ----\n",
    "    if axc is not None:\n",
    "        axc.scatter(\n",
    "            x[sub],\n",
    "            genuine[sub],\n",
    "            c=colors[sub],\n",
    "            edgecolor=\"black\",\n",
    "            linewidth=0.5\n",
    "        )\n",
    "        if False in sub:\n",
    "            axc.scatter(\n",
    "                x[[not s for s in sub]],\n",
    "                genuine[[not s for s in sub]],\n",
    "                c=colors[[not s for s in sub]],\n",
    "                marker='*',\n",
    "                edgecolor=\"black\",\n",
    "                linewidth=0.5\n",
    "            )\n",
    "        axc.axhline(0.5, color=\"black\", linestyle=\"--\", alpha=0.6)\n",
    "        axc.set_ylim(0, 1)\n",
    "        axc.set_ylabel(\"Fraction genuine\")\n",
    "        axc.set_xlabel(\"Prompt (ordered by observed Wasserstein distance)\")\n",
    "\n",
    "    # ---- colourbar ----\n",
    "    sm = cm.ScalarMappable(norm=norm, cmap=cmap)\n",
    "    sm.set_array([])\n",
    "    cbar = fig.colorbar(\n",
    "        sm,\n",
    "        ax=[ax, axc] if axc is not None else ax,\n",
    "        fraction=0.025,\n",
    "        pad=0.02\n",
    "    )\n",
    "    cbar.set_label(\"Fraction genuine\")\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "for tmp_model_id in range(len(model_names)):\n",
    "    df_model = results_df[\n",
    "        (results_df[\"model_id\"] == tmp_model_id) &\n",
    "        (results_df[\"valid_geom\"])\n",
    "    ]\n",
    "\n",
    "    fig, ax = plot_wasserstein_ordered_with_null(\n",
    "        df_model=df_model,\n",
    "        null_store=null_store,\n",
    "        model_id=tmp_model_id,\n",
    "        show_fliers=False,\n",
    "        false_premise=false_premise\n",
    "    )\n",
    "\n",
    "    if model_names[tmp_model_id] == \"Phi-14B\":\n",
    "        fig.savefig(\"img/S_WassVsNull_Phi.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79ac3116",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_violin_box_distance_panels_v0(\n",
    "    geometry_store,\n",
    "    selected_keys,\n",
    "    model_names,\n",
    "    figsize=(14, 8),\n",
    "    violin_width=0.8,\n",
    "    box_width=0.25,\n",
    "    alpha_violin=0.5,\n",
    "    title=None\n",
    "):\n",
    "    \"\"\"\n",
    "    Two-panel plot (embedding / Fisher space) with:\n",
    "      - left half violin: intra-genuine\n",
    "      - boxplot: inter-class\n",
    "      - right half violin: intra-hallucinated\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    geometry_store : dict\n",
    "        geometry_store[(model_id, prompt_id)] -> dict of distance arrays\n",
    "\n",
    "    selected_keys : list of tuples\n",
    "        [(model_id, prompt_id), ...]\n",
    "\n",
    "    model_names : dict\n",
    "        model_id -> model name\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    fig, axes\n",
    "    \"\"\"\n",
    "\n",
    "    # ---- key mapping (adapt here if needed)\n",
    "    KEYMAP = {\n",
    "        \"emb\": (\"D_GG\", \"D_HH\", \"D_GH\"),\n",
    "        \"fisher\": (\"D_GG_z\", \"D_HH_z\", \"D_GH_z\")\n",
    "    }\n",
    "\n",
    "    n = len(selected_keys)\n",
    "    x = np.arange(n)\n",
    "\n",
    "    # ---- colour palette (model-based)\n",
    "    cmap = plt.get_cmap(\"tab10\")\n",
    "    model_colours = {\n",
    "        m: cmap(i % cmap.N)\n",
    "        for i, m in enumerate(sorted({k[0] for k in selected_keys}))\n",
    "    }\n",
    "\n",
    "    fig, axes = plt.subplots(\n",
    "        2, 1,\n",
    "        figsize=figsize,\n",
    "        sharex=True,\n",
    "        gridspec_kw={\"height_ratios\": [2, 1]}\n",
    "    )\n",
    "\n",
    "    def _plot_panel(ax, space_key, ylabel):\n",
    "        key_GG, key_HH, key_GH = KEYMAP[space_key]\n",
    "\n",
    "        # ---- collect values for limits\n",
    "        all_vals = []\n",
    "        for k in selected_keys:\n",
    "            d = geometry_store[k]\n",
    "            all_vals.extend(d[key_GG])\n",
    "            all_vals.extend(d[key_HH])\n",
    "            all_vals.extend(d[key_GH])\n",
    "\n",
    "        all_vals = np.asarray(all_vals)\n",
    "        all_vals = all_vals[np.isfinite(all_vals)]\n",
    "\n",
    "        ymin, ymax = all_vals.min(), all_vals.max()\n",
    "        pad = 0.05 * (ymax - ymin)\n",
    "        ax.set_ylim(ymin - pad, ymax + pad)\n",
    "\n",
    "        y0, y1 = ax.get_ylim()\n",
    "        height = y1 - y0\n",
    "\n",
    "        for i, (m, p) in enumerate(selected_keys):\n",
    "            d = geometry_store[(m, p)]\n",
    "            colour = model_colours[m]\n",
    "\n",
    "            # ---- genuine half violin (left)\n",
    "            vp = ax.violinplot(\n",
    "                d[key_GG],\n",
    "                positions=[x[i]],\n",
    "                widths=violin_width,\n",
    "                showmeans=False,\n",
    "                showmedians=False,\n",
    "                showextrema=False\n",
    "            )\n",
    "            for pc in vp[\"bodies\"]:\n",
    "                pc.set_facecolor(colour)\n",
    "                pc.set_alpha(alpha_violin)\n",
    "                pc.set_clip_path(\n",
    "                    Rectangle(\n",
    "                        (x[i] - violin_width / 2, y0),\n",
    "                        violin_width / 2,\n",
    "                        height,\n",
    "                        transform=ax.transData\n",
    "                    )\n",
    "                )\n",
    "\n",
    "            # ---- hallucinated half violin (right)\n",
    "            vp = ax.violinplot(\n",
    "                d[key_HH],\n",
    "                positions=[x[i]],\n",
    "                widths=violin_width,\n",
    "                showmeans=False,\n",
    "                showmedians=False,\n",
    "                showextrema=False\n",
    "            )\n",
    "            for pc in vp[\"bodies\"]:\n",
    "                pc.set_facecolor(colour)\n",
    "                pc.set_alpha(alpha_violin)\n",
    "                pc.set_clip_path(\n",
    "                    Rectangle(\n",
    "                        (x[i], y0),\n",
    "                        violin_width / 2,\n",
    "                        height,\n",
    "                        transform=ax.transData\n",
    "                    )\n",
    "                )\n",
    "\n",
    "            # ---- inter-class boxplot\n",
    "            bp = ax.boxplot(\n",
    "                d[key_GH],\n",
    "                positions=[x[i]],\n",
    "                widths=box_width,\n",
    "                patch_artist=True,\n",
    "                showfliers=False\n",
    "            )\n",
    "            for box in bp[\"boxes\"]:\n",
    "                box.set_facecolor(colour)\n",
    "                box.set_alpha(0.9)\n",
    "\n",
    "        ax.set_ylabel(ylabel)\n",
    "        ax.grid(True, axis=\"y\", alpha=0.3)\n",
    "\n",
    "    # ---- plot panels\n",
    "    _plot_panel(axes[0], \"emb\", \"Distance (embedding space)\")\n",
    "    _plot_panel(axes[1], \"fisher\", \"Distance (Fisher space)\")\n",
    "\n",
    "    # ---- x-axis labels\n",
    "    labels = [\n",
    "        f\"{model_names[m]}\\n(p={p})\"\n",
    "        for m, p in selected_keys\n",
    "    ]\n",
    "    axes[1].set_xticks(x)\n",
    "    axes[1].set_xticklabels(labels, rotation=30, ha=\"right\")\n",
    "\n",
    "    if title is None:\n",
    "        title = \"Distance distributions across models and prompts\"\n",
    "    axes[0].set_title(title)\n",
    "\n",
    "    fig.tight_layout()\n",
    "    return fig, axes\n",
    "\n",
    "def plot_violin_box_distance_panels_fisher(\n",
    "    geometry_store,\n",
    "    selected_keys,\n",
    "    model_names,\n",
    "    figsize=(14, 8),\n",
    "    violin_width=0.8,\n",
    "    box_width=0.25,\n",
    "    alpha_violin=0.5,\n",
    "    title=None\n",
    "):\n",
    "    \"\"\"\n",
    "    Two-panel plot (embedding / Fisher space) with:\n",
    "      - left half violin: intra-genuine (GG)\n",
    "      - boxplot: inter-class (GH)\n",
    "      - right half violin: intra-hallucinated (HH)\n",
    "    \"\"\"\n",
    "\n",
    "    KEYMAP = {\n",
    "        \"emb\": (\"D_GG\", \"D_HH\", \"D_GH\"),\n",
    "        \"fisher\": (\"D_GG_z\", \"D_HH_z\", \"D_GH_z\")\n",
    "    }\n",
    "\n",
    "    COLORS = {\n",
    "        \"GG\": \"#6E9B34\",  # green\n",
    "        \"HH\": \"#AA4D39\",  # red\n",
    "        \"GH\": \"#27586B\",  # blue\n",
    "    }\n",
    "\n",
    "    n = len(selected_keys)\n",
    "    x = np.arange(n)\n",
    "\n",
    "    fig, axes = plt.subplots(\n",
    "        2, 1,\n",
    "        figsize=figsize,\n",
    "        sharex=True,\n",
    "        gridspec_kw={\"height_ratios\": [2, 1]}\n",
    "    )\n",
    "\n",
    "    def _plot_panel(ax, space_key, ylabel):\n",
    "        key_GG, key_HH, key_GH = KEYMAP[space_key]\n",
    "\n",
    "        # ---- global y-limits\n",
    "        all_vals = []\n",
    "        for k in selected_keys:\n",
    "            d = geometry_store[k]\n",
    "            all_vals.extend(d[key_GG])\n",
    "            all_vals.extend(d[key_HH])\n",
    "            all_vals.extend(d[key_GH])\n",
    "\n",
    "        all_vals = np.asarray(all_vals)\n",
    "        all_vals = all_vals[np.isfinite(all_vals)]\n",
    "\n",
    "        ymin, ymax = all_vals.min(), all_vals.max()\n",
    "        pad = 0.05 * (ymax - ymin)\n",
    "        ax.set_ylim(ymin - pad, ymax + pad)\n",
    "\n",
    "        y0, y1 = ax.get_ylim()\n",
    "        height = y1 - y0\n",
    "\n",
    "        for i, (m, p) in enumerate(selected_keys):\n",
    "            d = geometry_store[(m, p)]\n",
    "\n",
    "            # ---- GG: left half violin\n",
    "            vp = ax.violinplot(\n",
    "                d[key_GG],\n",
    "                positions=[x[i]],\n",
    "                widths=violin_width,\n",
    "                showmeans=False,\n",
    "                showmedians=False,\n",
    "                showextrema=False\n",
    "            )\n",
    "            for pc in vp[\"bodies\"]:\n",
    "                pc.set_facecolor(COLORS[\"GG\"])\n",
    "                pc.set_edgecolor(\"black\")\n",
    "                pc.set_linewidth(1.0)\n",
    "                pc.set_alpha(alpha_violin)\n",
    "                pc.set_clip_path(\n",
    "                    Rectangle(\n",
    "                        (x[i] - violin_width / 2, y0),\n",
    "                        violin_width / 2,\n",
    "                        height,\n",
    "                        transform=ax.transData\n",
    "                    )\n",
    "                )\n",
    "\n",
    "            # ---- HH: right half violin\n",
    "            vp = ax.violinplot(\n",
    "                d[key_HH],\n",
    "                positions=[x[i]],\n",
    "                widths=violin_width,\n",
    "                showmeans=False,\n",
    "                showmedians=False,\n",
    "                showextrema=False\n",
    "            )\n",
    "            for pc in vp[\"bodies\"]:\n",
    "                pc.set_facecolor(COLORS[\"HH\"])\n",
    "                pc.set_edgecolor(\"black\")\n",
    "                pc.set_linewidth(1.0)\n",
    "                pc.set_alpha(alpha_violin)\n",
    "                pc.set_clip_path(\n",
    "                    Rectangle(\n",
    "                        (x[i], y0),\n",
    "                        violin_width / 2,\n",
    "                        height,\n",
    "                        transform=ax.transData\n",
    "                    )\n",
    "                )\n",
    "\n",
    "            # ---- GH: boxplot\n",
    "            bp = ax.boxplot(\n",
    "                d[key_GH],\n",
    "                positions=[x[i]],\n",
    "                widths=box_width,\n",
    "                patch_artist=True,\n",
    "                showfliers=False\n",
    "            )\n",
    "            for box in bp[\"boxes\"]:\n",
    "                box.set_facecolor(COLORS[\"GH\"])\n",
    "                box.set_edgecolor(\"black\")\n",
    "                box.set_alpha(0.9)\n",
    "\n",
    "            for elem in [\"whiskers\", \"caps\", \"medians\"]:\n",
    "                for artist in bp[elem]:\n",
    "                    artist.set_color(\"black\")\n",
    "\n",
    "        ax.set_ylabel(ylabel)\n",
    "        ax.grid(True, axis=\"y\", alpha=0.3)\n",
    "\n",
    "    # ---- plot panels\n",
    "    _plot_panel(axes[0], \"emb\", \"Distance (embedding space)\")\n",
    "    _plot_panel(axes[1], \"fisher\", \"Distance (Fisher space)\")\n",
    "\n",
    "    # ---- x-axis labels\n",
    "    labels = [\n",
    "        f\"{model_names[m]}\\n(p={p})\"\n",
    "        for m, p in selected_keys\n",
    "    ]\n",
    "    axes[1].set_xticks(x)\n",
    "    axes[1].set_xticklabels(labels)#, rotation=30, ha=\"right\")\n",
    "\n",
    "    if title is None:\n",
    "        title = \"Distance distributions across models and prompts\"\n",
    "    axes[0].set_title(title)\n",
    "\n",
    "    # ---- legend\n",
    "    handles = [\n",
    "        plt.Line2D([0], [0], color=COLORS[\"GG\"], lw=6, label=\"GG (intra-genuine)\"),\n",
    "        plt.Line2D([0], [0], color=COLORS[\"GH\"], lw=6, label=\"GH (inter-class)\"),\n",
    "        plt.Line2D([0], [0], color=COLORS[\"HH\"], lw=6, label=\"HH (intra-hallucinated)\"),\n",
    "    ]\n",
    "    axes[0].legend(handles=handles, loc=\"upper left\")\n",
    "\n",
    "    fig.tight_layout()\n",
    "    return fig, axes\n",
    "\n",
    "sks = [select_representative_prompts(results_df, model_id=mid) for mid in range(len(model_names))]\n",
    "\n",
    "for kk in sks[0].keys():\n",
    "    selk = [(dd[kk]['model_id'], dd[kk]['prompt_id']) for dd in sks]\n",
    "\n",
    "    fig, axes = plot_violin_box_distance_panels_fisher(\n",
    "        geometry_store=geometry_store,\n",
    "        selected_keys=selk,\n",
    "        model_names=model_names,\n",
    "        figsize=(16, 9),\n",
    "        title=kk\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9d01fb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp_model_id = 0\n",
    "\n",
    "selkeys = [(r['model_id'], r['prompt_id']) for _, r in results_df[(results_df['model_id'] == tmp_model_id) & (results_df['valid_plot'])].iterrows() ]\n",
    "\n",
    "fig, axes = plot_violin_box_distance_panels_fisher(\n",
    "    geometry_store=geometry_store,\n",
    "    selected_keys=selkeys[:21],\n",
    "    model_names=model_names,\n",
    "    figsize=(21, 7)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2ddb04b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# row = selected[\"balanced\"]\n",
    "# key = (row[\"model_id\"], row[\"prompt_id\"])\n",
    "# geom = geometry_store[key]\n",
    "\n",
    "# D_GG = geom[\"D_GG\"]\n",
    "# D_GH = geom[\"D_GH\"]\n",
    "# D_HH = geom[\"D_HH\"]\n",
    "\n",
    "# D_GG_z = geom[\"D_GG_z\"]\n",
    "# D_GH_z = geom[\"D_GH_z\"]\n",
    "# D_HH_z = geom[\"D_HH_z\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10784517",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_distance_boxpanels(\n",
    "    geometry_store,\n",
    "    selected_keys,\n",
    "    model_names,\n",
    "    figsize=(16, 9),\n",
    "    show_fliers=False\n",
    "):\n",
    "    \"\"\"\n",
    "    Plot GG / GH / HH distance distributions as boxplots\n",
    "    for selected (model, prompt) pairs.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    geometry_store : dict\n",
    "        (model_id, prompt_id) -> geometry payload\n",
    "    selected_keys : list of tuple\n",
    "        [(model_id, prompt_id), ...]\n",
    "    model_names : dict\n",
    "        model_id -> model name\n",
    "    figsize : tuple\n",
    "    show_fliers : bool\n",
    "    \"\"\"\n",
    "\n",
    "    fig, axes = plt.subplots(\n",
    "        2, 1,\n",
    "        sharex=True,\n",
    "        figsize=figsize,\n",
    "        height_ratios=[2, 1]\n",
    "    )\n",
    "\n",
    "    panels = [\n",
    "        (\"embedding\", axes[0], \"Distances in embedding space\"),\n",
    "        (\"fisher\", axes[1], \"Distances in Fisher space\"),\n",
    "    ]\n",
    "\n",
    "    colors = {\n",
    "        \"GG\": \"#2ca02c\",  # green\n",
    "        \"GH\": \"#1f77b4\",  # blue\n",
    "        \"HH\": \"#d62728\",  # red\n",
    "    }\n",
    "\n",
    "    width = 0.25\n",
    "    x_centres = np.arange(len(selected_keys))\n",
    "\n",
    "    for space, ax, title in panels:\n",
    "        for i, key in enumerate(selected_keys):\n",
    "            gs = geometry_store[key]\n",
    "\n",
    "            if space == \"embedding\":\n",
    "                D_GG = gs[\"D_GG\"]\n",
    "                D_GH = gs[\"D_GH\"]\n",
    "                D_HH = gs[\"D_HH\"]\n",
    "            else:\n",
    "                D_GG = gs[\"D_GG_z\"]\n",
    "                D_GH = gs[\"D_GH_z\"]\n",
    "                D_HH = gs[\"D_HH_z\"]\n",
    "\n",
    "            data = [D_GG, D_GH, D_HH]\n",
    "            positions = [\n",
    "                x_centres[i] - width,\n",
    "                x_centres[i],\n",
    "                x_centres[i] + width,\n",
    "            ]\n",
    "\n",
    "            bp = ax.boxplot(\n",
    "                data,\n",
    "                positions=positions,\n",
    "                widths=width * 0.9,\n",
    "                patch_artist=True,\n",
    "                showfliers=show_fliers,\n",
    "                medianprops=dict(color=\"black\"),\n",
    "            )\n",
    "\n",
    "            for patch, key_col in zip(bp[\"boxes\"], [\"GG\", \"GH\", \"HH\"]):\n",
    "                patch.set_facecolor(colors[key_col])\n",
    "                patch.set_alpha(0.6)\n",
    "\n",
    "        ax.set_ylabel(\"Distance\")\n",
    "        ax.set_title(title)\n",
    "        ax.grid(True, axis=\"y\", alpha=0.3)\n",
    "\n",
    "    # ---- x labels ----\n",
    "    labels = [\n",
    "        f\"{model_names[m]}\\nP{p}\"\n",
    "        for (m, p) in selected_keys\n",
    "    ]\n",
    "\n",
    "    axes[1].set_xticks(x_centres)\n",
    "    axes[1].set_xticklabels(labels)#, rotation=45, ha=\"right\")\n",
    "    axes[1].set_xlabel(\"Model / Prompt\")\n",
    "\n",
    "    # ---- legend ----\n",
    "    handles = [\n",
    "        plt.Line2D([0], [0], color=colors[\"GG\"], lw=6, label=\"GG (intra-genuine)\"),\n",
    "        plt.Line2D([0], [0], color=colors[\"GH\"], lw=6, label=\"GH (inter-class)\"),\n",
    "        plt.Line2D([0], [0], color=colors[\"HH\"], lw=6, label=\"HH (intra-hallucinated)\"),\n",
    "    ]\n",
    "    axes[0].legend(handles=handles, loc=\"upper left\")\n",
    "\n",
    "    fig.tight_layout()\n",
    "    return fig, axes\n",
    "\n",
    "tmp_model_id = 0\n",
    "\n",
    "selkeys = [(r['model_id'], r['prompt_id']) for _, r in results_df[(results_df['model_id'] == tmp_model_id) & (results_df['valid_plot'])].iterrows() ]\n",
    "\n",
    "fig, axes = plot_distance_boxpanels(\n",
    "    geometry_store=geometry_store,\n",
    "    selected_keys=selkeys[:10],\n",
    "    model_names=model_names,\n",
    "    show_fliers=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a6e275b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import mannwhitneyu\n",
    "\n",
    "def gg_hh_pvalue(d_GG, d_HH):\n",
    "    d_GG = np.asarray(d_GG)\n",
    "    d_HH = np.asarray(d_HH)\n",
    "\n",
    "    d_GG = d_GG[np.isfinite(d_GG)]\n",
    "    d_HH = d_HH[np.isfinite(d_HH)]\n",
    "\n",
    "    if len(d_GG) < 2 or len(d_HH) < 2:\n",
    "        return np.nan\n",
    "\n",
    "    _, p = mannwhitneyu(d_GG, d_HH, alternative=\"two-sided\")\n",
    "    return p\n",
    "\n",
    "def pvalue_to_stars(p):\n",
    "    if not np.isfinite(p):\n",
    "        return \"\"\n",
    "    if p < 1e-3:\n",
    "        return \"***\"\n",
    "    elif p < 1e-2:\n",
    "        return \"**\"\n",
    "    elif p < 5e-2:\n",
    "        return \"*\"\n",
    "    else:\n",
    "        return \"ns\"\n",
    "    \n",
    "def reorder_selected_keys_by_model_size(selected_keys_dict, model_rank):\n",
    "    \"\"\"\n",
    "    Reorders the (model_id, prompt_id) tuples in each panel\n",
    "    according to model size ordering.\n",
    "    \"\"\"\n",
    "    reordered = {}\n",
    "\n",
    "    for panel, keys in selected_keys_dict.items():\n",
    "        reordered[panel] = sorted(\n",
    "            keys,\n",
    "            key=lambda k: model_rank[k[0]]\n",
    "        )\n",
    "\n",
    "    return reordered\n",
    "\n",
    "def plot_violin_box_distance_panels(\n",
    "    geometry_store,\n",
    "    selected_keys_dict,\n",
    "    model_names,\n",
    "    figsize=(16, 10),\n",
    "    violin_width=0.8,\n",
    "    box_width=0.25,\n",
    "    alpha_violin=0.5,\n",
    "    title=None\n",
    "):\n",
    "    \"\"\"\n",
    "    Three-panel plot (Balanced / Most genuine / Most hallucinated)\n",
    "    using embedding-space distances only.\n",
    "    \"\"\"\n",
    "\n",
    "    COLORS = {\n",
    "        \"GG\": \"#6E9B34\",  # green\n",
    "        \"HH\": \"#AA4D39\",  # red\n",
    "        \"GH\": \"#27586B\",  # blue\n",
    "    }\n",
    "\n",
    "    panels = [\"balanced\", \"most_genuine\", \"most_hallucinated\"]\n",
    "    panel_titles = {\n",
    "        \"balanced\": \"Balanced prompts\",\n",
    "        \"most_genuine\": \"Mostly genuine prompts\",\n",
    "        \"most_hallucinated\": \"Mostly hallucinated prompts\",\n",
    "    }\n",
    "\n",
    "    fig, axes = plt.subplots(\n",
    "        3, 1,\n",
    "        figsize=figsize,\n",
    "        sharex=True,\n",
    "        gridspec_kw={\"hspace\": 0.15}\n",
    "    )\n",
    "\n",
    "    def _plot_panel(ax, selected_keys, ylabel=None):\n",
    "        n = len(selected_keys)\n",
    "        x = np.arange(n)\n",
    "\n",
    "        # ---- global y-limits (FORCED FROM 0)\n",
    "        all_vals = []\n",
    "        for k in selected_keys:\n",
    "            d = geometry_store[k]\n",
    "            all_vals.extend(d[\"D_GG\"])\n",
    "            all_vals.extend(d[\"D_HH\"])\n",
    "            all_vals.extend(d[\"D_GH\"])\n",
    "\n",
    "        all_vals = np.asarray(all_vals)\n",
    "        all_vals = all_vals[np.isfinite(all_vals)]\n",
    "\n",
    "        ymax = all_vals.max()\n",
    "        ax.set_ylim(0, ymax * 1.12)\n",
    "\n",
    "        y0, y1 = ax.get_ylim()\n",
    "        height = y1 - y0\n",
    "\n",
    "        for i, (m, p) in enumerate(selected_keys):\n",
    "            d = geometry_store[(m, p)]\n",
    "\n",
    "            # ---- GG: left half violin\n",
    "            vp = ax.violinplot(\n",
    "                d[\"D_GG\"],\n",
    "                positions=[x[i]],\n",
    "                widths=violin_width,\n",
    "                showextrema=False\n",
    "            )\n",
    "            for pc in vp[\"bodies\"]:\n",
    "                pc.set_facecolor(COLORS[\"GG\"])\n",
    "                pc.set_edgecolor(\"black\")\n",
    "                pc.set_alpha(alpha_violin)\n",
    "                pc.set_clip_path(\n",
    "                    Rectangle(\n",
    "                        (x[i] - violin_width / 2, y0),\n",
    "                        violin_width / 2,\n",
    "                        height,\n",
    "                        transform=ax.transData\n",
    "                    )\n",
    "                )\n",
    "\n",
    "            # ---- HH: right half violin\n",
    "            vp = ax.violinplot(\n",
    "                d[\"D_HH\"],\n",
    "                positions=[x[i]],\n",
    "                widths=violin_width,\n",
    "                showextrema=False\n",
    "            )\n",
    "            for pc in vp[\"bodies\"]:\n",
    "                pc.set_facecolor(COLORS[\"HH\"])\n",
    "                pc.set_edgecolor(\"black\")\n",
    "                pc.set_alpha(alpha_violin)\n",
    "                pc.set_clip_path(\n",
    "                    Rectangle(\n",
    "                        (x[i], y0),\n",
    "                        violin_width / 2,\n",
    "                        height,\n",
    "                        transform=ax.transData\n",
    "                    )\n",
    "                )\n",
    "\n",
    "            # ---- GH: boxplot\n",
    "            bp = ax.boxplot(\n",
    "                d[\"D_GH\"],\n",
    "                positions=[x[i]],\n",
    "                widths=box_width,\n",
    "                patch_artist=True,\n",
    "                showfliers=False\n",
    "            )\n",
    "            for box in bp[\"boxes\"]:\n",
    "                box.set_facecolor(COLORS[\"GH\"])\n",
    "                box.set_edgecolor(\"black\")\n",
    "\n",
    "            # ---- significance stars (GG vs HH)\n",
    "            pval = gg_hh_pvalue(d[\"D_GG\"], d[\"D_HH\"])\n",
    "            stars = pvalue_to_stars(pval)\n",
    "            if stars:\n",
    "                ax.text(\n",
    "                    x[i],\n",
    "                    y1 - 0.03 * height,\n",
    "                    stars,\n",
    "                    ha=\"center\",\n",
    "                    va=\"top\",\n",
    "                    fontsize=14,\n",
    "                    fontweight=\"bold\"\n",
    "                )\n",
    "\n",
    "        if ylabel is not None:\n",
    "            ax.set_ylabel(ylabel)\n",
    "\n",
    "        ax.grid(True, axis=\"y\", alpha=0.3)\n",
    "\n",
    "    # ---- plot panels\n",
    "    for ax, key in zip(axes, panels):\n",
    "        _plot_panel(\n",
    "            ax,\n",
    "            selected_keys_dict[key],\n",
    "            ylabel=\"Distance\" if ax is axes[1] else None\n",
    "        )\n",
    "        ax.set_title(panel_titles[key])\n",
    "\n",
    "    # ---- x labels: MODEL NAMES ONLY\n",
    "    all_keys = selected_keys_dict[panels[0]]\n",
    "    labels = [model_names[m] for m, _ in all_keys]\n",
    "\n",
    "    axes[-1].set_xticks(np.arange(len(all_keys)))\n",
    "    axes[-1].set_xticklabels(labels)\n",
    "\n",
    "    # ---- legend: bottom, single row, outside\n",
    "    handles = [\n",
    "        plt.Line2D([0], [0], color=COLORS[\"GG\"], lw=6, label=\"GG (intra-genuine)\"),\n",
    "        plt.Line2D([0], [0], color=COLORS[\"GH\"], lw=6, label=\"GH (inter-class)\"),\n",
    "        plt.Line2D([0], [0], color=COLORS[\"HH\"], lw=6, label=\"HH (intra-hallucinated)\"),\n",
    "    ]\n",
    "\n",
    "    fig.legend(\n",
    "        handles=handles,\n",
    "        loc=\"lower center\",\n",
    "        ncol=3,\n",
    "        frameon=False,\n",
    "        bbox_to_anchor=(0.5, +0.03)\n",
    "    )\n",
    "\n",
    "    if title is not None:\n",
    "        fig.suptitle(title, y=0.98, fontsize=14)\n",
    "\n",
    "    # fig.tight_layout(rect=[0, 0.05, 1, 0.95])\n",
    "    return fig, axes\n",
    "\n",
    "\n",
    "selected = {\n",
    "    \"balanced\": [],\n",
    "    \"most_genuine\": [],\n",
    "    \"most_hallucinated\": [],\n",
    "}\n",
    "\n",
    "for mid in model_names:\n",
    "    ans = select_representative_prompts(results_df, model_id=mid)\n",
    "    for kk, dd in ans.items():\n",
    "        selected[kk].append((dd['model_id'], dd['prompt_id']))\n",
    "\n",
    "selected_ordered = reorder_selected_keys_by_model_size(\n",
    "    selected,\n",
    "    model_rank\n",
    ")\n",
    "\n",
    "fig, axes = plot_violin_box_distance_panels(\n",
    "    geometry_store=geometry_store,\n",
    "    selected_keys_dict=selected_ordered,\n",
    "    model_names=model_names,\n",
    "    figsize=(16, 9),\n",
    "    title=\"\"\n",
    ")\n",
    "\n",
    "fig.savefig(\"img/S_promptViolins.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1576a909",
   "metadata": {},
   "outputs": [],
   "source": [
    "# def plot(df):\n",
    "\n",
    "#     df_filter = df[df['valid_geom']]\n",
    "\n",
    "#     W_null_list = [\n",
    "#         [x for pid in df_filter[df_filter['model_id']==mid]['prompt_id'].values for x in null_store[(mid, pid)]]\n",
    "#         for mid in model_names.keys()\n",
    "#     ]\n",
    "\n",
    "#     W_list = [list(x) for _, x in results_df.groupby('model_id')['W_GG_HH']]\n",
    "\n",
    "#     fig, ax = plt.subplots(figsize=(8,6))\n",
    "\n",
    "#     ax.boxplot()\n",
    "\n",
    "\n",
    "def plot_wass_vs_null_per_model(\n",
    "    results_df,\n",
    "    null_store,\n",
    "    model_names,\n",
    "    ratio=(4, 3),\n",
    "    scale=2,\n",
    "    show_fliers=False,\n",
    "    max_prompt_lines=None,\n",
    "    random_state=42\n",
    "):\n",
    "    \"\"\"\n",
    "    Boxplot observed vs null Wasserstein distances per model,\n",
    "    with optional prompt-level paired lines.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    max_prompt_lines : None | int | float\n",
    "        - None: no prompt-level lines\n",
    "        - int: max number of prompts per model\n",
    "        - float in (0,1]: fraction of prompts per model\n",
    "    \"\"\"\n",
    "\n",
    "    rng = np.random.default_rng(random_state)\n",
    "\n",
    "    # ---- filter valid geometry ----\n",
    "    df = results_df[results_df[\"valid_plot\"]]\n",
    "\n",
    "    model_ids = sorted(model_names.keys())\n",
    "\n",
    "    W_obs = []\n",
    "    W_null = []\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    n = len(model_ids)\n",
    "    x = np.arange(n)\n",
    "    dx = 0.18\n",
    "\n",
    "    for i, mid in enumerate(model_ids):\n",
    "        df_m = df[df[\"model_id\"] == mid]\n",
    "\n",
    "        # ---- observed ----\n",
    "        Wm_obs = df_m[\"W_GG_HH\"].dropna().values\n",
    "        W_obs.append(Wm_obs)\n",
    "\n",
    "        # ---- null (flattened) ----\n",
    "        null_vals = []\n",
    "        per_prompt = []\n",
    "\n",
    "        for pid in df_m[\"prompt_id\"].values:\n",
    "            key = (mid, pid)\n",
    "            if key in null_store and len(null_store[key]) > 0:\n",
    "                vals = np.asarray(null_store[key])\n",
    "                null_vals.extend(vals)\n",
    "                per_prompt.append(\n",
    "                    (pid, df_m.loc[df_m[\"prompt_id\"] == pid, \"W_GG_HH\"].values[0], vals.mean())\n",
    "                )\n",
    "\n",
    "        W_null.append(np.asarray(null_vals))\n",
    "\n",
    "        # ---- paired prompt lines (optional) ----\n",
    "        if max_prompt_lines is not None and len(per_prompt) > 0:\n",
    "            if isinstance(max_prompt_lines, float):\n",
    "                k = max(1, int(len(per_prompt) * max_prompt_lines))\n",
    "            else:\n",
    "                k = min(len(per_prompt), int(max_prompt_lines))\n",
    "\n",
    "            sel = rng.choice(len(per_prompt), size=k, replace=False)\n",
    "\n",
    "            for j in sel:\n",
    "                _, w_obs, w_null_mean = per_prompt[j]\n",
    "                hndl = ax.plot(\n",
    "                    [x[i] + dx, x[i] - dx],\n",
    "                    [w_obs, w_null_mean],\n",
    "                    color=\"black\",\n",
    "                    alpha=0.25,\n",
    "                    linewidth=0.8,\n",
    "                    zorder=1,\n",
    "                    label=\"Prompt pairing\"\n",
    "                )\n",
    "\n",
    "    # ---- null boxplots ----\n",
    "    ax.boxplot(\n",
    "        W_null,\n",
    "        positions=x - dx,\n",
    "        widths=0.3,\n",
    "        showfliers=show_fliers,\n",
    "        patch_artist=True,\n",
    "        boxprops=dict(facecolor=\"#BBBBBB\", alpha=0.5),\n",
    "        medianprops=dict(color=\"black\"),\n",
    "    )\n",
    "\n",
    "    # ---- observed boxplots ----\n",
    "    ax.boxplot(\n",
    "        W_obs,\n",
    "        positions=x + dx,\n",
    "        widths=0.3,\n",
    "        showfliers=show_fliers,\n",
    "        patch_artist=True,\n",
    "        boxprops=dict(facecolor=\"#4477AA\", alpha=0.8),\n",
    "        medianprops=dict(color=\"black\"),\n",
    "    )\n",
    "\n",
    "    # ---- axis styling ----\n",
    "    ax.set_xticks(x)\n",
    "    ax.set_xticklabels([model_names[mid] for mid in model_ids])\n",
    "    ax.set_ylabel(r\"$W(D_{GG}, D_{HH})$\")\n",
    "    ax.grid(True, axis=\"y\", alpha=0.5)\n",
    "\n",
    "    # ---- legend ----\n",
    "    legend_handles = [\n",
    "        Patch(facecolor=\"#4477AA\", alpha=0.8, label=\"Observed\"),\n",
    "        Patch(facecolor=\"#BBBBBB\", alpha=0.5, label=\"Null\"),\n",
    "    ]\n",
    "    if max_prompt_lines is not None:\n",
    "        legend_handles.append(hndl[0])\n",
    "    ax.legend(handles=legend_handles, loc=\"upper left\")\n",
    "\n",
    "    fig.tight_layout()\n",
    "    return fig, ax\n",
    "\n",
    "fig, ax = plot_wass_vs_null_per_model(\n",
    "    results_df=results_df,\n",
    "    null_store=null_store,\n",
    "    model_names=model_names,\n",
    "    show_fliers=False,\n",
    "    max_prompt_lines=.1,\n",
    "    ratio=[3, 1],\n",
    "    scale=3,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53e64995",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58f40cd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_separability(gs, agg=\"mean\"):\n",
    "    def _agg(x):\n",
    "        x = np.asarray(x)\n",
    "        x = x[np.isfinite(x)]\n",
    "        if len(x) == 0:\n",
    "            return np.nan\n",
    "        return np.mean(x) if agg == \"mean\" else np.median(x)\n",
    "\n",
    "    mu_GG = _agg(gs[\"D_GG\"])\n",
    "    mu_HH = _agg(gs[\"D_HH\"])\n",
    "    mu_GH = _agg(gs[\"D_GH\"])\n",
    "\n",
    "    mu_GG_z = _agg(gs[\"D_GG_z\"])\n",
    "    mu_HH_z = _agg(gs[\"D_HH_z\"])\n",
    "    mu_GH_z = _agg(gs[\"D_GH_z\"])\n",
    "\n",
    "    intra = 0.5 * (mu_GG + mu_HH)\n",
    "    intra_z = 0.5 * (mu_GG_z + mu_HH_z)\n",
    "\n",
    "    return {\n",
    "        \"sep\": mu_GH / intra if intra > 0 else np.nan,\n",
    "        \"sep_z\": mu_GH_z / intra_z if intra_z > 0 else np.nan,\n",
    "    }\n",
    "\n",
    "def plot_prompt_separability_for_model(\n",
    "    results_df,\n",
    "    geometry_store,\n",
    "    model_id,\n",
    "    ratio=(4, 3),\n",
    "    scale=2,\n",
    "):\n",
    "    rows = []\n",
    "\n",
    "    for (m, p), gs in geometry_store.items():\n",
    "        if m != model_id:\n",
    "            continue\n",
    "        if not results_df[(results_df['model_id']==m) & (results_df['prompt_id']==p)]['valid_geom'].values[0]:\n",
    "            continue\n",
    "        \n",
    "        s = compute_separability(gs)\n",
    "        rows.append({\n",
    "            \"prompt_id\": p,\n",
    "            \"sep\": s[\"sep\"],\n",
    "            \"sep_z\": s[\"sep_z\"],\n",
    "        })\n",
    "\n",
    "    df = pd.DataFrame(rows).sort_values(\"sep_z\")\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    x = np.arange(len(df))\n",
    "\n",
    "    ax.plot(x, df[\"sep\"], marker=\"o\", label=\"Original\")\n",
    "    ax.plot(x, df[\"sep_z\"], marker=\"*\", label=\"Fisher\")\n",
    "\n",
    "    for i in range(len(df)):\n",
    "        ax.plot([x[i], x[i]], [df[\"sep\"].iloc[i], df[\"sep_z\"].iloc[i]],\n",
    "                color=\"gray\", alpha=0.3)\n",
    "\n",
    "    ax.set_xlabel(\"Prompts (ordered by Fisher separability)\")\n",
    "    ax.set_ylabel(\"Separability ratio\")\n",
    "    ax.set_yscale('log')\n",
    "    ax.legend()\n",
    "    ax.grid(True, alpha=0.4)\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "plot_prompt_separability_for_model(results_df, geometry_store, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f768e685",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2dd5ba0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_separability_df(results_df, geometry_store):\n",
    "    rows = []\n",
    "    for (m, p), gs in geometry_store.items():\n",
    "        if not results_df[(results_df['model_id']==m) & (results_df['prompt_id']==p)]['valid_geom'].values[0]:\n",
    "            continue\n",
    "        s = compute_separability(gs)\n",
    "        rows.append({\n",
    "            \"model_id\": m,\n",
    "            \"prompt_id\": p,\n",
    "            \"sep\": s[\"sep\"],\n",
    "            \"sep_z\": s[\"sep_z\"],\n",
    "        })\n",
    "    return pd.DataFrame(rows)\n",
    "\n",
    "def prepare_separability_violin_df(df_ans):\n",
    "    \"\"\"\n",
    "    Returns a long-form DataFrame with columns:\n",
    "      model_id | prompt_id | space | separability\n",
    "    \"\"\"\n",
    "    rows = []\n",
    "\n",
    "    for _, r in df_ans.iterrows():\n",
    "        rows.append({\n",
    "            \"model_id\": r[\"model_id\"],\n",
    "            \"prompt_id\": r[\"prompt_id\"],\n",
    "            \"space\": \"original\",\n",
    "            \"separability\": r[\"sep\"],\n",
    "        })\n",
    "        rows.append({\n",
    "            \"model_id\": r[\"model_id\"],\n",
    "            \"prompt_id\": r[\"prompt_id\"],\n",
    "            \"space\": \"fisher\",\n",
    "            \"separability\": r[\"sep_z\"],\n",
    "        })\n",
    "\n",
    "    return pd.DataFrame(rows)\n",
    "\n",
    "def plot_separability_violins_per_model(\n",
    "    df_sep_long,\n",
    "    model_names,\n",
    "    model_order,\n",
    "    ratio=(5, 3),\n",
    "    scale=2,\n",
    "    cmap_name=\"tab10\",\n",
    "    log_scale=True,\n",
    "):\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    model_order = [x for x in reversed(model_order)]\n",
    "\n",
    "    cmap = plt.get_cmap(cmap_name)\n",
    "\n",
    "    y_positions = np.arange(len(model_order))\n",
    "    width = 0.35\n",
    "\n",
    "    for i, mid in enumerate(model_order):\n",
    "        name = model_names[mid]\n",
    "        color = cmap(i % cmap.N)\n",
    "\n",
    "        for space, offset, alpha in [\n",
    "            (\"original\", 0, 0.30),\n",
    "            (\"fisher\",   0, 0.7),\n",
    "        ]:\n",
    "            vals = df_sep_long[\n",
    "                (df_sep_long[\"model_id\"] == mid) &\n",
    "                (df_sep_long[\"space\"] == space)\n",
    "            ][\"separability\"].values\n",
    "\n",
    "            if len(vals) == 0:\n",
    "                continue\n",
    "\n",
    "            vp = ax.violinplot(\n",
    "                vals,\n",
    "                positions=[y_positions[i] + offset],\n",
    "                vert=False,\n",
    "                widths=width,\n",
    "                showmeans=False,\n",
    "                showextrema=False,\n",
    "                showmedians=False,\n",
    "            )\n",
    "\n",
    "            for body in vp[\"bodies\"]:\n",
    "                body.set_facecolor(color)\n",
    "                body.set_edgecolor(\"black\")\n",
    "                body.set_alpha(alpha)\n",
    "\n",
    "    ax.set_yticks(y_positions)\n",
    "    ax.set_yticklabels([model_names[m] for m in model_order])\n",
    "\n",
    "    if log_scale:\n",
    "        ax.set_xscale(\"log\")\n",
    "\n",
    "    ax.set_xlabel(\"Separability ratio  (inter / intra)\")\n",
    "    # ax.set_title(\"Class separability across prompts\\n(original vs Fisher space)\")\n",
    "\n",
    "    # legend proxy\n",
    "    from matplotlib.patches import Patch\n",
    "    legend_elems = [\n",
    "        Patch(facecolor=\"gray\", alpha=0.30, label=\"Original space\"),\n",
    "        Patch(facecolor=\"gray\", alpha=0.7, label=\"Fisher space\"),\n",
    "    ]\n",
    "    ax.legend(handles=legend_elems, bbox_to_anchor=(0.85, 0.93), loc='center')\n",
    "\n",
    "    ax.grid(True, axis=\"x\", alpha=0.4)\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "df_ans = build_separability_df(results_df, geometry_store)\n",
    "df_sep_long = prepare_separability_violin_df(df_ans)\n",
    "\n",
    "fig, ax = plot_separability_violins_per_model(\n",
    "    df_sep_long,\n",
    "    model_names=model_names,\n",
    "    model_order=model_order,\n",
    ")\n",
    "\n",
    "fig.savefig(\"img/S_FisherSeparability.pdf\", bbox_inches='tight')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dba71c32",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da580db9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "a33c91cd",
   "metadata": {},
   "source": [
    "## Move to cluster descriptors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9f2200b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def dispersion_ratio(D_GG, D_HH):\n",
    "    \"\"\"\n",
    "    Ratio of intra-class dispersions (HH / GG).\n",
    "    > 1 means hallucinations are more spread.\n",
    "    \"\"\"\n",
    "    if len(D_GG) == 0 or len(D_HH) == 0:\n",
    "        return np.nan\n",
    "    return np.median(D_HH) / np.median(D_GG)\n",
    "\n",
    "def separation_ratio(D_GG, D_HH, D_GH):\n",
    "    \"\"\"\n",
    "    Inter-class distance relative to average intra-class distance.\n",
    "    Lower values indicate degradation.\n",
    "    \"\"\"\n",
    "    if len(D_GG) == 0 or len(D_HH) == 0 or len(D_GH) == 0:\n",
    "        return np.nan\n",
    "    intra = 0.5 * (np.median(D_GG) + np.median(D_HH))\n",
    "    return np.median(D_GH) / intra\n",
    "\n",
    "def intra_shape_divergence(D_GG, D_HH):\n",
    "    \"\"\"\n",
    "    Wasserstein distance between intra-class distance distributions,\n",
    "    normalised by GG scale.\n",
    "    \"\"\"\n",
    "    if len(D_GG) == 0 or len(D_HH) == 0:\n",
    "        return np.nan\n",
    "    W = wasserstein_distance(D_GG, D_HH)\n",
    "    return W / np.median(D_GG)\n",
    "\n",
    "def covariance_trace(X):\n",
    "    if len(X) < 2:\n",
    "        return np.nan\n",
    "    S = np.cov(X, rowvar=False, bias=True)\n",
    "    return np.trace(S)\n",
    "\n",
    "def effective_rank(X, eps=1e-12):\n",
    "    \"\"\"\n",
    "    exp( Shannon entropy of normalised eigenvalues )\n",
    "    \"\"\"\n",
    "    if len(X) < 2:\n",
    "        return np.nan\n",
    "    S = np.cov(X, rowvar=False, bias=True)\n",
    "    w = np.linalg.eigvalsh(S)\n",
    "    w = np.maximum(w, eps)\n",
    "    p = w / w.sum()\n",
    "    return np.exp(-(p * np.log(p)).sum())\n",
    "\n",
    "def normalised_dispersion(X):\n",
    "    if len(X) < 2:\n",
    "        return np.nan\n",
    "    S = np.cov(X, rowvar=False, bias=True)\n",
    "    d = S.shape[0]\n",
    "    return np.trace(S) / d\n",
    "\n",
    "def dispersion_inflation(X_G, X_H):\n",
    "    return normalised_dispersion(X_H) / normalised_dispersion(X_G)\n",
    "\n",
    "def shape_degradation(X_G, X_H):\n",
    "    return effective_rank(X_H) / effective_rank(X_G)\n",
    "\n",
    "def covariance_spectral_divergence(X_G, X_H, eps=1e-12):\n",
    "    if len(X_G) < 2 or len(X_H) < 2:\n",
    "        return np.nan\n",
    "\n",
    "    SG = np.cov(X_G, rowvar=False, bias=True)\n",
    "    SH = np.cov(X_H, rowvar=False, bias=True)\n",
    "\n",
    "    wG = np.maximum(np.linalg.eigvalsh(SG), eps)\n",
    "    wH = np.maximum(np.linalg.eigvalsh(SH), eps)\n",
    "\n",
    "    pG = wG / wG.sum()\n",
    "    pH = wH / wH.sum()\n",
    "\n",
    "    return np.sum(pG * np.log(pG / pH))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d44d2167",
   "metadata": {},
   "outputs": [],
   "source": [
    "def balanced_subsample_prompt(X, y, n_total, rng):\n",
    "    \"\"\"\n",
    "    Balanced subsampling preserving original GG/HH ratio.\n",
    "    \"\"\"\n",
    "    idx_G = np.where(y == 0)[0]  # assuming 0 = GG, 1 = HH\n",
    "    idx_H = np.where(y == 1)[0]\n",
    "\n",
    "    nG_full = len(idx_G)\n",
    "    nH_full = len(idx_H)\n",
    "    n_full = nG_full + nH_full\n",
    "\n",
    "    if n_total > n_full or nG_full < 3 or nH_full < 3:\n",
    "        return None\n",
    "\n",
    "    frac_G = nG_full / n_full\n",
    "    nG = int(round(frac_G * n_total))\n",
    "    nH = n_total - nG\n",
    "\n",
    "    # enforce feasibility\n",
    "    if nG < 1 or nH < 1 or nG > nG_full or nH > nH_full:\n",
    "        return None\n",
    "\n",
    "    sel_G = rng.choice(idx_G, size=nG, replace=False)\n",
    "    sel_H = rng.choice(idx_H, size=nH, replace=False)\n",
    "\n",
    "    sel = np.concatenate([sel_G, sel_H])\n",
    "    return X[sel], y[sel]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d69890b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_prompt_degradation_subsampled(\n",
    "    df,\n",
    "    model_id,\n",
    "    prompt_id,\n",
    "    n_total,\n",
    "    rng,\n",
    "):\n",
    "    X, y = extract_prompt_data(df, model_id, prompt_id)\n",
    "    out = balanced_subsample_prompt(X, y, n_total, rng)\n",
    "\n",
    "    if out is None:\n",
    "        return None\n",
    "\n",
    "    Xs, ys = out\n",
    "    X_G, X_H = split_by_label(Xs, ys)\n",
    "\n",
    "    if len(X_G) < 2 or len(X_H) < 2:\n",
    "        return None\n",
    "\n",
    "    res = {\n",
    "        \"model_id\": model_id,\n",
    "        \"prompt_id\": prompt_id,\n",
    "        \"n_total\": n_total,\n",
    "        \"n_G\": len(X_G),\n",
    "        \"n_H\": len(X_H),\n",
    "        \"frac_G\": len(X_G) / n_total,\n",
    "    }\n",
    "\n",
    "    # absolute\n",
    "    res[\"trace_G\"] = covariance_trace(X_G)\n",
    "    res[\"trace_H\"] = covariance_trace(X_H)\n",
    "    res[\"effrank_G\"] = effective_rank(X_G)\n",
    "    res[\"effrank_H\"] = effective_rank(X_H)\n",
    "\n",
    "    # relative\n",
    "    res[\"dispersion_inflation\"] = dispersion_inflation(X_G, X_H)\n",
    "    res[\"shape_degradation\"] = shape_degradation(X_G, X_H)\n",
    "    res[\"spectral_divergence\"] = covariance_spectral_divergence(X_G, X_H)\n",
    "\n",
    "    # other\n",
    "    D_GG, D_HH, D_GH = compute_distance_distributions(X_G, X_H)\n",
    "    res[\"dispersion_ratio\"] = dispersion_ratio(D_GG, D_HH)\n",
    "    res[\"separation_ratio\"] = separation_ratio(D_GG, D_HH, D_GH)\n",
    "    res[\"intra_shape_divergence\"] = intra_shape_divergence(D_GG, D_HH)\n",
    "\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f08d807c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_subsampling_protocol(\n",
    "    df,\n",
    "    model_ids,\n",
    "    prompt_ids,\n",
    "    n_sizes,\n",
    "    n_boot=10,\n",
    "    seed=0,\n",
    "    use_cache=False,\n",
    "    cache_dir=\"cache\",\n",
    "    overwrite_cache=False,\n",
    "):\n",
    "    \"\"\"\n",
    "    Run subsampling protocol with filesystem caching.\n",
    "\n",
    "    Cache files\n",
    "    ----------\n",
    "    subsampling_results.parquet\n",
    "    meta.json\n",
    "    \"\"\"\n",
    "\n",
    "    # ---- cache paths ----\n",
    "    if use_cache:\n",
    "        if cache_dir is None:\n",
    "            raise ValueError(\"cache_dir must be provided when use_cache=True\")\n",
    "\n",
    "        os.makedirs(cache_dir, exist_ok=True)\n",
    "\n",
    "        results_path = os.path.join(cache_dir, \"subsampling_results.parquet\")\n",
    "        meta_path = os.path.join(cache_dir, \"meta.json\")\n",
    "\n",
    "        cache_exists = (\n",
    "            os.path.exists(results_path)\n",
    "            and os.path.exists(meta_path)\n",
    "        )\n",
    "\n",
    "        if cache_exists and not overwrite_cache:\n",
    "            return pd.read_parquet(results_path)\n",
    "\n",
    "    # ---- computation ----\n",
    "    rng = np.random.default_rng(seed)\n",
    "    rows = []\n",
    "\n",
    "    for model_id in tqdm(model_ids, desc=\"Model\"):\n",
    "        for prompt_id in tqdm(prompt_ids, desc=\"Prompt\"):\n",
    "            for n_total in n_sizes:\n",
    "                for b in range(n_boot):\n",
    "\n",
    "                    res = evaluate_prompt_degradation_subsampled(\n",
    "                        df, model_id, prompt_id, n_total, rng\n",
    "                    )\n",
    "\n",
    "                    if res is not None:\n",
    "                        res = res.copy()\n",
    "                        res[\"boot\"] = b\n",
    "                        rows.append(res)\n",
    "\n",
    "    results_df = pd.DataFrame(rows)\n",
    "\n",
    "    # ---- save cache ----\n",
    "    if use_cache:\n",
    "        results_df.to_parquet(results_path, index=False)\n",
    "\n",
    "        meta = {\n",
    "            \"model_ids\": list(model_ids),\n",
    "            \"prompt_ids\": list(prompt_ids),\n",
    "            \"n_sizes\": list(n_sizes),\n",
    "            \"n_boot\": n_boot,\n",
    "            \"seed\": seed,\n",
    "            \"n_rows\": len(results_df),\n",
    "        }\n",
    "\n",
    "        with open(meta_path, \"w\") as f:\n",
    "            json.dump(meta, f, indent=2)\n",
    "\n",
    "    return results_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8673fe09",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_sizes = range(10, 101, 10)\n",
    "\n",
    "df_deg = run_subsampling_protocol(\n",
    "    df=df,\n",
    "    model_ids=model_names.keys(),\n",
    "    prompt_ids=results_df[\"prompt_id\"].unique(),\n",
    "    n_sizes=n_sizes,\n",
    "    n_boot=10,\n",
    "    use_cache=True,\n",
    "    cache_dir=cache_dir,\n",
    ") if do_descriptors else None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d30bcc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_degradation_indices(df_deg, indices=None, model_names=None, figsize=(16, 10)):\n",
    "    \"\"\"\n",
    "    Plot degradation indices vs dataset size for each model.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    df_deg : pd.DataFrame\n",
    "        Output of run_subsampling_protocol\n",
    "    indices : list of str, optional\n",
    "        Column names of degradation indices to plot\n",
    "    model_names : dict, optional\n",
    "        model_id -> model name mapping\n",
    "    figsize : tuple\n",
    "        Figure size\n",
    "    \"\"\"\n",
    "    if indices is None:\n",
    "        indices = [\"trace_H\", \"effrank_H\", \"dispersion_inflation\", \n",
    "                   \"shape_degradation\", \"spectral_divergence\"]\n",
    "\n",
    "    n_indices = len(indices)\n",
    "    fig, axes = plt.subplots(n_indices, 1, figsize=figsize, sharex=True)\n",
    "\n",
    "    if n_indices == 1:\n",
    "        axes = [axes]\n",
    "\n",
    "    for ax, idx in zip(axes, indices):\n",
    "        # compute median and IQR per model & n_total\n",
    "        df_plot = df_deg.groupby([\"model_id\", \"n_total\"])[idx].agg(\n",
    "            median=\"median\",\n",
    "            q25=lambda x: np.percentile(x, 25),\n",
    "            q75=lambda x: np.percentile(x, 75)\n",
    "        ).reset_index()\n",
    "\n",
    "        for mid in df_plot[\"model_id\"].unique():\n",
    "            df_m = df_plot[df_plot[\"model_id\"] == mid]\n",
    "            label = model_names[mid] if model_names else str(mid)\n",
    "\n",
    "            ax.plot(df_m[\"n_total\"], df_m[\"median\"], label=label, marker=\"o\")\n",
    "            ax.fill_between(df_m[\"n_total\"], df_m[\"q25\"], df_m[\"q75\"], alpha=0.2)\n",
    "\n",
    "        ax.set_ylabel(idx)\n",
    "        ax.grid(True, axis=\"y\", alpha=0.3)\n",
    "\n",
    "    axes[-1].set_xlabel(\"Subsample size (n_total)\")\n",
    "    axes[0].set_title(\"Degradation indices vs dataset size\")\n",
    "    axes[0].legend(loc=\"upper left\")\n",
    "    fig.tight_layout()\n",
    "    return fig, axes\n",
    "\n",
    "if df_deg is not None:\n",
    "    fig, axes = plot_degradation_indices(df_deg, indices=[\"dispersion_inflation\", \"shape_degradation\", \"spectral_divergence\"], model_names=model_names)\n",
    "    fig, axes = plot_degradation_indices(df_deg, indices=[\"dispersion_ratio\", \"separation_ratio\", \"intra_shape_divergence\"], model_names=model_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8f14cdc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "a2fde7dd",
   "metadata": {},
   "source": [
    "## Move to detectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e3ae1c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class FisherWassersteinDetector:\n",
    "    def __init__(self, lambda_reg=1e-3, normalise=True, normalise_by_trace=True):\n",
    "        self.lambda_reg = lambda_reg\n",
    "        self.normalise = normalise\n",
    "        self.normalise_by_trace = normalise_by_trace\n",
    "        self.v = None\n",
    "        self.Z_G = None\n",
    "        self.Z_H = None\n",
    "\n",
    "    def fit(self, X_G, X_H):\n",
    "        \"\"\"Compute Fisher direction from training embeddings\"\"\"\n",
    "        self.v = fisher_direction(X_G, X_H, lambda_reg=self.lambda_reg, normalise=self.normalise, normalise_by_trace=self.normalise_by_trace)\n",
    "        self.Z_G = (X_G @ self.v)[:, None]  # shape (n_G, 1)\n",
    "        self.Z_H = (X_H @ self.v)[:, None]  # shape (n_H, 1)\n",
    "        assert self.Z_G.ndim == 2 and self.Z_G.shape[1] == 1\n",
    "        assert self.Z_H.ndim == 2 and self.Z_H.shape[1] == 1\n",
    "\n",
    "    def predict_point(self, x):\n",
    "        \"\"\"Assign a single point x to class 0 (G) or 1 (H)\"\"\"\n",
    "        z = (x @ self.v).reshape(1, 1)\n",
    "        assert z.shape == (1, 1)\n",
    "\n",
    "        D_test_G = cdist(z, self.Z_G).ravel()\n",
    "        D_test_H = cdist(z, self.Z_H).ravel()\n",
    "\n",
    "        W_G = wasserstein_distance(D_test_G, pdist(self.Z_G)) if len(self.Z_G) > 1 else np.inf\n",
    "        W_H = wasserstein_distance(D_test_H, pdist(self.Z_H)) if len(self.Z_H) > 1 else np.inf\n",
    "\n",
    "        return 0 if W_G <= W_H else 1\n",
    "\n",
    "    def predict(self, X):\n",
    "        \"\"\"Assign multiple points\"\"\"\n",
    "        return np.array([self.predict_point(x) for x in X])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1432cc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LabelPropEvaluator:\n",
    "    def __init__(self, detector, X_test, y_test):\n",
    "        self.detector = detector\n",
    "        self.X_test = X_test\n",
    "        self.y_test = y_test\n",
    "\n",
    "    def evaluate(self):\n",
    "        y_pred = self.detector.predict(self.X_test)\n",
    "\n",
    "        tn, fp, fn, tp = confusion_matrix(\n",
    "            self.y_test, y_pred, labels=[0, 1]\n",
    "        ).ravel()\n",
    "\n",
    "        metrics = {\n",
    "            \"n_test\": len(self.y_test),\n",
    "            \"accuracy\": accuracy_score(self.y_test, y_pred),\n",
    "            \"f1\": f1_score(self.y_test, y_pred, zero_division=0),\n",
    "            \"precision\": precision_score(self.y_test, y_pred, zero_division=0),\n",
    "            \"recall\": recall_score(self.y_test, y_pred, zero_division=0),\n",
    "            # confusion matrix entries\n",
    "            \"tn\": tn,\n",
    "            \"fp\": fp,\n",
    "            \"fn\": fn,\n",
    "            \"tp\": tp,\n",
    "        }\n",
    "        return metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29ab70d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_fixed_test_sets(X, y, n_splits=5, test_fraction=0.2, random_state=42):\n",
    "    sss = StratifiedShuffleSplit(n_splits=n_splits, test_size=test_fraction, random_state=random_state)\n",
    "    trn_splits = []\n",
    "    tst_splits = []\n",
    "    for trn_idx, tst_idx in sss.split(X, y):\n",
    "        X_trn, y_trn = X[trn_idx], y[trn_idx]\n",
    "        X_tst, y_tst = X[tst_idx], y[tst_idx]\n",
    "        trn_splits.append((X_trn, y_trn))\n",
    "        tst_splits.append((X_tst, y_tst))\n",
    "    return trn_splits, tst_splits\n",
    "\n",
    "def subsample_training_set(X_train_full, y_train_full, fraction):\n",
    "    \"\"\"Balanced subsample by class\"\"\"\n",
    "    X_G, X_H = split_by_label(X_train_full, y_train_full)\n",
    "\n",
    "    n_G = max(1, int(fraction * len(X_G)))\n",
    "    n_H = max(1, int(fraction * len(X_H)))\n",
    "\n",
    "    rng = np.random.default_rng()\n",
    "    idx_G = rng.choice(len(X_G), size=n_G, replace=False)\n",
    "    idx_H = rng.choice(len(X_H), size=n_H, replace=False)\n",
    "\n",
    "    X_sub = np.vstack([X_G[idx_G], X_H[idx_H]])\n",
    "    y_sub = np.concatenate([np.zeros(n_G, dtype=bool), np.ones(n_H, dtype=bool)])\n",
    "\n",
    "    return X_sub, y_sub\n",
    "\n",
    "def run_label_propagation_experiment(\n",
    "    df,\n",
    "    model_id,\n",
    "    prompt_id,\n",
    "    train_fractions=[0.1, 0.3, 0.5, 0.7, 1.0],\n",
    "    n_iter=10,\n",
    "    test_fraction=0.2,\n",
    "    n_splits=5,\n",
    "    lambda_reg=1,\n",
    "    logskip=False,\n",
    "    random_state=42\n",
    "):\n",
    "    \"\"\"\n",
    "    Run label propagation experiment with fixed test sets and multiple\n",
    "    random subsamples of the training set per fraction.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    df : pd.DataFrame\n",
    "        Full dataset\n",
    "    model_id : int\n",
    "        Model index\n",
    "    prompt_id : int\n",
    "        Prompt index\n",
    "    train_fractions : list of float\n",
    "        Fractions of the training set to subsample\n",
    "    n_iter : int\n",
    "        Number of random subsamples per fraction\n",
    "    test_fraction : float\n",
    "        Fraction of data for test\n",
    "    n_splits : int\n",
    "        Number of fixed test splits\n",
    "    random_state : int\n",
    "        RNG seed\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    pd.DataFrame\n",
    "        Metrics for all splits, fractions, and iterations\n",
    "    \"\"\"\n",
    "\n",
    "    # ---- extract full dataset ----\n",
    "    X, y = extract_prompt_data(df, model_id, prompt_id)\n",
    "    n_H = sum(y)\n",
    "    n_G = len(y) - n_H\n",
    "    if n_H < 5 or n_G < 5:\n",
    "        if not logskip:\n",
    "            print(f\"Skipping model {model_id}, prompt {prompt_id} due to unbalancedeness (n_G = {n_G}, n_H = {n_H})\")\n",
    "        return None \n",
    "\n",
    "    # ---- fixed stratified splits ----\n",
    "    trn_sets, tst_sets = generate_fixed_test_sets(\n",
    "        X, y, n_splits=n_splits, test_fraction=test_fraction, random_state=random_state\n",
    "    )\n",
    "\n",
    "    results = []\n",
    "\n",
    "    # ---- loop over fixed test sets ----\n",
    "    for test_id, ((X_train_full, y_train_full), (X_test, y_test)) in enumerate(zip(trn_sets, tst_sets)):\n",
    "\n",
    "        for tf in train_fractions:\n",
    "            for iter_id in range(n_iter):\n",
    "                # balanced random subsample of training set\n",
    "                X_sub, y_sub = subsample_training_set(X_train_full, y_train_full, tf)\n",
    "\n",
    "                X_G_sub, X_H_sub = split_by_label(X_sub, y_sub)\n",
    "                if len(X_G_sub) < 2 or len(X_H_sub) < 2:\n",
    "                    continue\n",
    "\n",
    "                # fit detector\n",
    "                detector = FisherWassersteinDetector(lambda_reg=lambda_reg)\n",
    "                detector.fit(X_G_sub, X_H_sub)\n",
    "\n",
    "                # evaluate on fixed test set\n",
    "                evaluator = LabelPropEvaluator(detector, X_test, y_test)\n",
    "                metrics = evaluator.evaluate()\n",
    "\n",
    "                # add metadata\n",
    "                metrics.update({\n",
    "                    \"train_fraction\": tf,\n",
    "                    \"iter_id\": iter_id,\n",
    "                    \"test_set_id\": test_id,\n",
    "                    \"model_id\": model_id,\n",
    "                    \"prompt_id\": prompt_id,\n",
    "                    \"n_train\": len(X_sub)\n",
    "                })\n",
    "\n",
    "                results.append(metrics)\n",
    "\n",
    "    return pd.DataFrame(results)\n",
    "\n",
    "def run_full_label_propagation_study(\n",
    "    df,\n",
    "    model_ids,\n",
    "    prompt_ids_by_model,\n",
    "    train_fractions,\n",
    "    n_iter=10,\n",
    "    test_fraction=0.2,\n",
    "    n_splits=5,\n",
    "    lambda_reg=1,\n",
    "    random_state=42,\n",
    "    use_cache=False,\n",
    "    cache_dir=\"cache/label_propagation\",\n",
    "    overwrite_cache=False,\n",
    "    logskip=False\n",
    "):\n",
    "    \"\"\"\n",
    "    Run label propagation experiments with per-(model, prompt) caching.\n",
    "    \"\"\"\n",
    "\n",
    "    if use_cache:\n",
    "        os.makedirs(cache_dir, exist_ok=True)\n",
    "\n",
    "    all_results = []\n",
    "\n",
    "    for mid in tqdm(model_ids, desc=\"Model\"):\n",
    "        for pid in tqdm(prompt_ids_by_model[mid], desc=\"Prompt\"):\n",
    "\n",
    "            cache_path = None\n",
    "            if use_cache:\n",
    "                fname = f\"model={mid}__prompt={pid}.parquet\"\n",
    "                cache_path = os.path.join(cache_dir, fname)\n",
    "\n",
    "                if os.path.exists(cache_path) and not overwrite_cache:\n",
    "                    all_results.append(pd.read_parquet(cache_path))\n",
    "                    continue\n",
    "\n",
    "            # ---- compute ----\n",
    "            res_df = run_label_propagation_experiment(\n",
    "                df=df,\n",
    "                model_id=mid,\n",
    "                prompt_id=pid,\n",
    "                train_fractions=train_fractions,\n",
    "                n_iter=n_iter,\n",
    "                test_fraction=test_fraction,\n",
    "                n_splits=n_splits,\n",
    "                lambda_reg=lambda_reg,\n",
    "                logskip=logskip,\n",
    "                random_state=random_state,\n",
    "            )\n",
    "\n",
    "            if res_df is None or len(res_df) == 0:\n",
    "                continue\n",
    "\n",
    "            all_results.append(res_df)\n",
    "\n",
    "            # ---- save subcache ----\n",
    "            if use_cache:\n",
    "                res_df.to_parquet(cache_path, index=False)\n",
    "\n",
    "    if not all_results:\n",
    "        return pd.DataFrame()\n",
    "\n",
    "    results_lp = pd.concat(all_results, ignore_index=True)\n",
    "\n",
    "    # ---- global metadata (optional) ----\n",
    "    if use_cache:\n",
    "        meta = {\n",
    "            \"model_ids\": list(model_ids),\n",
    "            \"train_fractions\": train_fractions,\n",
    "            \"n_iter\": n_iter,\n",
    "            \"test_fraction\": test_fraction,\n",
    "            \"n_splits\": n_splits,\n",
    "            \"random_state\": random_state,\n",
    "            \"n_cached_pairs\": len(all_results),\n",
    "            \"lambda_reg\": lambda_reg,\n",
    "        }\n",
    "\n",
    "        with open(os.path.join(cache_dir, \"meta.json\"), \"w\") as f:\n",
    "            json.dump(meta, f, indent=2)\n",
    "\n",
    "    return results_lp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dd570d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_fractions = [x/100 for x in range(5, 101, 5)] # [0.05, 0.1, 0.15, 0.2, 0.4, 0.7, 1.0]\n",
    "\n",
    "results_lp = run_full_label_propagation_study(\n",
    "    df=df,\n",
    "    model_ids=model_names.keys(),\n",
    "    prompt_ids_by_model=[[x for x in df[df['model_id']==mid]['prompt_id'].unique()] for mid in model_names],\n",
    "    train_fractions=train_fractions,\n",
    "    n_iter=20,\n",
    "    lambda_reg=best_reg_lambda,\n",
    "    use_cache=True,\n",
    "    cache_dir=f\"{cache_dir}/label-propagation-{best_reg_lambda}\",\n",
    "    logskip=True\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d46656b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# def aggregate_metric_over_prompts(\n",
    "#     df,\n",
    "#     metric=\"f1\",\n",
    "#     agg_fn=\"mean\",\n",
    "# ):\n",
    "#     \"\"\"\n",
    "#     Aggregate metric across prompts, test splits and iterations.\n",
    "\n",
    "#     Returns\n",
    "#     -------\n",
    "#     pd.DataFrame with columns:\n",
    "#     [model_id, train_fraction, metric_mean, metric_std]\n",
    "#     \"\"\"\n",
    "\n",
    "#     grouped = df.groupby([\"model_id\", \"train_fraction\"])[metric]\n",
    "\n",
    "#     out = grouped.agg([\"mean\", \"std\"]).reset_index()\n",
    "#     out = out.rename(columns={\n",
    "#         \"mean\": f\"{metric}_mean\",\n",
    "#         \"std\": f\"{metric}_std\",\n",
    "#     })\n",
    "\n",
    "#     return out\n",
    "\n",
    "def aggregate_metric_over_prompts(\n",
    "    df,\n",
    "    metric=\"f1\",\n",
    "    score_metric='accuracy',\n",
    "    agg_prompts=True,\n",
    "    agg_train_frac=False,\n",
    "    agg_models=False\n",
    "):\n",
    "    \"\"\"\n",
    "    Aggregate a metric over prompts (and test splits / subsamples).\n",
    "\n",
    "    Returns one row per (model_id, train_fraction).\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    df : pd.DataFrame\n",
    "        Output of label propagation experiments\n",
    "    metric : str\n",
    "        Metric to aggregate (e.g. 'f1')\n",
    "    score_metric : str\n",
    "        Metric to aggregate (e.g. 'accuracy')\n",
    "    agg_prompts : bool\n",
    "        If True, aggregate over prompts as well\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    agg_df : pd.DataFrame\n",
    "        Columns:\n",
    "        - model_id\n",
    "        - train_fraction\n",
    "        - metric_mean\n",
    "        - metric_std\n",
    "        - score_mean\n",
    "        - score_std\n",
    "        - mean_n_train\n",
    "        - std_n_train\n",
    "        - n_runs\n",
    "    \"\"\"\n",
    "\n",
    "    group_cols = []\n",
    "    if not agg_models:\n",
    "        group_cols.append(\"model_id\")\n",
    "    if not agg_prompts:\n",
    "        group_cols.append(\"prompt_id\")\n",
    "    if not agg_train_frac:\n",
    "        group_cols.append(\"train_fraction\")\n",
    "\n",
    "    agg_df = (\n",
    "        df\n",
    "        .groupby(group_cols)\n",
    "        .agg(\n",
    "            metric_mean=(metric, \"mean\"),\n",
    "            metric_std=(metric, \"std\"),\n",
    "            score_mean=(score_metric, \"mean\"),\n",
    "            score_std=(score_metric, \"std\"),\n",
    "            mean_n_train=(\"n_train\", \"mean\"),\n",
    "            std_n_train=(\"n_train\", \"std\"),\n",
    "            n_runs=(metric, \"count\"),\n",
    "        )\n",
    "        .reset_index()\n",
    "    )\n",
    "\n",
    "    return agg_df\n",
    "\n",
    "def plot_learning_curves(\n",
    "    agg_df,\n",
    "    model_names,\n",
    "    metric=\"f1\",\n",
    "    ratio=(3, 2),\n",
    "    scale=3\n",
    "):\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    for mid, name in model_names.items():\n",
    "        df_m = agg_df[agg_df[\"model_id\"] == mid]\n",
    "        if df_m.empty:\n",
    "            continue\n",
    "\n",
    "        ax.plot(\n",
    "            df_m[\"mean_n_train\"],\n",
    "            df_m[\"metric_mean\"],\n",
    "            marker=\"o\",\n",
    "            label=name\n",
    "        )\n",
    "\n",
    "        ax.fill_between(\n",
    "            df_m[\"mean_n_train\"],\n",
    "            df_m[\"metric_mean\"] - df_m[\"metric_std\"],\n",
    "            df_m[\"metric_mean\"] + df_m[\"metric_std\"],\n",
    "            alpha=0.2\n",
    "        )\n",
    "\n",
    "    ax.set_xlabel(\"Training set size (mean)\")\n",
    "    ax.set_ylabel(metric.upper())\n",
    "    ax.grid(True, alpha=0.4)\n",
    "    ax.legend()\n",
    "    ax.set_title(f\"{metric.upper()} vs training size\")\n",
    "\n",
    "    return fig, ax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8358648e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "agg_f1 = aggregate_metric_over_prompts(results_lp, metric=\"f1\")\n",
    "\n",
    "fig, ax = plot_learning_curves(\n",
    "    agg_f1,\n",
    "    model_names=model_names,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2369a3ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_lp.groupby('train_fraction')['n_train'].plot(color='k', alpha=.4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd704284",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_prompt_level_boxplot_df(df, metric=\"f1\"):\n",
    "    return (\n",
    "        df\n",
    "        .groupby([\"model_id\", \"prompt_id\", \"train_fraction\"])\n",
    "        .agg(\n",
    "            metric_mean=(metric, \"mean\"),\n",
    "            mean_n_train=(\"n_train\", \"mean\"),\n",
    "        )\n",
    "        .reset_index()\n",
    "    )\n",
    "\n",
    "def plot_metric_boxplots(\n",
    "    df_prompt,\n",
    "    model_names,\n",
    "    train_fraction,\n",
    "    metric=\"f1\",\n",
    "    ratio=(4, 3),\n",
    "    scale=2\n",
    "):\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    data = []\n",
    "    labels = []\n",
    "\n",
    "    for mid, name in model_names.items():\n",
    "        vals = df_prompt[\n",
    "            (df_prompt[\"model_id\"] == mid) &\n",
    "            (df_prompt[\"train_fraction\"] == train_fraction)\n",
    "        ][\"metric_mean\"].values\n",
    "\n",
    "        if len(vals) > 0:\n",
    "            data.append(vals)\n",
    "            labels.append(name)\n",
    "\n",
    "    ax.boxplot(data, showfliers=True)\n",
    "    ax.set_xticklabels(labels, rotation=30, ha=\"right\")\n",
    "    ax.set_ylabel(metric.upper())\n",
    "    ax.set_title(f\"{metric.upper()} distribution across prompts\")\n",
    "\n",
    "    ax.grid(True, axis=\"y\", alpha=0.4)\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "def plot_metric_boxplots(\n",
    "    df_prompt,\n",
    "    model_names,\n",
    "    train_fraction,\n",
    "    metric=\"f1\",\n",
    "    ratio=(4, 3),\n",
    "    scale=2,\n",
    "    cmap_name=\"tab10\"\n",
    "):\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    cmap = plt.get_cmap(cmap_name)\n",
    "\n",
    "    data = []\n",
    "    labels = []\n",
    "    colors = []\n",
    "\n",
    "    for i, (mid, name) in enumerate(model_names.items()):\n",
    "        vals = df_prompt[\n",
    "            (df_prompt[\"model_id\"] == mid) &\n",
    "            (df_prompt[\"train_fraction\"] == train_fraction)\n",
    "        ][\"metric_mean\"].values\n",
    "\n",
    "        if len(vals) > 0:\n",
    "            data.append(vals)\n",
    "            labels.append(name)\n",
    "            colors.append(cmap(i % cmap.N))\n",
    "\n",
    "    bp = ax.boxplot(\n",
    "        data,\n",
    "        patch_artist=True,\n",
    "        showfliers=True,\n",
    "        medianprops=dict(color=\"black\", linewidth=1.5),\n",
    "        boxprops=dict(linewidth=1.2),\n",
    "        whiskerprops=dict(linewidth=1.0),\n",
    "        capprops=dict(linewidth=1.0),\n",
    "    )\n",
    "\n",
    "    # ---- apply colours\n",
    "    for box, c in zip(bp[\"boxes\"], colors):\n",
    "        box.set_facecolor(c)\n",
    "        box.set_alpha(0.6)\n",
    "\n",
    "    ax.set_xticklabels(labels, rotation=30, ha=\"right\")\n",
    "    ax.set_ylabel(metric.upper())\n",
    "    ax.set_title(\n",
    "        f\"{metric.upper()} distribution across prompts \"\n",
    "        f\"(train_fraction={train_fraction})\"\n",
    "    )\n",
    "\n",
    "    ax.grid(True, axis=\"y\", alpha=0.4)\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "agg_box = prepare_prompt_level_boxplot_df(results_lp)\n",
    "\n",
    "plot_metric_boxplots(agg_box, model_names, train_fractions[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01c5b30a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_metric_boxplots_two_panels(\n",
    "    results_lp,\n",
    "    model_names,\n",
    "    model_order,\n",
    "    train_fraction,\n",
    "    metrics=(\"accuracy\", \"f1\"),\n",
    "    ratio=(6, 3),\n",
    "    scale=2,\n",
    "    width_ratios=[1, 1],\n",
    "    xlims=None,\n",
    "    cmap_name=\"tab10\"\n",
    "):\n",
    "    fig, axes = plt.subplots(\n",
    "        1, 2,\n",
    "        figsize=[scale * x for x in ratio],\n",
    "        sharey=True,\n",
    "        width_ratios=width_ratios\n",
    "    )\n",
    "\n",
    "    cmap = plt.get_cmap(cmap_name)\n",
    "\n",
    "    df_prompt = (\n",
    "        results_lp\n",
    "        .groupby([\"model_id\", \"prompt_id\", \"train_fraction\"])\n",
    "        .agg(\n",
    "            f1_mean=(\"f1\", \"mean\"),\n",
    "            accuracy_mean=(\"accuracy\", \"mean\"),\n",
    "            mean_n_train=(\"n_train\", \"mean\"),\n",
    "        )\n",
    "        .reset_index()\n",
    "    )\n",
    "\n",
    "    for it, metric in enumerate(metrics):\n",
    "        ax = axes[it]\n",
    "        data = []\n",
    "        labels = []\n",
    "        colors = []\n",
    "\n",
    "        for i, mid in enumerate(reversed(model_order)):\n",
    "            name = model_names[mid]\n",
    "\n",
    "            vals = df_prompt[\n",
    "                (df_prompt[\"model_id\"] == mid) &\n",
    "                (df_prompt[\"train_fraction\"] == train_fraction)\n",
    "            ][f\"{metric}_mean\"].values\n",
    "\n",
    "            if len(vals) > 0:\n",
    "                data.append(vals)\n",
    "                labels.append(name)\n",
    "                colors.append(cmap(i % cmap.N))\n",
    "\n",
    "        positions = np.arange(1, len(data) + 1)\n",
    "\n",
    "        # ---- violins (horizontal)\n",
    "        vp = ax.violinplot(\n",
    "            data,\n",
    "            widths=.75,\n",
    "            positions=positions,\n",
    "            vert=False,\n",
    "            showmeans=False,\n",
    "            showextrema=False,\n",
    "        )\n",
    "\n",
    "        for body, c in zip(vp[\"bodies\"], colors):\n",
    "            body.set_facecolor(c)\n",
    "            body.set_alpha(0.35)\n",
    "            body.set_edgecolor(\"none\")\n",
    "\n",
    "        # ---- boxes (horizontal)\n",
    "        bp = ax.boxplot(\n",
    "            data,\n",
    "            positions=positions,\n",
    "            vert=False,\n",
    "            widths=0.25,\n",
    "            showfliers=False,\n",
    "            patch_artist=True,\n",
    "            medianprops=dict(color=\"black\", linewidth=1.5),\n",
    "        )\n",
    "\n",
    "        for box, c in zip(bp[\"boxes\"], colors):\n",
    "            box.set_facecolor(c)\n",
    "            box.set_alpha(0.6)\n",
    "\n",
    "        ax.set_yticks(positions)\n",
    "        ax.set_yticklabels(labels)\n",
    "        # ax.set_xlabel(\"Score\")\n",
    "        ax.set_title(metric.capitalize())\n",
    "        ax.grid(True, axis=\"x\", alpha=0.4)\n",
    "        ax.set_xlim(xlims[it])\n",
    "\n",
    "    # axes[0].set_ylabel(\"Model\")\n",
    "\n",
    "    # fig.suptitle(\n",
    "    #     f\"Prompt-level score distributions (train_fraction={train_fraction})\",\n",
    "    #     y=1.02\n",
    "    # )\n",
    "\n",
    "    fig.tight_layout()\n",
    "    return fig, axes\n",
    "\n",
    "fig, _ = plot_metric_boxplots_two_panels(\n",
    "    results_lp,\n",
    "    model_names=model_names,\n",
    "    model_order=model_order,\n",
    "    train_fraction=train_fractions[-1],\n",
    "    width_ratios=[5, 9],\n",
    "    xlims=[(0.5, 1.0), (.1, 1.0)],\n",
    "    ratio=(3,1),\n",
    "    scale=3\n",
    ")\n",
    "\n",
    "fig.savefig(\"img/LP_statistics.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c5a78fb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd094226",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_heatmap_df(agg_df):\n",
    "    return agg_df.pivot(\n",
    "        index=\"model_id\",\n",
    "        columns=\"train_fraction\",\n",
    "        values=\"metric_mean\"\n",
    "    )\n",
    "\n",
    "def plot_metric_heatmap(\n",
    "    heat_df,\n",
    "    model_names,\n",
    "    metric=\"f1\",\n",
    "    ratio=(5, 3),\n",
    "    scale=2\n",
    "):\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    im = ax.imshow(\n",
    "        heat_df.values,\n",
    "        aspect=\"auto\",\n",
    "        origin=\"lower\"\n",
    "    )\n",
    "\n",
    "    ax.set_yticks(range(len(heat_df.index)))\n",
    "    ax.set_yticklabels([model_names[mid] for mid in heat_df.index])\n",
    "\n",
    "    ax.set_xticks(range(len(heat_df.columns)))\n",
    "    ax.set_xticklabels(\n",
    "        [f\"{int(100*c)}%\" for c in heat_df.columns]\n",
    "    )\n",
    "\n",
    "    ax.set_xlabel(\"Training fraction\")\n",
    "    ax.set_title(metric.upper())\n",
    "\n",
    "    fig.colorbar(im, ax=ax, label=metric.upper())\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "df_heat = prepare_heatmap_df(agg_f1)\n",
    "plot_metric_heatmap(df_heat, model_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "484fb394",
   "metadata": {},
   "source": [
    "## How prompts behave in a model in terms of accuracy, order by f1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2f99ddf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_prompt_heatmap_per_model(\n",
    "    df_prompt,\n",
    "    model_id,\n",
    "    model_name,\n",
    "    sort_by=\"metric_mean\",\n",
    "    value_col=\"score_mean\",\n",
    "    metric=\"f1\",\n",
    "    score='accuracy',\n",
    "    ratio=(3, 5),\n",
    "    scale=2\n",
    "):\n",
    "    \"\"\"\n",
    "    Heatmap of prompts for one model, ordered by sort_by.\n",
    "    \"\"\"\n",
    "\n",
    "    df_m = (\n",
    "        df_prompt[df_prompt[\"model_id\"] == model_id]\n",
    "        .sort_values(sort_by)\n",
    "    )\n",
    "\n",
    "    values = df_m[value_col].values[:, None]  # (n_prompts, 1)\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    im = ax.imshow(\n",
    "        values,\n",
    "        aspect=\"auto\",\n",
    "        origin=\"lower\"\n",
    "    )\n",
    "\n",
    "    ax.set_yticks(range(len(df_m)))\n",
    "    ax.set_yticklabels(df_m[\"prompt_id\"])\n",
    "    ax.set_xticks([0])\n",
    "    ax.set_xticklabels([value_col])\n",
    "\n",
    "    ax.set_title(f\"{model_name}\\nPrompts ordered by {metric}\")\n",
    "\n",
    "    fig.colorbar(im, ax=ax, label=value_col)\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "results_lp_max = results_lp[results_lp[\"train_fraction\"]==1]\n",
    "agg_per_prompt = aggregate_metric_over_prompts(results_lp_max, metric=\"f1\", agg_prompts=False, agg_models=False, agg_train_frac=True)\n",
    "plot_prompt_heatmap_per_model(agg_per_prompt, 0, model_names[0], metric=\"f1\", score='accuracy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "237042c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_prompt_heatmap_per_model(\n",
    "    agg_df,\n",
    "    model_id,\n",
    "    model_name,\n",
    "    ratio=(5, 4),\n",
    "    scale=2\n",
    "):\n",
    "    df_m = agg_df[agg_df[\"model_id\"] == model_id]\n",
    "    \n",
    "    maxn = df_m['mean_n_train'].max()\n",
    "\n",
    "    heat = df_m.pivot(\n",
    "        index=\"prompt_id\",\n",
    "        columns=\"train_fraction\",\n",
    "        values=\"metric_mean\"\n",
    "    )\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    im = ax.imshow(\n",
    "        heat.values,\n",
    "        aspect=\"auto\",\n",
    "        origin=\"lower\"\n",
    "    )\n",
    "\n",
    "    ax.set_yticks(range(len(heat.index)))\n",
    "    ax.set_yticklabels(heat.index)\n",
    "\n",
    "    ax.set_xticks(range(len(heat.columns)))\n",
    "    ax.set_xticklabels([f\"{int(n*maxn)}\" for n in heat.columns])\n",
    "\n",
    "    ax.set_xlabel(\"Training set size\")\n",
    "    ax.set_title(f\"{model_name}\\nPrompt-level F1 vs training size\")\n",
    "\n",
    "    fig.colorbar(im, ax=ax, label=\"Mean F1\")\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "agg = aggregate_metric_over_prompts(\n",
    "    results_lp,\n",
    "    metric=\"f1\",\n",
    "    agg_prompts=False,\n",
    "    agg_models=False,\n",
    "    agg_train_frac=False\n",
    ")\n",
    "\n",
    "plot_prompt_heatmap_per_model(agg, 0, model_names[0])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14a7a6f3",
   "metadata": {},
   "source": [
    "## Which prompts are intrinsically hard, regardless of model?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d1ab800",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_prompt_difficulty(df, metric=\"f1\"):\n",
    "    return (\n",
    "        df\n",
    "        .groupby(\"prompt_id\")\n",
    "        .agg(\n",
    "            mean_f1=(metric, \"mean\"),\n",
    "            std_f1=(metric, \"std\"),\n",
    "        )\n",
    "        .sort_values(\"mean_f1\")\n",
    "        .reset_index()\n",
    "    )\n",
    "\n",
    "def plot_prompt_difficulty(df_prompt, ratio=(5, 3), scale=2):\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    ax.barh(\n",
    "        df_prompt[\"prompt_id\"],\n",
    "        df_prompt[\"mean_f1\"],\n",
    "        xerr=df_prompt[\"std_f1\"]\n",
    "    )\n",
    "\n",
    "    ax.set_xlabel(\"Mean F1 across models\")\n",
    "    ax.set_title(\"Prompt difficulty ranking\")\n",
    "    ax.grid(True, axis=\"x\", alpha=0.4)\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "df_hard = compute_prompt_difficulty(results_lp_max)\n",
    "plot_prompt_difficulty(df_hard)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5169ede0",
   "metadata": {},
   "source": [
    "## Are models failing on the same prompts?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80c5deeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_model_agreement(df_prompt, metric=\"metric_mean\"):\n",
    "    pivot = df_prompt.pivot(\n",
    "        index=\"prompt_id\",\n",
    "        columns=\"model_id\",\n",
    "        values=metric\n",
    "    )\n",
    "\n",
    "    return pivot.corr()\n",
    "\n",
    "def plot_model_agreement_heatmap(\n",
    "    corr_df,\n",
    "    model_names,\n",
    "    ratio=(4, 3),\n",
    "    scale=2\n",
    "):\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    im = ax.imshow(corr_df.values, vmin=-1, vmax=1)\n",
    "\n",
    "    ax.set_xticks(range(len(corr_df)))\n",
    "    ax.set_yticks(range(len(corr_df)))\n",
    "\n",
    "    ax.set_xticklabels([model_names[m] for m in corr_df.columns], rotation=30)\n",
    "    ax.set_yticklabels([model_names[m] for m in corr_df.index])\n",
    "\n",
    "    fig.colorbar(im, ax=ax, label=\"Prompt-level F1 correlation\")\n",
    "    # ax.set_title(\"Model agreement on prompt difficulty\")\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "df_agree = compute_model_agreement(agg_per_prompt)\n",
    "plot_model_agreement_heatmap(df_agree, model_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfa54736",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_model_agreement(df_prompt, value_col):\n",
    "    \"\"\"\n",
    "    Compute prompt-level correlation between models\n",
    "    for a given aggregated value.\n",
    "    \"\"\"\n",
    "    pivot = df_prompt.pivot(\n",
    "        index=\"prompt_id\",\n",
    "        columns=\"model_id\",\n",
    "        values=value_col\n",
    "    )\n",
    "    return pivot.corr()\n",
    "\n",
    "corr_metric = compute_model_agreement(agg_per_prompt, \"metric_mean\")\n",
    "corr_score  = compute_model_agreement(agg_per_prompt, \"score_mean\")\n",
    "\n",
    "def plot_model_agreement_bivariate_heatmap(\n",
    "    corr_metric,\n",
    "    corr_score,\n",
    "    model_names,\n",
    "    ratio=(4, 4),\n",
    "    scale=2,\n",
    "    cmap_metric=\"coolwarm\",\n",
    "    cmap_score=\"PiYG\"\n",
    "):\n",
    "    models = corr_metric.index.tolist()\n",
    "    n = len(models)\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    for i in range(n):\n",
    "        for j in range(n):\n",
    "\n",
    "            if i == j:\n",
    "                continue\n",
    "\n",
    "            if i < j:\n",
    "                val = corr_metric.iloc[i, j]\n",
    "                cmap = plt.get_cmap(cmap_metric)\n",
    "            else:\n",
    "                val = corr_score.iloc[i, j]\n",
    "                cmap = plt.get_cmap(cmap_score)\n",
    "\n",
    "            ax.imshow(\n",
    "                [[val]],\n",
    "                extent=(j, j+1, i+1, i),\n",
    "                vmin=-1,\n",
    "                vmax=1,\n",
    "                cmap=cmap\n",
    "            )\n",
    "\n",
    "    ax.set_xticks(np.arange(n) + 0.5)\n",
    "    ax.set_yticks(np.arange(n) + 0.5)\n",
    "\n",
    "    ax.set_xticklabels([model_names[m] for m in models], rotation=30)\n",
    "    ax.set_yticklabels([model_names[m] for m in models])\n",
    "\n",
    "    ax.set_xlim(0, n)\n",
    "    ax.set_ylim(n, 0)\n",
    "\n",
    "    ax.set_title(\"Model agreement on prompts\\nUpper: F1, Lower: Accuracy\")\n",
    "\n",
    "    # colourbars\n",
    "    sm1 = plt.cm.ScalarMappable(cmap=cmap_metric, norm=plt.Normalize(-1, 1))\n",
    "    sm2 = plt.cm.ScalarMappable(cmap=cmap_score,  norm=plt.Normalize(-1, 1))\n",
    "    sm1.set_array([])\n",
    "    sm2.set_array([])\n",
    "\n",
    "    cbar1 = fig.colorbar(sm1, ax=ax, fraction=0.046, pad=0.04)\n",
    "    cbar1.set_label(\"F1 correlation\")\n",
    "\n",
    "    cbar2 = fig.colorbar(sm2, ax=ax, fraction=0.046, pad=0.12)\n",
    "    cbar2.set_label(\"Accuracy correlation\")\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "plot_model_agreement_bivariate_heatmap(\n",
    "    corr_metric,\n",
    "    corr_score,\n",
    "    model_names\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fd0e466",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_pair_curve(df, model_i, model_j, value_col, min_points=2):\n",
    "    \"\"\"\n",
    "    Returns a qualitative mean curve over mean_n_train\n",
    "    for two models, aligned on common training sizes.\n",
    "\n",
    "    If insufficient overlap, returns (None, None).\n",
    "    \"\"\"\n",
    "    di = df[df[\"model_id\"] == model_i]\n",
    "    dj = df[df[\"model_id\"] == model_j]\n",
    "\n",
    "    mi = di.groupby(\"train_fraction\")[value_col].mean()\n",
    "    mj = dj.groupby(\"train_fraction\")[value_col].mean()\n",
    "\n",
    "    # align on common mean_n_train\n",
    "    mi_aligned, mj_aligned = mi.align(mj, join=\"inner\")\n",
    "\n",
    "    if len(mi_aligned) < min_points:\n",
    "        return None, None\n",
    "\n",
    "    xs = mi_aligned.index.values\n",
    "    ys = 0.5 * (mi_aligned.values + mj_aligned.values)\n",
    "\n",
    "    return xs, ys\n",
    "\n",
    "def plot_model_agreement_bivariate_heatmap_with_internals(\n",
    "    corr_metric,\n",
    "    corr_score,\n",
    "    agg_pf,\n",
    "    model_names,\n",
    "    ratio=(4, 4),\n",
    "    scale=2,\n",
    "    cmap_metric=\"coolwarm\",\n",
    "    cmap_score=\"PiYG\"\n",
    "):\n",
    "    models = corr_metric.index.tolist()\n",
    "    n = len(models)\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    for i in range(n):\n",
    "        for j in range(n):\n",
    "\n",
    "            if i == j:\n",
    "                continue\n",
    "\n",
    "            if i < j:\n",
    "                val = corr_metric.iloc[i, j]\n",
    "                cmap = plt.get_cmap(cmap_metric)\n",
    "            else:\n",
    "                val = corr_score.iloc[i, j]\n",
    "                cmap = plt.get_cmap(cmap_score)\n",
    "\n",
    "            ax.imshow(\n",
    "                [[val]],\n",
    "                extent=(j, j+1, i+1, i),\n",
    "                vmin=-1,\n",
    "                vmax=1,\n",
    "                cmap=cmap\n",
    "            )\n",
    "            # draw qualitative curve\n",
    "            xs, ys = extract_pair_curve(\n",
    "                agg_pf,\n",
    "                models[i],\n",
    "                models[j],\n",
    "                \"metric_mean\" if i < j else \"score_mean\"\n",
    "            )\n",
    "\n",
    "            xs = (xs - xs.min()) / (xs.max() - xs.min() + 1e-12)\n",
    "            ys = (ys - ys.min()) / (ys.max() - ys.min() + 1e-12)\n",
    "\n",
    "            ax.plot(\n",
    "                j + xs,\n",
    "                i + 1 - ys,\n",
    "                color=\"black\",\n",
    "                linewidth=1,\n",
    "                alpha=0.7\n",
    "            )\n",
    "\n",
    "    ax.set_xticks(np.arange(n) + 0.5)\n",
    "    ax.set_yticks(np.arange(n) + 0.5)\n",
    "\n",
    "    ax.set_xticklabels([model_names[m] for m in models], rotation=30)\n",
    "    ax.set_yticklabels([model_names[m] for m in models])\n",
    "\n",
    "    ax.set_xlim(0, n)\n",
    "    ax.set_ylim(n, 0)\n",
    "\n",
    "    ax.set_title(\"Model agreement on prompts\\nUpper: F1, Lower: Accuracy\")\n",
    "\n",
    "    # colourbars\n",
    "    sm1 = plt.cm.ScalarMappable(cmap=cmap_metric, norm=plt.Normalize(-1, 1))\n",
    "    sm2 = plt.cm.ScalarMappable(cmap=cmap_score,  norm=plt.Normalize(-1, 1))\n",
    "    sm1.set_array([])\n",
    "    sm2.set_array([])\n",
    "\n",
    "    cbar1 = fig.colorbar(sm1, ax=ax, fraction=0.046, pad=0.04)\n",
    "    cbar1.set_label(\"F1 correlation\")\n",
    "\n",
    "    cbar2 = fig.colorbar(sm2, ax=ax, fraction=0.046, pad=0.12)\n",
    "    cbar2.set_label(\"Accuracy correlation\")\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "agg_pf = aggregate_metric_over_prompts(\n",
    "    results_lp,\n",
    "    metric=\"f1\",\n",
    "    agg_prompts=False,\n",
    "    agg_models=False,\n",
    "    agg_train_frac=False\n",
    ")\n",
    "\n",
    "plot_model_agreement_bivariate_heatmap_with_internals(\n",
    "    corr_metric,\n",
    "    corr_score,\n",
    "    agg_pf,\n",
    "    model_names\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40a21193",
   "metadata": {},
   "source": [
    "## Compare mean accuracy and std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64188067",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_mean_vs_variability(df_prompt, model_names, ratio=(4, 3), scale=2):\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    for mid, name in model_names.items():\n",
    "        vals = df_prompt[df_prompt[\"model_id\"] == mid][\"metric_mean\"]\n",
    "        ax.scatter(\n",
    "            vals.mean(),\n",
    "            vals.std(),\n",
    "            label=name\n",
    "        )\n",
    "\n",
    "    ax.set_xlabel(\"Mean F1 across prompts\")\n",
    "    ax.set_ylabel(\"Std F1 across prompts\")\n",
    "    ax.set_title(\"Performance vs stability\")\n",
    "    ax.legend()\n",
    "    ax.grid(True, alpha=0.4)\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "plot_mean_vs_variability(agg_per_prompt, model_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7a60fa3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_mean_vs_variability_trajectories(\n",
    "    agg_df,\n",
    "    model_names,\n",
    "    metric_label='metric',\n",
    "    metric_name='F1',\n",
    "    ratio=(4, 3),\n",
    "    scale=2\n",
    "):\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    for mid, name in model_names.items():\n",
    "        df_m = agg_df[agg_df[\"model_id\"] == mid].sort_values(\"mean_n_train\")\n",
    "\n",
    "        ax.plot(\n",
    "            df_m[f\"{metric_label}_mean\"],\n",
    "            df_m[f\"{metric_label}_std\"],\n",
    "            marker=\"o\",\n",
    "            label=name\n",
    "        )\n",
    "\n",
    "        # Optional: annotate final point\n",
    "        ax.annotate(\n",
    "            name,\n",
    "            (df_m[f\"{metric_label}_mean\"].iloc[-1], df_m[f\"{metric_label}_std\"].iloc[-1]),\n",
    "            fontsize=8\n",
    "        )\n",
    "\n",
    "    ax.set_xlabel(f\"Mean {metric_name} across prompts\")\n",
    "    ax.set_ylabel(f\"{metric_name} standard deviation across prompts\")\n",
    "    # ax.set_title(\"Performance-stability trajectories\")\n",
    "    ax.grid(True, alpha=0.4)\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "agg_lc = aggregate_metric_over_prompts(\n",
    "    results_lp,\n",
    "    metric=\"f1\",\n",
    "    agg_prompts=True,\n",
    "    agg_models=False,\n",
    "    agg_train_frac=False\n",
    ")\n",
    "\n",
    "plot_mean_vs_variability_trajectories(agg_lc, model_names)\n",
    "plot_mean_vs_variability_trajectories(agg_lc, model_names, metric_name='Accuracy', metric_label=\"score\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "823813cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def alpha_from_fraction(tf, min_alpha=0.15, max_alpha=0.95):\n",
    "    return min_alpha + (max_alpha - min_alpha) * tf\n",
    "\n",
    "def label_offset_from_angle(angle_deg, radius):\n",
    "    angle_rad = np.deg2rad(angle_deg)\n",
    "    dx = radius * np.cos(angle_rad)\n",
    "    dy = radius * np.sin(angle_rad)\n",
    "\n",
    "    ha = \"left\" if -90 <= angle_deg <= 90 else \"right\"\n",
    "    return dx, dy, ha\n",
    "\n",
    "def text_with_display_offset(ax, x, y, text, angle_deg, radius_pts,\n",
    "                             color, fontsize=9):\n",
    "    angle_rad = np.deg2rad(angle_deg)\n",
    "    dx = radius_pts * np.cos(angle_rad)\n",
    "    dy = radius_pts * np.sin(angle_rad)\n",
    "\n",
    "    ha = \"left\" if -90 <= angle_deg <= 90 else \"right\"\n",
    "\n",
    "    text_transform = offset_copy(\n",
    "        ax.transData,\n",
    "        fig=ax.figure,\n",
    "        x=dx,\n",
    "        y=dy,\n",
    "        units=\"points\"\n",
    "    )\n",
    "\n",
    "    ax.text(\n",
    "        x, y, text,\n",
    "        transform=text_transform,\n",
    "        fontsize=fontsize,\n",
    "        ha=ha,\n",
    "        va=\"center\",\n",
    "        color=color,\n",
    "    )\n",
    "\n",
    "def plot_mean_vs_variability_trajectories_two_panels(\n",
    "    agg_df,\n",
    "    model_names,\n",
    "    metrics=(\n",
    "        (\"metric\", \"F1\"),\n",
    "        (\"score\", \"Accuracy\"),\n",
    "    ),\n",
    "    label_angles_top=None,\n",
    "    label_angles_bottom=None,\n",
    "    label_radius=10,\n",
    "    ratio=(4, 6),\n",
    "    scale=2,\n",
    "    cmap_name=\"tab10\",\n",
    "):\n",
    "    fig, axes = plt.subplots(\n",
    "        2, 1,\n",
    "        figsize=[scale * x for x in ratio],\n",
    "        sharex=True,\n",
    "        gridspec_kw={\"hspace\": 0.15},\n",
    "    )\n",
    "\n",
    "    cmap = plt.get_cmap(cmap_name)\n",
    "\n",
    "    for ax, (metric_label, metric_name) in zip(axes, metrics):\n",
    "        for i, (mid, name) in enumerate(model_names.items()):\n",
    "            df_m = (\n",
    "                agg_df[agg_df[\"model_id\"] == mid]\n",
    "                .sort_values(\"mean_n_train\")\n",
    "            )\n",
    "            if df_m.empty:\n",
    "                continue\n",
    "\n",
    "            color = cmap(i % cmap.N)\n",
    "\n",
    "            x = df_m[f\"{metric_label}_mean\"].values\n",
    "            y = df_m[f\"{metric_label}_std\"].values\n",
    "            tfs = df_m[\"mean_n_train\"].values\n",
    "            tfs = tfs / tfs.max()  # normalise to [0, 1]\n",
    "\n",
    "            # ---- faint trajectory line\n",
    "            ax.plot(\n",
    "                x, y,\n",
    "                color=color,\n",
    "                lw=1,\n",
    "                alpha=0.35,\n",
    "            )\n",
    "\n",
    "            # ---- points with fading alpha\n",
    "            for xi, yi, tf in zip(x, y, tfs):\n",
    "                ax.scatter(\n",
    "                    xi, yi,\n",
    "                    color=color,\n",
    "                    s=35,\n",
    "                    alpha=alpha_from_fraction(tf),\n",
    "                    zorder=3,\n",
    "                )\n",
    "\n",
    "            # ---- best point (star)\n",
    "            best_idx = np.argmax(x)\n",
    "            ax.scatter(\n",
    "                x[best_idx],\n",
    "                y[best_idx],\n",
    "                color=color,\n",
    "                s=140,\n",
    "                marker=\"*\",\n",
    "                edgecolor=\"black\",\n",
    "                linewidth=0.8,\n",
    "                zorder=5,\n",
    "            )\n",
    "\n",
    "            # ---- label to the right of best point\n",
    "            label_angles = (\n",
    "                label_angles_top if metric_label == \"metric\"\n",
    "                else label_angles_bottom\n",
    "            )\n",
    "\n",
    "            angle = label_angles.get(name, 0.0)\n",
    "\n",
    "            text_with_display_offset(\n",
    "                ax,\n",
    "                x[best_idx],\n",
    "                y[best_idx],\n",
    "                name,\n",
    "                angle_deg=angle,\n",
    "                radius_pts=label_radius,\n",
    "                color=color,\n",
    "                fontsize=12,\n",
    "            )\n",
    "\n",
    "        ax.set_ylabel(f\"{metric_name} std across prompts\")\n",
    "        ax.grid(True, alpha=0.4)\n",
    "\n",
    "    axes[-1].set_xlabel(\"Mean performance across prompts\")\n",
    "\n",
    "    # fig.suptitle(\n",
    "    #     \"Performance–stability trajectories\\n(fading = increasing training size)\",\n",
    "    #     y=0.97,\n",
    "    #     fontsize=13,\n",
    "    # )\n",
    "\n",
    "    return fig, axes\n",
    "\n",
    "\n",
    "label_angles_f1 = {\n",
    "    'Mistral-7B' : -150,\n",
    "    'Gemma-9B' : -150,\n",
    "    'Solar-11B' : 0,\n",
    "    'Phi-14B' : 0,\n",
    "    'Qwen-32B' : 0,\n",
    "    'Gemma-27B' : 30,\n",
    "    'Qwen-15B' : 0,\n",
    "}\n",
    "\n",
    "label_angles_acc = {\n",
    "    'Mistral-7B' : 0,\n",
    "    'Gemma-9B' : -150,\n",
    "    'Solar-11B' : 180,\n",
    "    'Phi-14B' : 0,\n",
    "    'Qwen-32B' : -90,\n",
    "    'Gemma-27B' : -60,\n",
    "    'Qwen-15B' : 0,\n",
    "}\n",
    "\n",
    "fig, axes = plot_mean_vs_variability_trajectories_two_panels(\n",
    "    agg_lc,\n",
    "    model_names,\n",
    "    label_angles_top=label_angles_f1,\n",
    "    label_angles_bottom=label_angles_acc,\n",
    "    ratio=(5, 6),\n",
    "    scale=1.5,\n",
    ")\n",
    "\n",
    "fig.savefig(\"img/LP_performances.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00ea0f95",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "1b1c72e6",
   "metadata": {},
   "source": [
    "## Lambda sensitivity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40bf0753",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_lambda_sensitivity_experiment(\n",
    "    df,\n",
    "    model_id,\n",
    "    prompt_id,\n",
    "    lambda_values,\n",
    "    test_fraction=0.2,\n",
    "    n_splits=5,\n",
    "    random_state=42,\n",
    "    logskip=False\n",
    "):\n",
    "    \"\"\"\n",
    "    Sensitivity analysis of Fisher regularisation parameter lambda.\n",
    "    Uses full training set (no subsampling).\n",
    "    \"\"\"\n",
    "\n",
    "    # ---- extract full dataset ----\n",
    "    X, y = extract_prompt_data(df, model_id, prompt_id)\n",
    "    n_H = sum(y)\n",
    "    n_G = len(y) - n_H\n",
    "    if n_H < 5 or n_G < 5:\n",
    "        if not logskip:\n",
    "            print(\n",
    "                f\"Skipping model {model_id}, prompt {prompt_id} \"\n",
    "                f\"(n_G={n_G}, n_H={n_H})\"\n",
    "            )\n",
    "        return None\n",
    "\n",
    "    # ---- fixed stratified splits ----\n",
    "    trn_sets, tst_sets = generate_fixed_test_sets(\n",
    "        X,\n",
    "        y,\n",
    "        n_splits=n_splits,\n",
    "        test_fraction=test_fraction,\n",
    "        random_state=random_state,\n",
    "    )\n",
    "\n",
    "    results = []\n",
    "\n",
    "    for split_id, ((X_train, y_train), (X_test, y_test)) in enumerate(\n",
    "        zip(trn_sets, tst_sets)\n",
    "    ):\n",
    "        X_G, X_H = split_by_label(X_train, y_train)\n",
    "        if len(X_G) < 2 or len(X_H) < 2:\n",
    "            continue\n",
    "\n",
    "        for lambda_reg in lambda_values:\n",
    "            detector = FisherWassersteinDetector(lambda_reg=lambda_reg)\n",
    "            detector.fit(X_G, X_H)\n",
    "\n",
    "            evaluator = LabelPropEvaluator(detector, X_test, y_test)\n",
    "            metrics = evaluator.evaluate()\n",
    "\n",
    "            metrics.update({\n",
    "                \"lambda_reg\": lambda_reg,\n",
    "                \"split_id\": split_id,\n",
    "                \"model_id\": model_id,\n",
    "                \"prompt_id\": prompt_id,\n",
    "                \"n_train\": len(X_train),\n",
    "                \"n_test\": len(X_test),\n",
    "            })\n",
    "\n",
    "            results.append(metrics)\n",
    "\n",
    "    if not results:\n",
    "        return None\n",
    "\n",
    "    return pd.DataFrame(results)\n",
    "\n",
    "def run_full_lambda_sensitivity_study(\n",
    "    df,\n",
    "    model_ids,\n",
    "    prompt_ids_by_model,\n",
    "    lambda_values,\n",
    "    test_fraction=0.2,\n",
    "    n_splits=5,\n",
    "    random_state=42,\n",
    "    use_cache=False,\n",
    "    cache_dir=\"cache/lambda_sensitivity\",\n",
    "    overwrite_cache=False,\n",
    "    logskip=False\n",
    "):\n",
    "    \"\"\"\n",
    "    Run lambda sensitivity analysis for all (model, prompt) pairs.\n",
    "    \"\"\"\n",
    "\n",
    "    if use_cache:\n",
    "        os.makedirs(cache_dir, exist_ok=True)\n",
    "\n",
    "    all_results = []\n",
    "\n",
    "    for mid in tqdm(model_ids, desc=\"Model\"):\n",
    "        for pid in tqdm(prompt_ids_by_model[mid], desc=\"Prompt\"):\n",
    "\n",
    "            cache_path = None\n",
    "            if use_cache:\n",
    "                fname = f\"model={mid}__prompt={pid}.parquet\"\n",
    "                cache_path = os.path.join(cache_dir, fname)\n",
    "\n",
    "                if os.path.exists(cache_path) and not overwrite_cache:\n",
    "                    all_results.append(pd.read_parquet(cache_path))\n",
    "                    continue\n",
    "\n",
    "            res_df = run_lambda_sensitivity_experiment(\n",
    "                df=df,\n",
    "                model_id=mid,\n",
    "                prompt_id=pid,\n",
    "                lambda_values=lambda_values,\n",
    "                test_fraction=test_fraction,\n",
    "                n_splits=n_splits,\n",
    "                random_state=random_state,\n",
    "                logskip=logskip\n",
    "            )\n",
    "\n",
    "            if res_df is None or len(res_df) == 0:\n",
    "                continue\n",
    "\n",
    "            all_results.append(res_df)\n",
    "\n",
    "            if use_cache:\n",
    "                res_df.to_parquet(cache_path, index=False)\n",
    "\n",
    "    if not all_results:\n",
    "        return pd.DataFrame()\n",
    "\n",
    "    results = pd.concat(all_results, ignore_index=True)\n",
    "\n",
    "    if use_cache:\n",
    "        meta = {\n",
    "            \"model_ids\": list(model_ids),\n",
    "            \"lambda_values\": list(lambda_values),\n",
    "            \"test_fraction\": test_fraction,\n",
    "            \"n_splits\": n_splits,\n",
    "            \"random_state\": random_state,\n",
    "            \"n_cached_pairs\": len(all_results),\n",
    "        }\n",
    "\n",
    "        with open(os.path.join(cache_dir, \"meta.json\"), \"w\") as f:\n",
    "            json.dump(meta, f, indent=2)\n",
    "\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47a3b737",
   "metadata": {},
   "outputs": [],
   "source": [
    "lambda_values = [\n",
    "    1e-4, 3e-4,\n",
    "    1e-3, 3e-3,\n",
    "    1e-2, 3e-2,\n",
    "    1e-1, 2e-1, 3e-1, 5e-1, 7e-1,\n",
    "    1e+0, 1.4e+0, 2e+0, 3e+0, 5e+0,\n",
    "    1e+1, 3e+1,\n",
    "    1e+2, 3e+2,\n",
    "]\n",
    "\n",
    "results_lambda = run_full_lambda_sensitivity_study(\n",
    "    df=df,\n",
    "    model_ids=model_names.keys(),\n",
    "    prompt_ids_by_model=[\n",
    "        [x for x in df[df[\"model_id\"] == mid][\"prompt_id\"].unique()]\n",
    "        for mid in model_names\n",
    "    ],\n",
    "    lambda_values=lambda_values,\n",
    "    test_fraction=0.2,\n",
    "    n_splits=5,\n",
    "    random_state=42,\n",
    "    use_cache=True,\n",
    "    cache_dir=cache_dir + \"/lambda-sensitivity\",\n",
    "    overwrite_cache=False,\n",
    "    logskip=True\n",
    ")\n",
    "\n",
    "# The regularisation parameter \\lambda was varied over a logarithmically spaced grid ranging from 10^{-6} to 10^{-1}, covering both weakly and strongly regularised regimes. This range was found sufficient to capture the transition from unstable Fisher directions to over-regularised, mean-difference–dominated projections."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1484c68",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_lambda_fine = run_full_lambda_sensitivity_study(\n",
    "    df=df,\n",
    "    model_ids=model_names.keys(),\n",
    "    prompt_ids_by_model=[\n",
    "        [x for x in df[df[\"model_id\"] == mid][\"prompt_id\"].unique()]\n",
    "        for mid in model_names\n",
    "    ],\n",
    "    lambda_values=np.arange(0.5, 4.5, .1),\n",
    "    test_fraction=0.2,\n",
    "    n_splits=5,\n",
    "    random_state=42,\n",
    "    use_cache=True,\n",
    "    cache_dir=cache_dir + \"/lambda-sensitivity-fine\",\n",
    "    overwrite_cache=False,\n",
    "    logskip=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6af4b9d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def aggregate_metric_over_lambda(\n",
    "    df,\n",
    "    metric=\"f1\",\n",
    "    score_metric=\"accuracy\",\n",
    "    agg_prompts=True,\n",
    "    agg_models=False,\n",
    "):\n",
    "    \"\"\"\n",
    "    Aggregate metrics over prompts, test splits and repetitions\n",
    "    as a function of the Fisher regularisation parameter lambda.\n",
    "\n",
    "    Returns one row per (model_id, lambda_reg) unless aggregation is requested.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    df : pd.DataFrame\n",
    "        Output of lambda sensitivity experiments\n",
    "    metric : str\n",
    "        Primary metric to aggregate (e.g. 'f1')\n",
    "    score_metric : str\n",
    "        Secondary metric (e.g. 'accuracy')\n",
    "    agg_prompts : bool\n",
    "        If True, aggregate over prompts\n",
    "    agg_models : bool\n",
    "        If True, aggregate over models as well\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    agg_df : pd.DataFrame\n",
    "        Columns:\n",
    "        - model_id (optional)\n",
    "        - lambda_reg\n",
    "        - metric_mean\n",
    "        - metric_std\n",
    "        - score_mean\n",
    "        - score_std\n",
    "        - n_runs\n",
    "    \"\"\"\n",
    "\n",
    "    group_cols = [\"lambda_reg\"]\n",
    "\n",
    "    if not agg_models:\n",
    "        group_cols.insert(0, \"model_id\")\n",
    "    if not agg_prompts:\n",
    "        group_cols.insert(1, \"prompt_id\")\n",
    "\n",
    "    agg_df = (\n",
    "        df\n",
    "        .groupby(group_cols)\n",
    "        .agg(\n",
    "            metric_mean=(metric, \"mean\"),\n",
    "            metric_std=(metric, \"std\"),\n",
    "            score_mean=(score_metric, \"mean\"),\n",
    "            score_std=(score_metric, \"std\"),\n",
    "            n_runs=(metric, \"count\"),\n",
    "        )\n",
    "        .reset_index()\n",
    "        .sort_values(\"lambda_reg\")\n",
    "    )\n",
    "\n",
    "    return agg_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1338b04",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_lambda_sensitivity(\n",
    "    agg_df,\n",
    "    model_names,\n",
    "    metric=\"f1\",\n",
    "    metric_label='metric',\n",
    "    ratio=(3, 2),\n",
    "    scale=3,\n",
    "    best_reg_lambda=None,\n",
    "):\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    for mid, name in model_names.items():\n",
    "        df_m = agg_df[agg_df[\"model_id\"] == mid]\n",
    "        if df_m.empty:\n",
    "            continue\n",
    "\n",
    "        x = np.log10(df_m[\"lambda_reg\"].values)\n",
    "\n",
    "        ax.plot(\n",
    "            x,\n",
    "            df_m[f\"{metric_label}_mean\"],\n",
    "            marker=\"o\",\n",
    "            label=name\n",
    "        )\n",
    "\n",
    "        ax.fill_between(\n",
    "            x,\n",
    "            df_m[f\"{metric_label}_mean\"] - df_m[f\"{metric_label}_std\"],\n",
    "            df_m[f\"{metric_label}_mean\"] + df_m[f\"{metric_label}_std\"],\n",
    "            alpha=0.2\n",
    "        )\n",
    "\n",
    "    if best_reg_lambda is not None:\n",
    "        ax.axvline(np.log10(best_reg_lambda), ls='--', color='k', label='Best parameter')\n",
    "\n",
    "    ax.set_xlabel(\"$\\\\log_{10}(\\\\lambda)$\")\n",
    "    ax.set_ylabel(metric.upper())\n",
    "    ax.grid(True, alpha=0.4)\n",
    "    ax.legend()\n",
    "    # ax.set_title(f\"{metric.upper()} vs Fisher regularisation\")\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "agg_lambda = aggregate_metric_over_lambda(\n",
    "    results_lambda,\n",
    "    metric=\"f1\",\n",
    "    score_metric=\"accuracy\",\n",
    "    agg_prompts=True,\n",
    "    agg_models=False\n",
    ")\n",
    "\n",
    "fig, ax = plot_lambda_sensitivity(\n",
    "    agg_lambda,\n",
    "    model_names=model_names,\n",
    "    metric=\"f1\",\n",
    "    best_reg_lambda=best_reg_lambda\n",
    ")\n",
    "\n",
    "fig.savefig(\"img/LP_lambda_sensitivity.pdf\", bbox_inches='tight')\n",
    "\n",
    "fig, ax = plot_lambda_sensitivity(\n",
    "    agg_lambda,\n",
    "    model_names=model_names,\n",
    "    metric=\"accuracy\",\n",
    "    metric_label=\"score\",\n",
    "    best_reg_lambda=best_reg_lambda\n",
    ")\n",
    "\n",
    "agg_lambda_fine = aggregate_metric_over_lambda(\n",
    "    results_lambda_fine,\n",
    "    metric=\"f1\",\n",
    "    score_metric=\"accuracy\",\n",
    "    agg_prompts=True,\n",
    "    agg_models=False,\n",
    ")\n",
    "fig, ax = plot_lambda_sensitivity(\n",
    "    agg_lambda_fine,\n",
    "    model_names=model_names,\n",
    "    metric=\"f1\",\n",
    "    best_reg_lambda=best_reg_lambda,\n",
    ")\n",
    "fig, ax = plot_lambda_sensitivity(\n",
    "    agg_lambda_fine,\n",
    "    model_names=model_names,\n",
    "    metric=\"accuracy\",\n",
    "    metric_label=\"score\",\n",
    "    best_reg_lambda=best_reg_lambda,\n",
    ")\n",
    "\n",
    "agg_lambda_global = aggregate_metric_over_lambda(\n",
    "    results_lambda,\n",
    "    metric=\"f1\",\n",
    "    agg_prompts=True,\n",
    "    agg_models=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2d902da",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_lambda_sensitivity_2d(\n",
    "    agg_df,\n",
    "    model_names,\n",
    "    metric=\"f1\",\n",
    "    metric_label=\"metric\",\n",
    "    ratio=(3, 2),\n",
    "    scale=3,\n",
    "    best_reg_lambda=None,\n",
    "):\n",
    "    \"\"\"\n",
    "    2D sensitivity plot:\n",
    "      x-axis: metric mean\n",
    "      y-axis: metric std\n",
    "\n",
    "    Each model is a trajectory over lambda.\n",
    "    The best lambda is highlighted with a star (nearest if needed).\n",
    "    \"\"\"\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    for mid, name in model_names.items():\n",
    "        df_m = agg_df[agg_df[\"model_id\"] == mid]\n",
    "        if df_m.empty:\n",
    "            continue\n",
    "\n",
    "        # ---- sort by lambda (important for line continuity)\n",
    "        df_m = df_m.sort_values(\"lambda_reg\")\n",
    "\n",
    "        x = df_m[f\"{metric_label}_mean\"].values\n",
    "        y = df_m[f\"{metric_label}_std\"].values\n",
    "\n",
    "        ax.plot(\n",
    "            x,\n",
    "            y,\n",
    "            marker=\"o\",\n",
    "            label=name,\n",
    "            alpha=0.9\n",
    "        )\n",
    "\n",
    "        # ---- highlight best lambda (nearest if not exact)\n",
    "        if best_reg_lambda is not None:\n",
    "            idx = np.argmin(np.abs(df_m[\"lambda_reg\"].values - best_reg_lambda))\n",
    "\n",
    "            ax.plot(\n",
    "                x[idx],\n",
    "                y[idx],\n",
    "                marker=\"*\",\n",
    "                markersize=14,\n",
    "                color=ax.lines[-1].get_color(),\n",
    "                zorder=5\n",
    "            )\n",
    "\n",
    "    ax.set_xlabel(f\"Mean {metric.upper()}\")\n",
    "    ax.set_ylabel(f\"Std {metric.upper()}\")\n",
    "    ax.grid(True, alpha=0.4)\n",
    "    ax.legend()\n",
    "    ax.set_title(f\"{metric.upper()} mean-variance trade-off vs Fisher regularisation\")\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "fig, ax = plot_lambda_sensitivity_2d(\n",
    "    agg_df=agg_lambda,\n",
    "    model_names=model_names,\n",
    "    metric=\"f1\",\n",
    "    metric_label=\"metric\",\n",
    "    best_reg_lambda=best_reg_lambda,\n",
    "    scale=3\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96bbb025",
   "metadata": {},
   "outputs": [],
   "source": [
    "agg_lambda.loc[agg_lambda.groupby('model_id')['metric_mean'].idxmax()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43d40b14",
   "metadata": {},
   "outputs": [],
   "source": [
    "agg_lambda"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29c38fb6",
   "metadata": {},
   "source": [
    "## Choose a single λ that minimises the average relative regret with respect to the best achievable λ for each (model, prompt) task.\n",
    "\n",
    "Let:\n",
    " - $s_{m,p}(\\lambda)$ be the score (e.g. F1) for model $m$, prompt $p$, and regularisation $\\lambda$\n",
    " - $s^*_{m,p} = \\max_{\\lambda} s_{m,p}(\\lambda)$\n",
    "\n",
    "Define the relative loss (regret):\n",
    "$$\n",
    "    \\ell_{m,p}(\\lambda)\n",
    "    \\;=\\;\n",
    "    \\frac{s^*_{m,p} - s_{m,p}(\\lambda)}{s^*_{m,p}}\n",
    "$$\n",
    "\n",
    "Then define the average relative loss:\n",
    "$$\n",
    "    L(\\lambda) \\;=\\; \\mathbb{E}_{m,p}[\\ell_{m,p}(\\lambda)]\n",
    "$$\n",
    "\n",
    "Your chosen $\\lambda$ is:\n",
    "$$\n",
    "    \\lambda^\\star = \\arg\\min_{\\lambda} L(\\lambda)\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50e9e8a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def aggregate_over_runs(\n",
    "    df,\n",
    "    metric=\"f1\"\n",
    "):\n",
    "    \"\"\"\n",
    "    Average metric over runs (splits, iterations, etc.)\n",
    "    \"\"\"\n",
    "    return (\n",
    "        df\n",
    "        .groupby([\"model_id\", \"prompt_id\", \"lambda_reg\"])[metric]\n",
    "        .mean()\n",
    "        .reset_index(name=metric)\n",
    "    )\n",
    "\n",
    "def compute_best_scores(df_agg, metric=\"f1\"):\n",
    "    \"\"\"\n",
    "    Compute s*_{m,p} = max_lambda s(m,p,lambda)\n",
    "    \"\"\"\n",
    "    return (\n",
    "        df_agg\n",
    "        .groupby([\"model_id\", \"prompt_id\"])[metric]\n",
    "        .max()\n",
    "        .reset_index(name=\"best_score\")\n",
    "    )\n",
    "\n",
    "def compute_relative_loss(df_agg, df_best, metric=\"f1\"):\n",
    "    \"\"\"\n",
    "    Add relative loss column\n",
    "    \"\"\"\n",
    "    df = df_agg.merge(\n",
    "        df_best,\n",
    "        on=[\"model_id\", \"prompt_id\"],\n",
    "        how=\"left\"\n",
    "    )\n",
    "\n",
    "    df[\"relative_loss\"] = (\n",
    "        (df[\"best_score\"] - df[metric]) / df[\"best_score\"]\n",
    "    )\n",
    "\n",
    "    return df\n",
    "\n",
    "def aggregate_relative_loss(df_rel):\n",
    "    \"\"\"\n",
    "    Compute L(lambda)\n",
    "    \"\"\"\n",
    "    return (\n",
    "        df_rel\n",
    "        .groupby(\"lambda_reg\")[\"relative_loss\"]\n",
    "        .agg([\"mean\", \"std\", \"max\"])\n",
    "        .reset_index()\n",
    "        .rename(columns={\n",
    "            \"mean\": \"rel_loss_mean\",\n",
    "            \"std\": \"rel_loss_std\",\n",
    "            \"max\": \"rel_loss_max\"\n",
    "        })\n",
    "    )\n",
    "\n",
    "def select_lambda_min_regret(df_lambda_loss):\n",
    "    idx = df_lambda_loss[\"rel_loss_mean\"].idxmin()\n",
    "    return df_lambda_loss.loc[idx]\n",
    "\n",
    "df_agg = aggregate_over_runs(results_lambda_fine, metric=\"accuracy\")\n",
    "df_best = compute_best_scores(df_agg, metric=\"accuracy\")\n",
    "df_rel  = compute_relative_loss(df_agg, df_best, metric=\"accuracy\")\n",
    "df_loss = aggregate_relative_loss(df_rel)\n",
    "\n",
    "lambda_star = select_lambda_min_regret(df_loss)\n",
    "\n",
    "print(\"lambda_star (acc)\", lambda_star, '\\n')\n",
    "\n",
    "df_agg = aggregate_over_runs(results_lambda_fine, metric=\"f1\")\n",
    "df_best = compute_best_scores(df_agg, metric=\"f1\")\n",
    "df_rel  = compute_relative_loss(df_agg, df_best, metric=\"f1\")\n",
    "df_loss = aggregate_relative_loss(df_rel)\n",
    "\n",
    "lambda_star = select_lambda_min_regret(df_loss)\n",
    "\n",
    "print(\"lambda_star (f1)\", lambda_star, '\\n')\n",
    "\n",
    "lambda_star_worst = (\n",
    "    df_loss\n",
    "    .sort_values(\"rel_loss_max\")\n",
    "    .iloc[0]\n",
    ")\n",
    "\n",
    "print(\"lambda_star_worst\", lambda_star_worst)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "592895d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_average_best_lambda(df_agg, metric=\"f1\"):\n",
    "    best_per_task = (\n",
    "        df_agg\n",
    "        .sort_values(metric, ascending=False)\n",
    "        .groupby([\"model_id\", \"prompt_id\"])\n",
    "        .first()\n",
    "        .reset_index()\n",
    "    )\n",
    "\n",
    "    return {\n",
    "        \"mean_best_lambda\": best_per_task[\"lambda_reg\"].mean(),\n",
    "        \"median_best_lambda\": best_per_task[\"lambda_reg\"].median()\n",
    "    }\n",
    "\n",
    "avg_best = compute_average_best_lambda(df_agg)\n",
    "\n",
    "print(\"Min-regret lambda:\", lambda_star[\"lambda_reg\"])\n",
    "print(\"Average best lambda:\", avg_best)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2621fd62",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(6, 4))\n",
    "\n",
    "ax.plot(\n",
    "    df_loss[\"lambda_reg\"],\n",
    "    df_loss[\"rel_loss_mean\"],\n",
    "    marker=\"o\"\n",
    ")\n",
    "\n",
    "ax.fill_between(\n",
    "    df_loss[\"lambda_reg\"],\n",
    "    df_loss[\"rel_loss_mean\"] - df_loss[\"rel_loss_std\"],\n",
    "    df_loss[\"rel_loss_mean\"] + df_loss[\"rel_loss_std\"],\n",
    "    alpha=0.2\n",
    ")\n",
    "\n",
    "ax.axvline(\n",
    "    lambda_star[\"lambda_reg\"],\n",
    "    color=\"red\",\n",
    "    linestyle=\"--\",\n",
    "    label=r\"$\\lambda^\\star$ (min regret)\"\n",
    ")\n",
    "\n",
    "ax.set_xscale(\"log\")\n",
    "ax.set_xlabel(r\"Regularisation $\\lambda$\")\n",
    "ax.set_ylabel(\"Relative performance loss\")\n",
    "ax.grid(True, alpha=0.4)\n",
    "ax.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8e46c4c",
   "metadata": {},
   "source": [
    "## Images for pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77eb9e02",
   "metadata": {},
   "outputs": [],
   "source": [
    "mid = 2  #model id\n",
    "pid = 82 #prompt ic\n",
    "rid = 11 #response id\n",
    "\n",
    "def train_test_split_80_20(X, y, random_state=42):\n",
    "    sss = StratifiedShuffleSplit(\n",
    "        n_splits=1, test_size=0.2, random_state=random_state\n",
    "    )\n",
    "    trn_idx, tst_idx = next(sss.split(X, y))\n",
    "    return X[trn_idx], X[tst_idx], y[trn_idx], y[tst_idx], trn_idx, tst_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "830d9696",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_distance_block(geometry_store, key, space):\n",
    "    \"\"\"\n",
    "    Returns a dict with keys D_GG, D_HH, D_GH\n",
    "    \"\"\"\n",
    "    d = geometry_store[key]\n",
    "\n",
    "    key = {'embedding': '', 'fisher': '_z'}\n",
    "\n",
    "    if space in d:  # nested\n",
    "        return d[space]\n",
    "\n",
    "    # flattened fallback\n",
    "    return {\n",
    "        \"D_GG\": d[f\"D_GG{key[space]}\"],\n",
    "        \"D_HH\": d[f\"D_HH{key[space]}\"],\n",
    "        \"D_GH\": d[f\"D_GH{key[space]}\"],\n",
    "    }\n",
    "\n",
    "def plot_distance_violin(\n",
    "    geometry_store,\n",
    "    key,\n",
    "    space=\"embedding\",\n",
    "    colors=None,\n",
    "    figsize=(5, 4),\n",
    "    savepath=None,\n",
    "    transparent=False,\n",
    "    rotate=False\n",
    "):\n",
    "    \"\"\"\n",
    "    Half-violin plot for GG / HH and boxplot for GH\n",
    "    in the selected space ('embedding' or 'fisher')\n",
    "    \"\"\"\n",
    "    if colors is None:\n",
    "        colors = {\n",
    "            \"GG\": \"#6E9B34\",\n",
    "            \"HH\": \"#AA4D39\",\n",
    "            \"GH\": \"#27586B\",\n",
    "        }\n",
    "\n",
    "    d = get_distance_block(geometry_store, key, space)\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=figsize)\n",
    "\n",
    "    # ---- GG: upper half violin ----\n",
    "    vp = ax.violinplot(\n",
    "        d[\"D_GG\"],\n",
    "        positions=[0],\n",
    "        widths=0.5,\n",
    "        showmeans=False,\n",
    "        showmedians=False,\n",
    "        showextrema=False,\n",
    "        side=\"high\",\n",
    "        orientation=\"horizontal\",\n",
    "    )\n",
    "    for pc in vp[\"bodies\"]:\n",
    "        pc.set_facecolor(colors[\"GG\"])\n",
    "        pc.set_edgecolor(\"black\")\n",
    "        pc.set_linewidth(1.0)\n",
    "        pc.set_alpha(0.5)\n",
    "\n",
    "    # ---- HH: lower half violin ----\n",
    "    vp = ax.violinplot(\n",
    "        d[\"D_HH\"],\n",
    "        positions=[0],\n",
    "        widths=0.5,\n",
    "        showmeans=False,\n",
    "        showmedians=False,\n",
    "        showextrema=False,\n",
    "        side=\"low\",\n",
    "        orientation=\"horizontal\",\n",
    "    )\n",
    "    for pc in vp[\"bodies\"]:\n",
    "        pc.set_facecolor(colors[\"HH\"])\n",
    "        pc.set_edgecolor(\"black\")\n",
    "        pc.set_linewidth(1.0)\n",
    "        pc.set_alpha(0.5)\n",
    "\n",
    "    # ---- GH: boxplot ----\n",
    "    bp = ax.boxplot(\n",
    "        d[\"D_GH\"],\n",
    "        positions=[0],\n",
    "        widths=0.1,\n",
    "        patch_artist=True,\n",
    "        showfliers=False,\n",
    "        orientation=\"horizontal\",\n",
    "    )\n",
    "    for box in bp[\"boxes\"]:\n",
    "        box.set_facecolor(colors[\"GH\"])\n",
    "        box.set_edgecolor(\"black\")\n",
    "        box.set_alpha(0.9)\n",
    "\n",
    "    for elem in [\"whiskers\", \"caps\", \"medians\"]:\n",
    "        for artist in bp[elem]:\n",
    "            artist.set_color(\"black\")\n",
    "\n",
    "    xmin, xmax = ax.get_xlim()\n",
    "    y0 = 0\n",
    "\n",
    "    ax.annotate(\n",
    "        r\"$L_2$\",\n",
    "        xy=(-.001, y0),\n",
    "        xytext=(xmax*1.1, y0),\n",
    "        # textcoords=\"offset points\",\n",
    "        arrowprops=dict(\n",
    "            arrowstyle=\"<-\",\n",
    "            color=\"0.6\",\n",
    "            linewidth=1.2,\n",
    "            shrinkA=0,\n",
    "            shrinkB=0,\n",
    "        ),\n",
    "        ha=\"left\",\n",
    "        va=\"center\",\n",
    "        color=\"0.4\",\n",
    "        rotation=-90 if rotate else 0,\n",
    "        fontsize=40,\n",
    "        zorder=0,\n",
    "    )\n",
    "\n",
    "    ax.axvline(0, 0.25, 0.75, color='grey')\n",
    "\n",
    "    ax.set_axis_off()\n",
    "\n",
    "    if savepath is not None:\n",
    "        fig.savefig(savepath, bbox_inches=\"tight\", pad_inches=0, transparent=transparent)\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "\n",
    "plot_distance_violin(\n",
    "    geometry_store,\n",
    "    key=(mid, pid),\n",
    "    space=\"embedding\",\n",
    "    savepath=\"img/P_violin_embedding.pdf\",\n",
    "    transparent=True,\n",
    "    rotate=True\n",
    ")\n",
    "plot_distance_violin(\n",
    "    geometry_store,\n",
    "    key=(mid, pid),\n",
    "    space=\"fisher\",\n",
    "    savepath=\"img/P_violin_fisher.pdf\",\n",
    "    transparent=True,\n",
    "    rotate=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f544cd8d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6528cf48",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0e9dd94",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tsne_projection(\n",
    "    X_train, y_train,\n",
    "    X_test, y_test,\n",
    "    perplexity=30,\n",
    "    random_state=42\n",
    "):\n",
    "    X_all = np.vstack([X_train, X_test])\n",
    "    y_all = np.concatenate([y_train, y_test])\n",
    "\n",
    "    tsne = TSNE(\n",
    "        n_components=2,\n",
    "        perplexity=perplexity,\n",
    "        init=\"pca\",\n",
    "        learning_rate=\"auto\",\n",
    "        random_state=random_state,\n",
    "    )\n",
    "\n",
    "    Z_all = tsne.fit_transform(X_all)\n",
    "\n",
    "    n_tr = len(X_train)\n",
    "    Z_tr, Z_te = Z_all[:n_tr], Z_all[n_tr:]\n",
    "\n",
    "    return Z_tr, Z_te, y_train, y_test\n",
    "\n",
    "# COLORS = {\n",
    "#     0: \"#6E9B34\",  # G\n",
    "#     1: \"#AA4D39\",  # H\n",
    "# }\n",
    "\n",
    "# def plot_tsne_with_test_point(\n",
    "#     Z_train, y_train,\n",
    "#     Z_test=None, y_test=None,\n",
    "#     test_id=None,\n",
    "#     ax=None\n",
    "# ):\n",
    "#     if ax is None:\n",
    "#         fig, ax = plt.subplots(figsize=(5, 4))\n",
    "\n",
    "#     # ---- training clouds ----\n",
    "#     for cls in [0, 1]:\n",
    "#         mask = (y_train == cls)\n",
    "#         ax.scatter(\n",
    "#             Z_train[mask, 0],\n",
    "#             Z_train[mask, 1],\n",
    "#             s=22,\n",
    "#             alpha=0.5,\n",
    "#             c=COLORS[cls],\n",
    "#             marker='s' if cls == 0 else 'o',\n",
    "#             label=f\"Train {'G' if cls == 0 else 'H'}\",\n",
    "#         )\n",
    "\n",
    "#     # ---- selected test point ----\n",
    "#     if Z_test is not None:\n",
    "#         ax.scatter(\n",
    "#             Z_test[test_id, 0],\n",
    "#             Z_test[test_id, 1],\n",
    "#             s=90,\n",
    "#             c=\"black\",\n",
    "#             edgecolor=\"white\",\n",
    "#             linewidth=1.2,\n",
    "#             zorder=5,\n",
    "#             label=f\"Test {test_id}\",\n",
    "#         )\n",
    "\n",
    "#     ax.set_axis_off()\n",
    "#     # ax.legend(frameon=False, loc=\"best\")\n",
    "#     return fig, ax\n",
    "\n",
    "# X, y = extract_prompt_data(df, 2, 82)\n",
    "\n",
    "# X_trn, X_tst, y_trn, y_tst = train_test_split_80_20(\n",
    "#     X, y, random_state=42\n",
    "# )\n",
    "\n",
    "# Z_tr, Z_te, y_tr, y_te = tsne_projection(\n",
    "#     X_trn, y_trn, X_tst, y_tst,\n",
    "#     perplexity=30\n",
    "# )\n",
    "\n",
    "# for tid in range(10, 20):\n",
    "#     print(tid)\n",
    "#     plot_tsne_with_test_point(\n",
    "#         Z_tr, y_tr,\n",
    "#         Z_te, y_te,\n",
    "#         test_id=tid\n",
    "#     )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb1e6513",
   "metadata": {},
   "outputs": [],
   "source": [
    "EDGE_COLORS = {\n",
    "    \"GG\": \"#6E9B34\",  # green\n",
    "    \"HH\": \"#AA4D39\",  # red\n",
    "    \"GH\": \"#27586B\",  # blue\n",
    "}\n",
    "\n",
    "import networkx as nx\n",
    "import itertools\n",
    "\n",
    "def build_tsne_complete_graph(Z, y):\n",
    "    \"\"\"\n",
    "    Z : (n,2) t-SNE coordinates\n",
    "    y : (n,) class labels {0,1}\n",
    "    \"\"\"\n",
    "    G = nx.Graph()\n",
    "\n",
    "    # ---- nodes ----\n",
    "    for i, (pos, cls) in enumerate(zip(Z, y)):\n",
    "        G.add_node(\n",
    "            i,\n",
    "            pos=tuple(pos),\n",
    "            cls=int(cls),\n",
    "        )\n",
    "\n",
    "    # ---- complete edges ----\n",
    "    for i, j in itertools.combinations(range(len(Z)), 2):\n",
    "        ci, cj = y[i], y[j]\n",
    "\n",
    "        if ci == 0 and cj == 0:\n",
    "            etype = \"GG\"\n",
    "        elif ci == 1 and cj == 1:\n",
    "            etype = \"HH\"\n",
    "        else:\n",
    "            etype = \"GH\"\n",
    "\n",
    "        G.add_edge(i, j, etype=etype)\n",
    "\n",
    "    return G\n",
    "\n",
    "def plot_tsne_complete_graph(\n",
    "    G,\n",
    "    ax=None,\n",
    "    node_size=30,\n",
    "    edge_alphas={\n",
    "        \"GG\": 0.1,\n",
    "        \"HH\": 0.08,\n",
    "        \"GH\": 0.04,\n",
    "    },\n",
    "    edge_lw=0.6,\n",
    "    colors=None,\n",
    "):\n",
    "    if ax is None:\n",
    "        fig, ax = plt.subplots(figsize=(5, 4))\n",
    "    else:\n",
    "        fig = ax.figure\n",
    "\n",
    "    if colors is None:\n",
    "        colors = {\n",
    "            0: \"#6E9B34\",\n",
    "            1: \"#AA4D39\",\n",
    "        }\n",
    "\n",
    "    pos = nx.get_node_attributes(G, \"pos\")\n",
    "    cls = nx.get_node_attributes(G, \"cls\")\n",
    "\n",
    "    # ---- edges by type ----\n",
    "    for etype, color in EDGE_COLORS.items():\n",
    "        edges = [\n",
    "            (u, v)\n",
    "            for u, v, d in G.edges(data=True)\n",
    "            if d[\"etype\"] == etype\n",
    "        ]\n",
    "\n",
    "        nx.draw_networkx_edges(\n",
    "            G,\n",
    "            pos,\n",
    "            edgelist=edges,\n",
    "            edge_color=color,\n",
    "            alpha=edge_alphas[etype],\n",
    "            width=edge_lw,\n",
    "            ax=ax,\n",
    "        )\n",
    "\n",
    "    # ---- nodes ----\n",
    "    for c in [0, 1]:\n",
    "        nodes = [n for n in G.nodes if cls[n] == c]\n",
    "        nx.draw_networkx_nodes(\n",
    "            G,\n",
    "            pos,\n",
    "            nodelist=nodes,\n",
    "            node_color=colors[c],\n",
    "            node_size=node_size,\n",
    "            node_shape='s' if c == 0 else 'o',\n",
    "            alpha=0.8,\n",
    "            ax=ax,\n",
    "        )\n",
    "\n",
    "    ax.set_axis_off()\n",
    "    return fig, ax\n",
    "\n",
    "\n",
    "X, y = extract_prompt_data(df, 2, 82)\n",
    "\n",
    "X_trn, X_tst, y_trn, y_tst, trainIdxs, testIdxs = train_test_split_80_20(\n",
    "    X, y, random_state=42\n",
    ")\n",
    "\n",
    "# ---- t-SNE ----\n",
    "Z_tr, Z_te, y_tr, y_te = tsne_projection(\n",
    "    X_trn, y_trn,\n",
    "    X_tst, y_tst,\n",
    "    perplexity=30,\n",
    "    random_state=42,\n",
    ")\n",
    "\n",
    "# ---- graph ----\n",
    "G = build_tsne_complete_graph(Z_tr, y_tr)\n",
    "\n",
    "# ---- plot ----\n",
    "fig, ax = plot_tsne_complete_graph(G, )\n",
    "fig.savefig('img/P_embedding.pdf', bbox_inches='tight', transparent=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c9eeb7a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efc59ef1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # 1D intra-class distances\n",
    "# d_GG = pdist(z_G[:, None], metric=\"euclidean\").ravel()\n",
    "# d_HH = pdist(z_H[:, None], metric=\"euclidean\").ravel()\n",
    "\n",
    "# def draw_fisher_distance_violin(\n",
    "#     ax,\n",
    "#     z_center,\n",
    "#     distances,\n",
    "#     angle_deg=-30,\n",
    "#     width=0.08,\n",
    "#     color=\"C0\",\n",
    "#     alpha=0.35,\n",
    "#     zorder=1,\n",
    "#     side=\"both\",  # \"top\" or \"bottom\"\n",
    "# ):\n",
    "#     theta = np.deg2rad(angle_deg)\n",
    "#     d = np.array([np.cos(theta), np.sin(theta)])\n",
    "#     d_perp = np.array([-d[1], d[0]])\n",
    "\n",
    "#     # KDE over distances\n",
    "#     kde = gaussian_kde(distances)\n",
    "\n",
    "#     r = np.linspace(0, distances.max(), 300)\n",
    "#     density = kde(r)\n",
    "#     density = density / density.max() * width\n",
    "\n",
    "#     # anchor at class centre\n",
    "#     center = z_center * d\n",
    "\n",
    "#     curve = center + r[:, None] * d\n",
    "#     offset = density[:, None] * d_perp\n",
    "\n",
    "#     if side == \"top\":\n",
    "#         poly = np.vstack([curve, curve[::-1] + offset[::-1]])\n",
    "#     elif side == \"bottom\":\n",
    "#         poly = np.vstack([curve, curve[::-1] - offset[::-1]])\n",
    "#     else:\n",
    "#         poly = np.vstack([\n",
    "#             curve + offset,\n",
    "#             (curve - offset)[::-1]\n",
    "#         ])\n",
    "\n",
    "#     ax.fill(\n",
    "#         poly[:, 0],\n",
    "#         poly[:, 1],\n",
    "#         color=color,\n",
    "#         alpha=alpha,\n",
    "#         linewidth=0,\n",
    "#         zorder=zorder,\n",
    "#     )\n",
    "\n",
    "# zG_center = z_G.mean()\n",
    "# zH_center = z_H.mean()\n",
    "\n",
    "# fig, ax = plot_tsne_with_test_point(\n",
    "#     Z_tr_j, y_tr,\n",
    "#     Z_test=None, y_test=None,\n",
    "#     test_id=7\n",
    "# )\n",
    "\n",
    "# draw_fisher_distance_violin(\n",
    "#     ax,\n",
    "#     z_center=zG_center,\n",
    "#     distances=d_GG,\n",
    "#     angle_deg=-30,\n",
    "#     width=0.06,\n",
    "#     color=COLORS[0],\n",
    "#     side=\"bottom\"\n",
    "# )\n",
    "\n",
    "# draw_fisher_distance_violin(\n",
    "#     ax,\n",
    "#     z_center=zH_center,\n",
    "#     distances=d_HH,\n",
    "#     angle_deg=-30,\n",
    "#     width=0.06,\n",
    "#     color=COLORS[1],\n",
    "#     side=\"top\"\n",
    "# )\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56756718",
   "metadata": {},
   "source": [
    "### Test point, ie LP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d3b5c36",
   "metadata": {},
   "outputs": [],
   "source": [
    "STAR_EDGE_COLORS = {\n",
    "    0: \"#6E9B34\",  # G → green\n",
    "    1: \"#F28E2B\",  # H → orange (distinct from HH red)\n",
    "}\n",
    "\n",
    "def build_tsne_star_graph(\n",
    "    Z_train,\n",
    "    y_train,\n",
    "    z_test,\n",
    "):\n",
    "    \"\"\"\n",
    "    Star graph centred at test point\n",
    "\n",
    "    Z_train : (n,2)\n",
    "    y_train : (n,)\n",
    "    z_test  : (2,)\n",
    "    \"\"\"\n",
    "    G = nx.Graph()\n",
    "\n",
    "    test_id = \"test\"\n",
    "\n",
    "    # ---- test node ----\n",
    "    G.add_node(\n",
    "        test_id,\n",
    "        pos=tuple(z_test),\n",
    "        cls=\"test\",\n",
    "    )\n",
    "\n",
    "    # ---- training nodes + star edges ----\n",
    "    for i, (pos, cls) in enumerate(zip(Z_train, y_train)):\n",
    "        G.add_node(\n",
    "            i,\n",
    "            pos=tuple(pos),\n",
    "            cls=int(cls),\n",
    "        )\n",
    "\n",
    "        G.add_edge(\n",
    "            test_id,\n",
    "            i,\n",
    "            etype=int(cls),   # 0=G, 1=H\n",
    "        )\n",
    "\n",
    "    return G\n",
    "\n",
    "def plot_tsne_star_graph(\n",
    "    G,\n",
    "    ax=None,\n",
    "    node_size=30,\n",
    "    test_size=400,\n",
    "    edge_alpha=0.5,\n",
    "    edge_lw=0.9,\n",
    "    title=None\n",
    "):\n",
    "    if ax is None:\n",
    "        fig, ax = plt.subplots(figsize=(5, 4))\n",
    "    else:\n",
    "        fig = ax.figure\n",
    "\n",
    "    pos = nx.get_node_attributes(G, \"pos\")\n",
    "    cls = nx.get_node_attributes(G, \"cls\")\n",
    "\n",
    "    # ---- edges (test → train) ----\n",
    "    for cls_id, color in STAR_EDGE_COLORS.items():\n",
    "        edges = [\n",
    "            (u, v)\n",
    "            for u, v, d in G.edges(data=True)\n",
    "            if d[\"etype\"] == cls_id\n",
    "        ]\n",
    "\n",
    "        nx.draw_networkx_edges(\n",
    "            G,\n",
    "            pos,\n",
    "            edgelist=edges,\n",
    "            edge_color=color,\n",
    "            alpha=edge_alpha,\n",
    "            width=edge_lw,\n",
    "            ax=ax,\n",
    "        )\n",
    "\n",
    "    # ---- training nodes ----\n",
    "    for c in [0, 1]:\n",
    "        nodes = [n for n in G.nodes if cls.get(n) == c]\n",
    "        nx.draw_networkx_nodes(\n",
    "            G,\n",
    "            pos,\n",
    "            nodelist=nodes,\n",
    "            node_color=COLORS[c],\n",
    "            node_size=node_size,\n",
    "            node_shape='s' if c == 0 else 'o',\n",
    "            alpha=0.8,\n",
    "            ax=ax,\n",
    "        )\n",
    "\n",
    "    # ---- test node (black star) ----\n",
    "    nx.draw_networkx_nodes(\n",
    "        G,\n",
    "        pos,\n",
    "        nodelist=[\"test\"],\n",
    "        node_color=\"black\",\n",
    "        node_size=test_size,\n",
    "        node_shape=\"*\",\n",
    "        linewidths=1.2,\n",
    "        edgecolors=\"white\",\n",
    "        ax=ax,\n",
    "    )\n",
    "    \n",
    "    if title is not None:\n",
    "        ax.set_title(title)\n",
    "\n",
    "    ax.set_axis_off()\n",
    "    return fig, ax\n",
    "\n",
    "X, y = extract_prompt_data(df, 2, 82)\n",
    "\n",
    "X_trn, X_tst, y_trn, y_tst, _, _ = train_test_split_80_20(\n",
    "    X, y, random_state=42\n",
    ")\n",
    "\n",
    "Z_tr, Z_te, y_tr, y_te = tsne_projection(\n",
    "    X_trn, y_trn,\n",
    "    X_tst, y_tst,\n",
    "    perplexity=30,\n",
    "    random_state=42,\n",
    ")\n",
    "\n",
    "G_star = build_tsne_star_graph(\n",
    "    Z_tr,\n",
    "    y_tr,\n",
    "    Z_te[rid],\n",
    ")\n",
    "\n",
    "fig, ax = plot_tsne_star_graph(G_star)\n",
    "fig.savefig(f\"img/P_embedding_LP{rid}.pdf\", bbox_inches=\"tight\", transparent=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f908227a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fisher_point_to_oblique(z, angle_deg=-30):\n",
    "    theta = np.deg2rad(angle_deg)\n",
    "    d = np.array([np.cos(theta), np.sin(theta)])\n",
    "    return z * d\n",
    "\n",
    "\n",
    "def fisher_to_oblique(z, angle_deg=-30):\n",
    "    \"\"\"\n",
    "    Map 1D Fisher coordinates to 2D points on an oblique line\n",
    "    \"\"\"\n",
    "    theta = np.deg2rad(angle_deg)\n",
    "    direction = np.array([np.cos(theta), np.sin(theta)])\n",
    "    return z[:, None] * direction[None, :]\n",
    "\n",
    "def add_orthogonal_jitter(Z, scale=0.01):\n",
    "    direction = Z.mean(axis=0)\n",
    "    direction /= np.linalg.norm(direction)\n",
    "    perp = np.array([-direction[1], direction[0]])\n",
    "    return Z + scale * np.random.randn(len(Z), 1) * perp\n",
    "\n",
    "\n",
    "def plot_fisher_with_optional_test(\n",
    "    Z_train_j, y_train,\n",
    "    z_test=None,\n",
    "    angle_deg=-30,\n",
    "    ax=None,\n",
    "    train_size=22,\n",
    "    test_size=400,\n",
    "):\n",
    "    if ax is None:\n",
    "        fig, ax = plt.subplots(figsize=(5, 4))\n",
    "    else:\n",
    "        fig = ax.figure\n",
    "\n",
    "    # ---- training points (jittered) ----\n",
    "    for cls in [0, 1]:\n",
    "        mask = (y_train == cls)\n",
    "        ax.scatter(\n",
    "            Z_train_j[mask, 0],\n",
    "            Z_train_j[mask, 1],\n",
    "            s=train_size,\n",
    "            alpha=0.6,\n",
    "            c=COLORS[cls],\n",
    "            marker='s' if cls == 0 else 'o',\n",
    "            zorder=2,\n",
    "        )\n",
    "\n",
    "    # ---- optional test point (NO jitter, on axis) ----\n",
    "    if z_test is not None:\n",
    "        Zt = fisher_point_to_oblique(z_test, angle_deg=angle_deg)\n",
    "        ax.scatter(\n",
    "            Zt[0],\n",
    "            Zt[1],\n",
    "            s=test_size,\n",
    "            c=\"black\",\n",
    "            marker=\"*\",\n",
    "            edgecolor=\"white\",\n",
    "            linewidth=1.2,\n",
    "            zorder=5,\n",
    "        )\n",
    "\n",
    "    ax.set_axis_off()\n",
    "    return fig, ax\n",
    "\n",
    "\n",
    "\n",
    "def draw_fisher_axis(ax, z, angle_deg=-30, lw=1.5):\n",
    "    \"\"\"\n",
    "    Draw the Fisher axis corresponding to 1D coordinates z\n",
    "    \"\"\"\n",
    "    theta = np.deg2rad(angle_deg)\n",
    "    direction = np.array([np.cos(theta), np.sin(theta)])\n",
    "\n",
    "    t_min, t_max = z.min(), z.max()\n",
    "    t_min-=.05\n",
    "    t_max+=.05\n",
    "    line = np.vstack([\n",
    "        t_min * direction,\n",
    "        t_max * direction\n",
    "    ])\n",
    "\n",
    "    ax.plot(\n",
    "        line[:, 0],\n",
    "        line[:, 1],\n",
    "        \"--\",\n",
    "        color=\"black\",\n",
    "        lw=lw,\n",
    "        alpha=0.6,\n",
    "        zorder=0,\n",
    "    )\n",
    "\n",
    "def draw_fisher_violin(\n",
    "    ax,\n",
    "    z,\n",
    "    angle_deg=-30,\n",
    "    width=0.08,\n",
    "    color=\"C0\",\n",
    "    alpha=0.35,\n",
    "    zorder=1,\n",
    "    dotop=True,\n",
    "    dobot=True\n",
    "):\n",
    "    \"\"\"\n",
    "    Draw a half-symmetric violin along the Fisher axis\n",
    "    \"\"\"\n",
    "    from scipy.stats import gaussian_kde\n",
    "    \n",
    "    theta = np.deg2rad(angle_deg)\n",
    "    d = np.array([np.cos(theta), np.sin(theta)])\n",
    "    d_perp = np.array([-d[1], d[0]])\n",
    "\n",
    "    # KDE in Fisher space\n",
    "    kde = gaussian_kde(z)\n",
    "\n",
    "    z_grid = np.linspace(z.min(), z.max(), 300)\n",
    "    density = kde(z_grid)\n",
    "    density = density / density.max() * width\n",
    "\n",
    "    centerline = z_grid[:, None] * d[None, :]\n",
    "\n",
    "    upper = centerline + density[:, None] * d_perp\n",
    "    lower = centerline - density[:, None] * d_perp\n",
    "\n",
    "    if dotop and dobot:\n",
    "        violin = np.vstack([upper, lower[::-1]])\n",
    "    elif dotop:\n",
    "        violin = np.vstack([(upper[0]+lower[0])/2, upper, (upper[-1]+lower[-1])/2])\n",
    "    elif dobot:\n",
    "        violin = np.vstack([(upper[0]+lower[0])/2, lower, (upper[-1]+lower[-1])/2])\n",
    "    else:\n",
    "        raise Exception()\n",
    "\n",
    "    ax.fill(\n",
    "        violin[:, 0],\n",
    "        violin[:, 1],\n",
    "        color=color,\n",
    "        alpha=alpha,\n",
    "        linewidth=0,\n",
    "        zorder=zorder,\n",
    "    )\n",
    "\n",
    "\n",
    "\n",
    "# ---- fit Fisher ----\n",
    "X_G, X_H = split_by_label(X_trn, y_trn)\n",
    "v = fisher_direction(X_G, X_H, lambda_reg=best_reg_lambda)\n",
    "\n",
    "z_tr = X_trn @ v\n",
    "z_te = X_tst @ v\n",
    "\n",
    "Z_tr = fisher_to_oblique(z_tr, angle_deg=-30)\n",
    "Z_te = fisher_to_oblique(z_te, angle_deg=-30)\n",
    "Z_tr_j = add_orthogonal_jitter(Z_tr, scale=0.01)\n",
    "\n",
    "# ---- training scatter + optional test ----\n",
    "fig, ax = plot_fisher_with_optional_test(\n",
    "    Z_tr_j, y_tr,\n",
    "    z_test=None,\n",
    "    angle_deg=-30,\n",
    ")\n",
    "\n",
    "# ---- Fisher axis ----\n",
    "draw_fisher_axis(ax, z_tr, angle_deg=-30, lw=1.5)\n",
    "\n",
    "# ---- optional violins (unchanged) ----\n",
    "# draw_fisher_violin(...)\n",
    "# draw_fisher_violin(...)\n",
    "\n",
    "ax.invert_yaxis()\n",
    "fig.savefig(\"img/P_fisher.pdf\", bbox_inches=\"tight\", transparent=True)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# ---- training scatter + optional test ----\n",
    "fig, ax = plot_fisher_with_optional_test(\n",
    "    Z_tr_j, y_tr,\n",
    "    z_test=z_te[rid],\n",
    "    angle_deg=-30,\n",
    ")\n",
    "\n",
    "# ---- Fisher axis ----\n",
    "draw_fisher_axis(ax, z_tr, angle_deg=-30, lw=1.5)\n",
    "\n",
    "# ---- optional violins (unchanged) ----\n",
    "# draw_fisher_violin(\n",
    "#     ax,\n",
    "#     z_tr[y_tr == 0],\n",
    "#     angle_deg=-30,\n",
    "#     width=0.07,\n",
    "#     color=COLORS[0],\n",
    "#     alpha=0.35,\n",
    "#     zorder=1,\n",
    "#     dobot=False\n",
    "# )\n",
    "\n",
    "# draw_fisher_violin(\n",
    "#     ax,\n",
    "#     z_tr[y_tr == 1],\n",
    "#     angle_deg=-30,\n",
    "#     width=0.07,\n",
    "#     color=COLORS[1],\n",
    "#     alpha=0.35,\n",
    "#     zorder=1,\n",
    "#     dotop=False\n",
    "# )\n",
    "\n",
    "ax.invert_yaxis()\n",
    "fig.savefig(f\"img/P_fisher_LP{rid}.pdf\", bbox_inches=\"tight\", transparent=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df8cecb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "class FisherDistanceExtractor:\n",
    "    def __init__(self, lambda_reg=1e-3, normalise=True, normalise_by_trace=True):\n",
    "        self.lambda_reg = lambda_reg\n",
    "        self.normalise = normalise\n",
    "        self.normalise_by_trace = normalise_by_trace\n",
    "\n",
    "    def fit(self, X_train, y_train):\n",
    "        X_G, X_H = split_by_label(X_train, y_train)\n",
    "\n",
    "        # store original embeddings\n",
    "        self.X_G = X_G\n",
    "        self.X_H = X_H\n",
    "\n",
    "        # Fisher direction\n",
    "        self.v = fisher_direction(\n",
    "            X_G, X_H,\n",
    "            lambda_reg=self.lambda_reg,\n",
    "            normalise=self.normalise,\n",
    "            normalise_by_trace=self.normalise_by_trace\n",
    "        )\n",
    "\n",
    "        # Fisher projections\n",
    "        self.Z_G = (X_G @ self.v)[:, None]\n",
    "        self.Z_H = (X_H @ self.v)[:, None]\n",
    "\n",
    "        # reference intra-class distances\n",
    "        self.ref_G_fisher = pdist(self.Z_G)\n",
    "        self.ref_H_fisher = pdist(self.Z_H)\n",
    "\n",
    "        self.ref_G_embed = pdist(self.X_G)\n",
    "        self.ref_H_embed = pdist(self.X_H)\n",
    "\n",
    "    def extract_test_distances(self, X_test, y_test):\n",
    "        rows = []\n",
    "\n",
    "        for test_id, (x, y_true) in enumerate(zip(X_test, y_test)):\n",
    "            z = (x @ self.v).reshape(1, 1)\n",
    "\n",
    "            # Fisher distances\n",
    "            dG_f = cdist(z, self.Z_G).ravel()\n",
    "            dH_f = cdist(z, self.Z_H).ravel()\n",
    "\n",
    "            # Embedding distances\n",
    "            dG_e = cdist(x[None, :], self.X_G).ravel()\n",
    "            dH_e = cdist(x[None, :], self.X_H).ravel()\n",
    "\n",
    "            for j in range(len(dG_f)):\n",
    "                rows.append({\n",
    "                    \"test_id\": test_id,\n",
    "                    \"y_true\": int(y_true),\n",
    "                    \"target_class\": \"G\",\n",
    "                    \"train_id\": j,\n",
    "                    \"distance_fisher\": dG_f[j],\n",
    "                    \"distance_embed\": dG_e[j],\n",
    "                })\n",
    "\n",
    "            for j in range(len(dH_f)):\n",
    "                rows.append({\n",
    "                    \"test_id\": test_id,\n",
    "                    \"y_true\": int(y_true),\n",
    "                    \"target_class\": \"H\",\n",
    "                    \"train_id\": j,\n",
    "                    \"distance_fisher\": dH_f[j],\n",
    "                    \"distance_embed\": dH_e[j],\n",
    "                })\n",
    "\n",
    "        # ---- reference geometry (optional but useful) ----\n",
    "        for j, d in enumerate(self.ref_G_fisher):\n",
    "            rows.append({\n",
    "                \"test_id\": None,\n",
    "                \"y_true\": None,\n",
    "                \"target_class\": \"G_ref\",\n",
    "                \"train_id\": j,\n",
    "                \"distance_fisher\": d,\n",
    "                \"distance_embed\": self.ref_G_embed[j],\n",
    "            })\n",
    "\n",
    "        for j, d in enumerate(self.ref_H_fisher):\n",
    "            rows.append({\n",
    "                \"test_id\": None,\n",
    "                \"y_true\": None,\n",
    "                \"target_class\": \"H_ref\",\n",
    "                \"train_id\": j,\n",
    "                \"distance_fisher\": d,\n",
    "                \"distance_embed\": self.ref_H_embed[j],\n",
    "            })\n",
    "\n",
    "        return pd.DataFrame(rows)\n",
    "    \n",
    "    def extract_wasserstein_scores(self, X_test, y_test):\n",
    "        rows = []\n",
    "\n",
    "        for test_id, (x, y_true) in enumerate(zip(X_test, y_test)):\n",
    "            z = (x @ self.v).reshape(1, 1)\n",
    "\n",
    "            # Fisher distances\n",
    "            D_G = cdist(z, self.Z_G).ravel()\n",
    "            D_H = cdist(z, self.Z_H).ravel()\n",
    "\n",
    "            # Wasserstein distances\n",
    "            W_G = (\n",
    "                wasserstein_distance(D_G, self.ref_G_fisher)\n",
    "                if len(self.ref_G_fisher) > 0 else np.inf\n",
    "            )\n",
    "            W_H = (\n",
    "                wasserstein_distance(D_H, self.ref_H_fisher)\n",
    "                if len(self.ref_H_fisher) > 0 else np.inf\n",
    "            )\n",
    "\n",
    "            y_pred = 0 if W_G <= W_H else 1\n",
    "            correct = int(y_pred == y_true)\n",
    "\n",
    "            rows.append({\n",
    "                \"test_id\": test_id,\n",
    "                \"y_true\": int(y_true),\n",
    "                \"y_pred\": int(y_pred),\n",
    "                \"correct\": correct,\n",
    "                \"W_G\": W_G,\n",
    "                \"W_H\": W_H,\n",
    "                \"margin\": W_H - W_G,  # positive → G\n",
    "            })\n",
    "\n",
    "        return pd.DataFrame(rows)\n",
    "    \n",
    "def generate_fisher_distance_table(\n",
    "    df,\n",
    "    model_id,\n",
    "    prompt_id,\n",
    "    lambda_reg=1,\n",
    "    random_state=42\n",
    "):\n",
    "    X, y = extract_prompt_data(df, model_id, prompt_id)\n",
    "\n",
    "    X_trn, X_tst, y_trn, y_tst, _, tstIdxs = train_test_split_80_20(\n",
    "        X, y, random_state=random_state\n",
    "    )\n",
    "\n",
    "    extractor = FisherDistanceExtractor(lambda_reg=lambda_reg)\n",
    "    extractor.fit(X_trn, y_trn)\n",
    "\n",
    "    dist_df = extractor.extract_test_distances(X_tst, y_tst)\n",
    "    wass_df = extractor.extract_wasserstein_scores(X_tst, y_tst)\n",
    "\n",
    "    for d in (dist_df, wass_df):\n",
    "        d[\"model_id\"] = model_id\n",
    "        d[\"prompt_id\"] = prompt_id\n",
    "\n",
    "    return dist_df, wass_df\n",
    "\n",
    "def get_test_distance_block(dist_df, test_id, space=\"fisher\"):\n",
    "    \"\"\"\n",
    "    space ∈ {\"fisher\", \"embedding\"}\n",
    "    \"\"\"\n",
    "    col = \"distance_fisher\" if space == \"fisher\" else \"distance_embed\"\n",
    "\n",
    "    dG = dist_df[\n",
    "        (dist_df[\"test_id\"] == test_id) &\n",
    "        (dist_df[\"target_class\"] == \"G\")\n",
    "    ][col].values\n",
    "\n",
    "    dH = dist_df[\n",
    "        (dist_df[\"test_id\"] == test_id) &\n",
    "        (dist_df[\"target_class\"] == \"H\")\n",
    "    ][col].values\n",
    "\n",
    "    return {\n",
    "        \"D_G\": dG,\n",
    "        \"D_H\": dH,\n",
    "    }\n",
    "\n",
    "def plot_test_distance_violin(\n",
    "    dist_df,\n",
    "    test_id,\n",
    "    space=\"fisher\",\n",
    "    colors=None,\n",
    "    figsize=(5, 4),\n",
    "    savepath=None,\n",
    "    title=\"\",\n",
    "    transparent=False,\n",
    "    rotate=False\n",
    "):\n",
    "    if colors is None:\n",
    "        colors = {\n",
    "            \"G\": \"#6E9B34\",\n",
    "            \"H\": \"#AA4D39\",\n",
    "        }\n",
    "\n",
    "    d = get_test_distance_block(dist_df, test_id, space=space)\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=figsize)\n",
    "\n",
    "    # upper: G\n",
    "    vp = ax.violinplot(\n",
    "        d[\"D_G\"],\n",
    "        positions=[0],\n",
    "        widths=0.5,\n",
    "        showmeans=False,\n",
    "        showmedians=False,\n",
    "        showextrema=False,\n",
    "        side=\"high\",\n",
    "        orientation=\"horizontal\",\n",
    "    )\n",
    "    for pc in vp[\"bodies\"]:\n",
    "        pc.set_facecolor(colors[\"G\"])\n",
    "        pc.set_edgecolor(\"black\")\n",
    "        pc.set_linewidth(1.0)\n",
    "        pc.set_alpha(0.5)\n",
    "\n",
    "    # lower: H\n",
    "    vp = ax.violinplot(\n",
    "        d[\"D_H\"],\n",
    "        positions=[0],\n",
    "        widths=0.5,\n",
    "        showmeans=False,\n",
    "        showmedians=False,\n",
    "        showextrema=False,\n",
    "        side=\"low\",\n",
    "        orientation=\"horizontal\",\n",
    "    )\n",
    "    for pc in vp[\"bodies\"]:\n",
    "        pc.set_facecolor(colors[\"H\"])\n",
    "        pc.set_edgecolor(\"black\")\n",
    "        pc.set_linewidth(1.0)\n",
    "        pc.set_alpha(0.5)\n",
    "\n",
    "    # ---- axis arrow ----\n",
    "    xmin, xmax = ax.get_xlim()\n",
    "    ax.annotate(\n",
    "        r\"$L_2$\",\n",
    "        xy=(-0.001, 0),\n",
    "        xytext=(xmax * 1.1, 0),\n",
    "        arrowprops=dict(\n",
    "            arrowstyle=\"<-\",\n",
    "            color=\"0.6\",\n",
    "            linewidth=1.2,\n",
    "            shrinkA=0,\n",
    "            shrinkB=0,\n",
    "        ),\n",
    "        ha=\"left\",\n",
    "        va=\"center\",\n",
    "        color=\"0.4\",\n",
    "        rotation=-90 if rotate else 0,\n",
    "        fontsize=40,\n",
    "        zorder=0,\n",
    "    )\n",
    "\n",
    "    if title is not None:\n",
    "        ax.set_title(title)\n",
    "\n",
    "    # ---- zero reference ----\n",
    "    ax.axvline(0, 0.25, 0.75, color=\"grey\")\n",
    "\n",
    "    ax.set_axis_off()\n",
    "\n",
    "    if savepath is not None:\n",
    "        fig.savefig(savepath, bbox_inches=\"tight\", pad_inches=0, transparent=transparent)\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "dist_df, wass_df = generate_fisher_distance_table(\n",
    "    df,\n",
    "    mid,\n",
    "    pid,\n",
    "    lambda_reg=best_reg_lambda,\n",
    "    random_state=42\n",
    ")\n",
    "\n",
    "tid2H = {}\n",
    "for k, v in dist_df.groupby(['test_id', 'y_true']):\n",
    "    print(int(k[0]), k[1]==1)\n",
    "    tid2H[int(k[0])] = k[1]==1\n",
    "\n",
    "\n",
    "# Fisher space\n",
    "plot_test_distance_violin(\n",
    "    dist_df,\n",
    "    test_id=rid,\n",
    "    space=\"fisher\",\n",
    "    savepath=f\"img/P_violin_fisher_LP{rid}.pdf\",\n",
    "    transparent = True,\n",
    "    rotate=True\n",
    ")\n",
    "\n",
    "# Original embedding space\n",
    "plot_test_distance_violin(\n",
    "    dist_df,\n",
    "    test_id=rid,\n",
    "    space=\"embedding\",\n",
    "    savepath=f\"img/P_violin_embedding_LP{rid}.pdf\",\n",
    "    transparent = True,\n",
    "    rotate=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17d0099d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# for rid in range(10, 20):\n",
    "#     # Fisher space\n",
    "#     plot_test_distance_violin(\n",
    "#         dist_df,\n",
    "#         test_id=rid,\n",
    "#         space=\"fisher\",\n",
    "#         savepath=None,\n",
    "#         title=str(rid)\n",
    "#     )\n",
    "\n",
    "#     # Original embedding space\n",
    "#     plot_test_distance_violin(\n",
    "#         dist_df,\n",
    "#         test_id=rid,\n",
    "#         space=\"embedding\",\n",
    "#         savepath=None,\n",
    "#         title=str(rid)\n",
    "#     )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d90c151",
   "metadata": {},
   "outputs": [],
   "source": [
    "# for tid in [x for x in range(20) if not tid2H[x]]:\n",
    "#     G_star = build_tsne_star_graph(\n",
    "#         Z_tr,\n",
    "#         y_tr,\n",
    "#         Z_te[tid],\n",
    "#     )\n",
    "\n",
    "#     fig, ax = plot_tsne_star_graph(G_star, title = str(tid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0919fd8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "wass_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f295491b",
   "metadata": {},
   "outputs": [],
   "source": [
    "testIdxs[rid]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49139417",
   "metadata": {},
   "outputs": [],
   "source": [
    "df[(df['response_index'] == testIdxs[rid]) & (df['prompt_id'] == pid) & (df['model_id'] == mid)]['response'].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57a8b5fe",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "848b739b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "843507a7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.14.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
