{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad59dd92",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from scipy.spatial.distance import pdist, cdist\n",
    "from scipy.stats import wasserstein_distance\n",
    "\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "\n",
    "import networkx as nx\n",
    "\n",
    "from sklearn.model_selection import StratifiedShuffleSplit\n",
    "\n",
    "import HALL_lib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bf87ecc",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['text.usetex'] = True\n",
    "plt.rcParams['text.latex.preamble'] = r'''\n",
    "\\usepackage{mathtools}\n",
    "% more packages here\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b600274",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============== NEWDATASET ==============\n",
    "\n",
    "# cache_dir = \"cache-newdataset\"\n",
    "# maxResponsesPerPrompt = 500\n",
    "\n",
    "# model_names = {\n",
    "#     0 : \"Llama-8B\",\n",
    "#     1 : \"Phi-14B\",\n",
    "#     2 : \"Mistral-7B\",\n",
    "#     3 : \"Solar-11B\",\n",
    "#     4 : \"Gemma-9B\",\n",
    "#     5 : \"Qwen-15B\",\n",
    "#     6 : \"DeepSeek-7B\",\n",
    "#     7 : \"Yi-9B\"\n",
    "# }\n",
    "\n",
    "# ============== OLDDATASET ==============\n",
    "\n",
    "# cache_dir = \"cache-olddataset-20-22\"\n",
    "# maxResponsesPerPrompt=100\n",
    "\n",
    "# model_names = {\n",
    "#     0: 'Mistral-7B',\n",
    "#     1: 'Gemma-9B',\n",
    "#     2: 'Solar-11B',\n",
    "#     3: 'Phi-14B',\n",
    "#     4: 'Qwen-32B',\n",
    "#     5: 'Gemma-27B',\n",
    "#     6: 'Qwen-15B'\n",
    "# }\n",
    "\n",
    "# best_reg_lambda = 1.9\n",
    "\n",
    "# ============== OLDDATASET ==============\n",
    "\n",
    "cache_dir = \"cache-icml\"\n",
    "maxResponsesPerPrompt=150\n",
    "\n",
    "model_names = {\n",
    "    0: 'Mistral-7B',\n",
    "    1: 'Gemma-9B',\n",
    "    2: 'Solar-11B',\n",
    "    3: 'Phi-14B',\n",
    "    4: 'Qwen-15B',\n",
    "    5: 'Gemma-27B',\n",
    "    6: 'Qwen-32B',\n",
    "    7: \"DeepSeek-7B\",\n",
    "    8: \"Llama-8B\",\n",
    "    9: \"Yi-9B\"\n",
    "}\n",
    "\n",
    "best_reg_lambda = 1.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a810ae8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset_fn = \"../llm-hallucinations/olddataset.parquet\"\n",
    "dataset_fn = \"../llm-hallucinations/dataset_icml.parquet\"\n",
    "\n",
    "df = HALL_lib.loadParquet(dataset_fn, unifyYears=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7f99912",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_order, model_rank = HALL_lib.build_model_size_order(model_names)\n",
    "\n",
    "false_premise = {pid : fp for pid, fp in df[['prompt_id', \"false_premise\"]].value_counts().keys()} if \"false_premise\" in df.columns else None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b67494f",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df, geometry_store, null_store = HALL_lib.run_structural_analysis(\n",
    "    df,\n",
    "    lambda_reg=best_reg_lambda,\n",
    "    n_permutations=100,\n",
    "    random_state=42,\n",
    "    min_per_class_plot=5,\n",
    "    use_cache=True,\n",
    "    cache_dir=cache_dir+'/S-data',\n",
    "    overwrite_cache=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dbc5db8e",
   "metadata": {},
   "source": [
    "## Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77eb9e02",
   "metadata": {},
   "outputs": [],
   "source": [
    "mid = 2  # model id\n",
    "pid = 82 # prompt ic\n",
    "rid = 18 # response id (it was 11 on the old dataset)\n",
    "\n",
    "def train_test_split_80_20(X, y, random_state=42):\n",
    "    sss = StratifiedShuffleSplit(\n",
    "        n_splits=1, test_size=0.2, random_state=random_state\n",
    "    )\n",
    "    trn_idx, tst_idx = next(sss.split(X, y))\n",
    "    return X[trn_idx], X[tst_idx], y[trn_idx], y[tst_idx], trn_idx, tst_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "830d9696",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_distance_block(geometry_store, key, space):\n",
    "    \"\"\"\n",
    "    Returns a dict with keys D_GG, D_HH, D_GH\n",
    "    \"\"\"\n",
    "    d = geometry_store[key]\n",
    "\n",
    "    key = {'embedding': '', 'fisher': '_z'}\n",
    "\n",
    "    if space in d:  # nested\n",
    "        return d[space]\n",
    "\n",
    "    # flattened fallback\n",
    "    return {\n",
    "        \"D_GG\": d[f\"D_GG{key[space]}\"],\n",
    "        \"D_HH\": d[f\"D_HH{key[space]}\"],\n",
    "        \"D_GH\": d[f\"D_GH{key[space]}\"],\n",
    "    }\n",
    "\n",
    "def plot_distance_violin(\n",
    "    geometry_store,\n",
    "    key,\n",
    "    space=\"embedding\",\n",
    "    colors=None,\n",
    "    figsize=(5, 4),\n",
    "    savepath=None,\n",
    "    transparent=False,\n",
    "    rotate=False\n",
    "):\n",
    "    \"\"\"\n",
    "    Half-violin plot for GG / HH and boxplot for GH\n",
    "    in the selected space ('embedding' or 'fisher')\n",
    "    \"\"\"\n",
    "    if colors is None:\n",
    "        colors = {\n",
    "            \"GG\": \"#6E9B34\",\n",
    "            \"HH\": \"#AA4D39\",\n",
    "            \"GH\": \"#27586B\",\n",
    "        }\n",
    "\n",
    "    d = get_distance_block(geometry_store, key, space)\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=figsize)\n",
    "\n",
    "    # ---- GG: upper half violin ----\n",
    "    vp = ax.violinplot(\n",
    "        d[\"D_GG\"],\n",
    "        positions=[0],\n",
    "        widths=0.5,\n",
    "        showmeans=False,\n",
    "        showmedians=False,\n",
    "        showextrema=False,\n",
    "        side=\"high\",\n",
    "        orientation=\"horizontal\",\n",
    "    )\n",
    "    for pc in vp[\"bodies\"]:\n",
    "        pc.set_facecolor(colors[\"GG\"])\n",
    "        pc.set_edgecolor(\"black\")\n",
    "        pc.set_linewidth(1.0)\n",
    "        pc.set_alpha(0.5)\n",
    "\n",
    "    # ---- HH: lower half violin ----\n",
    "    vp = ax.violinplot(\n",
    "        d[\"D_HH\"],\n",
    "        positions=[0],\n",
    "        widths=0.5,\n",
    "        showmeans=False,\n",
    "        showmedians=False,\n",
    "        showextrema=False,\n",
    "        side=\"low\",\n",
    "        orientation=\"horizontal\",\n",
    "    )\n",
    "    for pc in vp[\"bodies\"]:\n",
    "        pc.set_facecolor(colors[\"HH\"])\n",
    "        pc.set_edgecolor(\"black\")\n",
    "        pc.set_linewidth(1.0)\n",
    "        pc.set_alpha(0.5)\n",
    "\n",
    "    # ---- GH: boxplot ----\n",
    "    bp = ax.boxplot(\n",
    "        d[\"D_GH\"],\n",
    "        positions=[0],\n",
    "        widths=0.1,\n",
    "        patch_artist=True,\n",
    "        showfliers=False,\n",
    "        orientation=\"horizontal\",\n",
    "    )\n",
    "    for box in bp[\"boxes\"]:\n",
    "        box.set_facecolor(colors[\"GH\"])\n",
    "        box.set_edgecolor(\"black\")\n",
    "        box.set_alpha(0.9)\n",
    "\n",
    "    for elem in [\"whiskers\", \"caps\", \"medians\"]:\n",
    "        for artist in bp[elem]:\n",
    "            artist.set_color(\"black\")\n",
    "\n",
    "    xmin, xmax = ax.get_xlim()\n",
    "    y0 = 0\n",
    "\n",
    "    ax.annotate(\n",
    "        r\"$L_2$\",\n",
    "        xy=(-.001, y0),\n",
    "        xytext=(xmax*1.1, y0),\n",
    "        # textcoords=\"offset points\",\n",
    "        arrowprops=dict(\n",
    "            arrowstyle=\"<-\",\n",
    "            color=\"0.6\",\n",
    "            linewidth=1.2,\n",
    "            shrinkA=0,\n",
    "            shrinkB=0,\n",
    "        ),\n",
    "        ha=\"left\",\n",
    "        va=\"center\",\n",
    "        color=\"0.4\",\n",
    "        rotation=-90 if rotate else 0,\n",
    "        fontsize=40,\n",
    "        zorder=0,\n",
    "    )\n",
    "\n",
    "    ax.axvline(0, 0.25, 0.75, color='grey')\n",
    "\n",
    "    ax.set_axis_off()\n",
    "\n",
    "    if savepath is not None:\n",
    "        fig.savefig(savepath, bbox_inches=\"tight\", pad_inches=0, transparent=transparent)\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "\n",
    "plot_distance_violin(\n",
    "    geometry_store,\n",
    "    key=(mid, pid),\n",
    "    space=\"embedding\",\n",
    "    savepath=\"img/P_violin_embedding.pdf\",\n",
    "    transparent=True,\n",
    "    rotate=True\n",
    ")\n",
    "plot_distance_violin(\n",
    "    geometry_store,\n",
    "    key=(mid, pid),\n",
    "    space=\"fisher\",\n",
    "    savepath=\"img/P_violin_fisher.pdf\",\n",
    "    transparent=True,\n",
    "    rotate=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f544cd8d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6528cf48",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0e9dd94",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tsne_projection(\n",
    "    X_train, y_train,\n",
    "    X_test, y_test,\n",
    "    perplexity=30,\n",
    "    random_state=42\n",
    "):\n",
    "    X_all = np.vstack([X_train, X_test])\n",
    "    y_all = np.concatenate([y_train, y_test])\n",
    "\n",
    "    tsne = TSNE(\n",
    "        n_components=2,\n",
    "        perplexity=perplexity,\n",
    "        init=\"pca\",\n",
    "        learning_rate=\"auto\",\n",
    "        random_state=random_state,\n",
    "    )\n",
    "\n",
    "    Z_all = tsne.fit_transform(X_all)\n",
    "\n",
    "    n_tr = len(X_train)\n",
    "    Z_tr, Z_te = Z_all[:n_tr], Z_all[n_tr:]\n",
    "\n",
    "    return Z_tr, Z_te, y_train, y_test\n",
    "\n",
    "# COLORS = {\n",
    "#     0: \"#6E9B34\",  # G\n",
    "#     1: \"#AA4D39\",  # H\n",
    "# }\n",
    "\n",
    "# def plot_tsne_with_test_point(\n",
    "#     Z_train, y_train,\n",
    "#     Z_test=None, y_test=None,\n",
    "#     test_id=None,\n",
    "#     ax=None\n",
    "# ):\n",
    "#     if ax is None:\n",
    "#         fig, ax = plt.subplots(figsize=(5, 4))\n",
    "\n",
    "#     # ---- training clouds ----\n",
    "#     for cls in [0, 1]:\n",
    "#         mask = (y_train == cls)\n",
    "#         ax.scatter(\n",
    "#             Z_train[mask, 0],\n",
    "#             Z_train[mask, 1],\n",
    "#             s=22,\n",
    "#             alpha=0.5,\n",
    "#             c=COLORS[cls],\n",
    "#             marker='s' if cls == 0 else 'o',\n",
    "#             label=f\"Train {'G' if cls == 0 else 'H'}\",\n",
    "#         )\n",
    "\n",
    "#     # ---- selected test point ----\n",
    "#     if Z_test is not None:\n",
    "#         ax.scatter(\n",
    "#             Z_test[test_id, 0],\n",
    "#             Z_test[test_id, 1],\n",
    "#             s=90,\n",
    "#             c=\"black\",\n",
    "#             edgecolor=\"white\",\n",
    "#             linewidth=1.2,\n",
    "#             zorder=5,\n",
    "#             label=f\"Test {test_id}\",\n",
    "#         )\n",
    "\n",
    "#     ax.set_axis_off()\n",
    "#     # ax.legend(frameon=False, loc=\"best\")\n",
    "#     return fig, ax\n",
    "\n",
    "# X, y = extract_prompt_data(df, 2, 82)\n",
    "\n",
    "# X_trn, X_tst, y_trn, y_tst = train_test_split_80_20(\n",
    "#     X, y, random_state=42\n",
    "# )\n",
    "\n",
    "# Z_tr, Z_te, y_tr, y_te = tsne_projection(\n",
    "#     X_trn, y_trn, X_tst, y_tst,\n",
    "#     perplexity=30\n",
    "# )\n",
    "\n",
    "# for tid in range(10, 20):\n",
    "#     print(tid)\n",
    "#     plot_tsne_with_test_point(\n",
    "#         Z_tr, y_tr,\n",
    "#         Z_te, y_te,\n",
    "#         test_id=tid\n",
    "#     )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb1e6513",
   "metadata": {},
   "outputs": [],
   "source": [
    "EDGE_COLORS = {\n",
    "    \"GG\": \"#6E9B34\",  # green\n",
    "    \"HH\": \"#AA4D39\",  # red\n",
    "    \"GH\": \"#27586B\",  # blue\n",
    "}\n",
    "\n",
    "import networkx as nx\n",
    "import itertools\n",
    "\n",
    "def build_tsne_complete_graph(Z, y):\n",
    "    \"\"\"\n",
    "    Z : (n,2) t-SNE coordinates\n",
    "    y : (n,) class labels {0,1}\n",
    "    \"\"\"\n",
    "    G = nx.Graph()\n",
    "\n",
    "    # ---- nodes ----\n",
    "    for i, (pos, cls) in enumerate(zip(Z, y)):\n",
    "        G.add_node(\n",
    "            i,\n",
    "            pos=tuple(pos),\n",
    "            cls=int(cls),\n",
    "        )\n",
    "\n",
    "    # ---- complete edges ----\n",
    "    for i, j in itertools.combinations(range(len(Z)), 2):\n",
    "        ci, cj = y[i], y[j]\n",
    "\n",
    "        if ci == 0 and cj == 0:\n",
    "            etype = \"GG\"\n",
    "        elif ci == 1 and cj == 1:\n",
    "            etype = \"HH\"\n",
    "        else:\n",
    "            etype = \"GH\"\n",
    "\n",
    "        G.add_edge(i, j, etype=etype)\n",
    "\n",
    "    return G\n",
    "\n",
    "def plot_tsne_complete_graph(\n",
    "    G,\n",
    "    ax=None,\n",
    "    node_size=30,\n",
    "    edge_alphas={\n",
    "        \"GG\": 0.1,\n",
    "        \"HH\": 0.08,\n",
    "        \"GH\": 0.04,\n",
    "    },\n",
    "    edge_lw=0.6,\n",
    "    colors=None,\n",
    "):\n",
    "    if ax is None:\n",
    "        fig, ax = plt.subplots(figsize=(5, 4))\n",
    "    else:\n",
    "        fig = ax.figure\n",
    "\n",
    "    if colors is None:\n",
    "        colors = {\n",
    "            0: \"#6E9B34\",\n",
    "            1: \"#AA4D39\",\n",
    "        }\n",
    "\n",
    "    pos = nx.get_node_attributes(G, \"pos\")\n",
    "    cls = nx.get_node_attributes(G, \"cls\")\n",
    "\n",
    "    # ---- edges by type ----\n",
    "    for etype, color in EDGE_COLORS.items():\n",
    "        edges = [\n",
    "            (u, v)\n",
    "            for u, v, d in G.edges(data=True)\n",
    "            if d[\"etype\"] == etype\n",
    "        ]\n",
    "\n",
    "        nx.draw_networkx_edges(\n",
    "            G,\n",
    "            pos,\n",
    "            edgelist=edges,\n",
    "            edge_color=color,\n",
    "            alpha=edge_alphas[etype],\n",
    "            width=edge_lw,\n",
    "            ax=ax,\n",
    "        )\n",
    "\n",
    "    # ---- nodes ----\n",
    "    for c in [0, 1]:\n",
    "        nodes = [n for n in G.nodes if cls[n] == c]\n",
    "        nx.draw_networkx_nodes(\n",
    "            G,\n",
    "            pos,\n",
    "            nodelist=nodes,\n",
    "            node_color=colors[c],\n",
    "            node_size=node_size,\n",
    "            node_shape='s' if c == 0 else 'o',\n",
    "            alpha=0.8,\n",
    "            ax=ax,\n",
    "        )\n",
    "\n",
    "    ax.set_axis_off()\n",
    "    return fig, ax\n",
    "\n",
    "\n",
    "X, y = HALL_lib.extract_prompt_data(df, 2, 82)\n",
    "\n",
    "X_trn, X_tst, y_trn, y_tst, trainIdxs, testIdxs = train_test_split_80_20(\n",
    "    X, y, random_state=42\n",
    ")\n",
    "\n",
    "# ---- t-SNE ----\n",
    "Z_tr, Z_te, y_tr, y_te = tsne_projection(\n",
    "    X_trn, y_trn,\n",
    "    X_tst, y_tst,\n",
    "    perplexity=30,\n",
    "    random_state=42,\n",
    ")\n",
    "\n",
    "# ---- graph ----\n",
    "G = build_tsne_complete_graph(Z_tr, y_tr)\n",
    "\n",
    "# ---- plot ----\n",
    "fig, ax = plot_tsne_complete_graph(G, )\n",
    "fig.savefig('img/P_embedding.pdf', bbox_inches='tight', transparent=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c9eeb7a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efc59ef1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # 1D intra-class distances\n",
    "# d_GG = pdist(z_G[:, None], metric=\"euclidean\").ravel()\n",
    "# d_HH = pdist(z_H[:, None], metric=\"euclidean\").ravel()\n",
    "\n",
    "# def draw_fisher_distance_violin(\n",
    "#     ax,\n",
    "#     z_center,\n",
    "#     distances,\n",
    "#     angle_deg=-30,\n",
    "#     width=0.08,\n",
    "#     color=\"C0\",\n",
    "#     alpha=0.35,\n",
    "#     zorder=1,\n",
    "#     side=\"both\",  # \"top\" or \"bottom\"\n",
    "# ):\n",
    "#     theta = np.deg2rad(angle_deg)\n",
    "#     d = np.array([np.cos(theta), np.sin(theta)])\n",
    "#     d_perp = np.array([-d[1], d[0]])\n",
    "\n",
    "#     # KDE over distances\n",
    "#     kde = gaussian_kde(distances)\n",
    "\n",
    "#     r = np.linspace(0, distances.max(), 300)\n",
    "#     density = kde(r)\n",
    "#     density = density / density.max() * width\n",
    "\n",
    "#     # anchor at class centre\n",
    "#     center = z_center * d\n",
    "\n",
    "#     curve = center + r[:, None] * d\n",
    "#     offset = density[:, None] * d_perp\n",
    "\n",
    "#     if side == \"top\":\n",
    "#         poly = np.vstack([curve, curve[::-1] + offset[::-1]])\n",
    "#     elif side == \"bottom\":\n",
    "#         poly = np.vstack([curve, curve[::-1] - offset[::-1]])\n",
    "#     else:\n",
    "#         poly = np.vstack([\n",
    "#             curve + offset,\n",
    "#             (curve - offset)[::-1]\n",
    "#         ])\n",
    "\n",
    "#     ax.fill(\n",
    "#         poly[:, 0],\n",
    "#         poly[:, 1],\n",
    "#         color=color,\n",
    "#         alpha=alpha,\n",
    "#         linewidth=0,\n",
    "#         zorder=zorder,\n",
    "#     )\n",
    "\n",
    "# zG_center = z_G.mean()\n",
    "# zH_center = z_H.mean()\n",
    "\n",
    "# fig, ax = plot_tsne_with_test_point(\n",
    "#     Z_tr_j, y_tr,\n",
    "#     Z_test=None, y_test=None,\n",
    "#     test_id=7\n",
    "# )\n",
    "\n",
    "# draw_fisher_distance_violin(\n",
    "#     ax,\n",
    "#     z_center=zG_center,\n",
    "#     distances=d_GG,\n",
    "#     angle_deg=-30,\n",
    "#     width=0.06,\n",
    "#     color=COLORS[0],\n",
    "#     side=\"bottom\"\n",
    "# )\n",
    "\n",
    "# draw_fisher_distance_violin(\n",
    "#     ax,\n",
    "#     z_center=zH_center,\n",
    "#     distances=d_HH,\n",
    "#     angle_deg=-30,\n",
    "#     width=0.06,\n",
    "#     color=COLORS[1],\n",
    "#     side=\"top\"\n",
    "# )\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56756718",
   "metadata": {},
   "source": [
    "### Test point, ie LP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d3b5c36",
   "metadata": {},
   "outputs": [],
   "source": [
    "COLORS = {\n",
    "    0: \"#6E9B34\",  # G → green\n",
    "    1: \"#AA4D39\",  # H → orange (distinct from HH red)\n",
    "}\n",
    "\n",
    "def build_tsne_star_graph(\n",
    "    Z_train,\n",
    "    y_train,\n",
    "    z_test,\n",
    "):\n",
    "    \"\"\"\n",
    "    Star graph centred at test point\n",
    "\n",
    "    Z_train : (n,2)\n",
    "    y_train : (n,)\n",
    "    z_test  : (2,)\n",
    "    \"\"\"\n",
    "    G = nx.Graph()\n",
    "\n",
    "    test_id = \"test\"\n",
    "\n",
    "    # ---- test node ----\n",
    "    G.add_node(\n",
    "        test_id,\n",
    "        pos=tuple(z_test),\n",
    "        cls=\"test\",\n",
    "    )\n",
    "\n",
    "    # ---- training nodes + star edges ----\n",
    "    for i, (pos, cls) in enumerate(zip(Z_train, y_train)):\n",
    "        G.add_node(\n",
    "            i,\n",
    "            pos=tuple(pos),\n",
    "            cls=int(cls),\n",
    "        )\n",
    "\n",
    "        G.add_edge(\n",
    "            test_id,\n",
    "            i,\n",
    "            etype=int(cls),   # 0=G, 1=H\n",
    "        )\n",
    "\n",
    "    return G\n",
    "\n",
    "def plot_tsne_star_graph(\n",
    "    G,\n",
    "    ax=None,\n",
    "    node_size=30,\n",
    "    test_size=400,\n",
    "    edge_alpha=0.5,\n",
    "    edge_lw=0.9,\n",
    "    title=None\n",
    "):\n",
    "    if ax is None:\n",
    "        fig, ax = plt.subplots(figsize=(5, 4))\n",
    "    else:\n",
    "        fig = ax.figure\n",
    "\n",
    "    pos = nx.get_node_attributes(G, \"pos\")\n",
    "    cls = nx.get_node_attributes(G, \"cls\")\n",
    "\n",
    "    # ---- edges (test → train) ----\n",
    "    for cls_id, color in COLORS.items():\n",
    "        edges = [\n",
    "            (u, v)\n",
    "            for u, v, d in G.edges(data=True)\n",
    "            if d[\"etype\"] == cls_id\n",
    "        ]\n",
    "\n",
    "        nx.draw_networkx_edges(\n",
    "            G,\n",
    "            pos,\n",
    "            edgelist=edges,\n",
    "            edge_color=color,\n",
    "            alpha=edge_alpha,\n",
    "            width=edge_lw,\n",
    "            ax=ax,\n",
    "        )\n",
    "\n",
    "    # ---- training nodes ----\n",
    "    for c in [0, 1]:\n",
    "        nodes = [n for n in G.nodes if cls.get(n) == c]\n",
    "        nx.draw_networkx_nodes(\n",
    "            G,\n",
    "            pos,\n",
    "            nodelist=nodes,\n",
    "            node_color=COLORS[c],\n",
    "            node_size=node_size,\n",
    "            node_shape='s' if c == 0 else 'o',\n",
    "            alpha=0.8,\n",
    "            ax=ax,\n",
    "        )\n",
    "\n",
    "    # ---- test node (black star) ----\n",
    "    nx.draw_networkx_nodes(\n",
    "        G,\n",
    "        pos,\n",
    "        nodelist=[\"test\"],\n",
    "        node_color=\"black\",\n",
    "        node_size=test_size,\n",
    "        node_shape=\"*\",\n",
    "        linewidths=1.2,\n",
    "        edgecolors=\"white\",\n",
    "        ax=ax,\n",
    "    )\n",
    "    \n",
    "    if title is not None:\n",
    "        ax.set_title(title)\n",
    "\n",
    "    ax.set_axis_off()\n",
    "    return fig, ax\n",
    "\n",
    "X, y = HALL_lib.extract_prompt_data(df, 2, 82)\n",
    "\n",
    "X_trn, X_tst, y_trn, y_tst, _, _ = train_test_split_80_20(\n",
    "    X, y, random_state=42\n",
    ")\n",
    "\n",
    "Z_tr, Z_te, y_tr, y_te = tsne_projection(\n",
    "    X_trn, y_trn,\n",
    "    X_tst, y_tst,\n",
    "    perplexity=30,\n",
    "    random_state=42,\n",
    ")\n",
    "G_star = build_tsne_star_graph(\n",
    "    Z_tr,\n",
    "    y_tr,\n",
    "    Z_te[rid],\n",
    ")\n",
    "\n",
    "fig, ax = plot_tsne_star_graph(G_star)\n",
    "fig.savefig(f\"img/P_embedding_LP{rid}.pdf\", bbox_inches=\"tight\", transparent=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f908227a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fisher_point_to_oblique(z, angle_deg=-30):\n",
    "    theta = np.deg2rad(angle_deg)\n",
    "    d = np.array([np.cos(theta), np.sin(theta)])\n",
    "    return z * d\n",
    "\n",
    "\n",
    "def fisher_to_oblique(z, angle_deg=-30):\n",
    "    \"\"\"\n",
    "    Map 1D Fisher coordinates to 2D points on an oblique line\n",
    "    \"\"\"\n",
    "    theta = np.deg2rad(angle_deg)\n",
    "    direction = np.array([np.cos(theta), np.sin(theta)])\n",
    "    return z[:, None] * direction[None, :]\n",
    "\n",
    "def add_orthogonal_jitter(Z, scale=0.01):\n",
    "    direction = Z.mean(axis=0)\n",
    "    direction /= np.linalg.norm(direction)\n",
    "    perp = np.array([-direction[1], direction[0]])\n",
    "    return Z + scale * np.random.randn(len(Z), 1) * perp\n",
    "\n",
    "\n",
    "def plot_fisher_with_optional_test(\n",
    "    Z_train_j, y_train,\n",
    "    z_test=None,\n",
    "    angle_deg=-30,\n",
    "    ax=None,\n",
    "    train_size=22,\n",
    "    test_size=400,\n",
    "):\n",
    "    if ax is None:\n",
    "        fig, ax = plt.subplots(figsize=(5, 4))\n",
    "    else:\n",
    "        fig = ax.figure\n",
    "\n",
    "    # ---- training points (jittered) ----\n",
    "    for cls in [0, 1]:\n",
    "        mask = (y_train == cls)\n",
    "        ax.scatter(\n",
    "            Z_train_j[mask, 0],\n",
    "            Z_train_j[mask, 1],\n",
    "            s=train_size,\n",
    "            alpha=0.6,\n",
    "            c=COLORS[cls],\n",
    "            marker='s' if cls == 0 else 'o',\n",
    "            zorder=2,\n",
    "        )\n",
    "\n",
    "    # ---- optional test point (NO jitter, on axis) ----\n",
    "    if z_test is not None:\n",
    "        Zt = fisher_point_to_oblique(z_test, angle_deg=angle_deg)\n",
    "        ax.scatter(\n",
    "            Zt[0],\n",
    "            Zt[1],\n",
    "            s=test_size,\n",
    "            c=\"black\",\n",
    "            marker=\"*\",\n",
    "            edgecolor=\"white\",\n",
    "            linewidth=1.2,\n",
    "            zorder=5,\n",
    "        )\n",
    "\n",
    "    ax.set_axis_off()\n",
    "    return fig, ax\n",
    "\n",
    "\n",
    "\n",
    "def draw_fisher_axis(ax, z, angle_deg=-30, lw=1.5):\n",
    "    \"\"\"\n",
    "    Draw the Fisher axis corresponding to 1D coordinates z\n",
    "    \"\"\"\n",
    "    theta = np.deg2rad(angle_deg)\n",
    "    direction = np.array([np.cos(theta), np.sin(theta)])\n",
    "\n",
    "    t_min, t_max = z.min(), z.max()\n",
    "    t_min-=.05\n",
    "    t_max+=.05\n",
    "    line = np.vstack([\n",
    "        t_min * direction,\n",
    "        t_max * direction\n",
    "    ])\n",
    "\n",
    "    ax.plot(\n",
    "        line[:, 0],\n",
    "        line[:, 1],\n",
    "        \"--\",\n",
    "        color=\"black\",\n",
    "        lw=lw,\n",
    "        alpha=0.6,\n",
    "        zorder=0,\n",
    "    )\n",
    "\n",
    "def draw_fisher_violin(\n",
    "    ax,\n",
    "    z,\n",
    "    angle_deg=-30,\n",
    "    width=0.08,\n",
    "    color=\"C0\",\n",
    "    alpha=0.35,\n",
    "    zorder=1,\n",
    "    dotop=True,\n",
    "    dobot=True\n",
    "):\n",
    "    \"\"\"\n",
    "    Draw a half-symmetric violin along the Fisher axis\n",
    "    \"\"\"\n",
    "    from scipy.stats import gaussian_kde\n",
    "    \n",
    "    theta = np.deg2rad(angle_deg)\n",
    "    d = np.array([np.cos(theta), np.sin(theta)])\n",
    "    d_perp = np.array([-d[1], d[0]])\n",
    "\n",
    "    # KDE in Fisher space\n",
    "    kde = gaussian_kde(z)\n",
    "\n",
    "    z_grid = np.linspace(z.min(), z.max(), 300)\n",
    "    density = kde(z_grid)\n",
    "    density = density / density.max() * width\n",
    "\n",
    "    centerline = z_grid[:, None] * d[None, :]\n",
    "\n",
    "    upper = centerline + density[:, None] * d_perp\n",
    "    lower = centerline - density[:, None] * d_perp\n",
    "\n",
    "    if dotop and dobot:\n",
    "        violin = np.vstack([upper, lower[::-1]])\n",
    "    elif dotop:\n",
    "        violin = np.vstack([(upper[0]+lower[0])/2, upper, (upper[-1]+lower[-1])/2])\n",
    "    elif dobot:\n",
    "        violin = np.vstack([(upper[0]+lower[0])/2, lower, (upper[-1]+lower[-1])/2])\n",
    "    else:\n",
    "        raise Exception()\n",
    "\n",
    "    ax.fill(\n",
    "        violin[:, 0],\n",
    "        violin[:, 1],\n",
    "        color=color,\n",
    "        alpha=alpha,\n",
    "        linewidth=0,\n",
    "        zorder=zorder,\n",
    "    )\n",
    "\n",
    "\n",
    "\n",
    "# ---- fit Fisher ----\n",
    "X_G, X_H = HALL_lib.split_by_label(X_trn, y_trn)\n",
    "v = HALL_lib.fisher_direction(X_G, X_H, lambda_reg=best_reg_lambda)\n",
    "\n",
    "z_tr = X_trn @ v\n",
    "z_te = X_tst @ v\n",
    "\n",
    "Z_tr = fisher_to_oblique(z_tr, angle_deg=-30)\n",
    "Z_te = fisher_to_oblique(z_te, angle_deg=-30)\n",
    "Z_tr_j = add_orthogonal_jitter(Z_tr, scale=0.01)\n",
    "\n",
    "# ---- training scatter + optional test ----\n",
    "fig, ax = plot_fisher_with_optional_test(\n",
    "    Z_tr_j, y_tr,\n",
    "    z_test=None,\n",
    "    angle_deg=-30,\n",
    ")\n",
    "\n",
    "# ---- Fisher axis ----\n",
    "draw_fisher_axis(ax, z_tr, angle_deg=-30, lw=1.5)\n",
    "\n",
    "# ---- optional violins (unchanged) ----\n",
    "# draw_fisher_violin(...)\n",
    "# draw_fisher_violin(...)\n",
    "\n",
    "ax.invert_yaxis()\n",
    "fig.savefig(\"img/P_fisher.pdf\", bbox_inches=\"tight\", transparent=True)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# ---- training scatter + optional test ----\n",
    "fig, ax = plot_fisher_with_optional_test(\n",
    "    Z_tr_j, y_tr,\n",
    "    z_test=z_te[rid],\n",
    "    angle_deg=-30,\n",
    ")\n",
    "\n",
    "# ---- Fisher axis ----\n",
    "draw_fisher_axis(ax, z_tr, angle_deg=-30, lw=1.5)\n",
    "\n",
    "# ---- optional violins (unchanged) ----\n",
    "# draw_fisher_violin(\n",
    "#     ax,\n",
    "#     z_tr[y_tr == 0],\n",
    "#     angle_deg=-30,\n",
    "#     width=0.07,\n",
    "#     color=COLORS[0],\n",
    "#     alpha=0.35,\n",
    "#     zorder=1,\n",
    "#     dobot=False\n",
    "# )\n",
    "\n",
    "# draw_fisher_violin(\n",
    "#     ax,\n",
    "#     z_tr[y_tr == 1],\n",
    "#     angle_deg=-30,\n",
    "#     width=0.07,\n",
    "#     color=COLORS[1],\n",
    "#     alpha=0.35,\n",
    "#     zorder=1,\n",
    "#     dotop=False\n",
    "# )\n",
    "\n",
    "ax.invert_yaxis()\n",
    "fig.savefig(f\"img/P_fisher_LP{rid}.pdf\", bbox_inches=\"tight\", transparent=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df8cecb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "class FisherDistanceExtractor:\n",
    "    def __init__(self, lambda_reg=1e-3, normalise=True, normalise_by_trace=True):\n",
    "        self.lambda_reg = lambda_reg\n",
    "        self.normalise = normalise\n",
    "        self.normalise_by_trace = normalise_by_trace\n",
    "\n",
    "    def fit(self, X_train, y_train):\n",
    "        X_G, X_H = HALL_lib.split_by_label(X_train, y_train)\n",
    "\n",
    "        # store original embeddings\n",
    "        self.X_G = X_G\n",
    "        self.X_H = X_H\n",
    "\n",
    "        # Fisher direction\n",
    "        self.v = HALL_lib.fisher_direction(\n",
    "            X_G, X_H,\n",
    "            lambda_reg=self.lambda_reg,\n",
    "            normalise=self.normalise,\n",
    "            normalise_by_trace=self.normalise_by_trace\n",
    "        )\n",
    "\n",
    "        # Fisher projections\n",
    "        self.Z_G = (X_G @ self.v)[:, None]\n",
    "        self.Z_H = (X_H @ self.v)[:, None]\n",
    "\n",
    "        # reference intra-class distances\n",
    "        self.ref_G_fisher = pdist(self.Z_G)\n",
    "        self.ref_H_fisher = pdist(self.Z_H)\n",
    "\n",
    "        self.ref_G_embed = pdist(self.X_G)\n",
    "        self.ref_H_embed = pdist(self.X_H)\n",
    "\n",
    "    def extract_test_distances(self, X_test, y_test):\n",
    "        rows = []\n",
    "\n",
    "        for test_id, (x, y_true) in enumerate(zip(X_test, y_test)):\n",
    "            z = (x @ self.v).reshape(1, 1)\n",
    "\n",
    "            # Fisher distances\n",
    "            dG_f = cdist(z, self.Z_G).ravel()\n",
    "            dH_f = cdist(z, self.Z_H).ravel()\n",
    "\n",
    "            # Embedding distances\n",
    "            dG_e = cdist(x[None, :], self.X_G).ravel()\n",
    "            dH_e = cdist(x[None, :], self.X_H).ravel()\n",
    "\n",
    "            for j in range(len(dG_f)):\n",
    "                rows.append({\n",
    "                    \"test_id\": test_id,\n",
    "                    \"y_true\": int(y_true),\n",
    "                    \"target_class\": \"G\",\n",
    "                    \"train_id\": j,\n",
    "                    \"distance_fisher\": dG_f[j],\n",
    "                    \"distance_embed\": dG_e[j],\n",
    "                })\n",
    "\n",
    "            for j in range(len(dH_f)):\n",
    "                rows.append({\n",
    "                    \"test_id\": test_id,\n",
    "                    \"y_true\": int(y_true),\n",
    "                    \"target_class\": \"H\",\n",
    "                    \"train_id\": j,\n",
    "                    \"distance_fisher\": dH_f[j],\n",
    "                    \"distance_embed\": dH_e[j],\n",
    "                })\n",
    "\n",
    "        # ---- reference geometry (optional but useful) ----\n",
    "        for j, d in enumerate(self.ref_G_fisher):\n",
    "            rows.append({\n",
    "                \"test_id\": None,\n",
    "                \"y_true\": None,\n",
    "                \"target_class\": \"G_ref\",\n",
    "                \"train_id\": j,\n",
    "                \"distance_fisher\": d,\n",
    "                \"distance_embed\": self.ref_G_embed[j],\n",
    "            })\n",
    "\n",
    "        for j, d in enumerate(self.ref_H_fisher):\n",
    "            rows.append({\n",
    "                \"test_id\": None,\n",
    "                \"y_true\": None,\n",
    "                \"target_class\": \"H_ref\",\n",
    "                \"train_id\": j,\n",
    "                \"distance_fisher\": d,\n",
    "                \"distance_embed\": self.ref_H_embed[j],\n",
    "            })\n",
    "\n",
    "        return pd.DataFrame(rows)\n",
    "    \n",
    "    def extract_wasserstein_scores(self, X_test, y_test):\n",
    "        rows = []\n",
    "\n",
    "        for test_id, (x, y_true) in enumerate(zip(X_test, y_test)):\n",
    "            z = (x @ self.v).reshape(1, 1)\n",
    "\n",
    "            # Fisher distances\n",
    "            D_G = cdist(z, self.Z_G).ravel()\n",
    "            D_H = cdist(z, self.Z_H).ravel()\n",
    "\n",
    "            # Wasserstein distances\n",
    "            W_G = (\n",
    "                wasserstein_distance(D_G, self.ref_G_fisher)\n",
    "                if len(self.ref_G_fisher) > 0 else np.inf\n",
    "            )\n",
    "            W_H = (\n",
    "                wasserstein_distance(D_H, self.ref_H_fisher)\n",
    "                if len(self.ref_H_fisher) > 0 else np.inf\n",
    "            )\n",
    "\n",
    "            y_pred = 0 if W_G <= W_H else 1\n",
    "            correct = int(y_pred == y_true)\n",
    "\n",
    "            rows.append({\n",
    "                \"test_id\": test_id,\n",
    "                \"y_true\": int(y_true),\n",
    "                \"y_pred\": int(y_pred),\n",
    "                \"correct\": correct,\n",
    "                \"W_G\": W_G,\n",
    "                \"W_H\": W_H,\n",
    "                \"margin\": W_H - W_G,  # positive → G\n",
    "            })\n",
    "\n",
    "        return pd.DataFrame(rows)\n",
    "    \n",
    "def generate_fisher_distance_table(\n",
    "    df,\n",
    "    model_id,\n",
    "    prompt_id,\n",
    "    lambda_reg=1,\n",
    "    random_state=42\n",
    "):\n",
    "    X, y = HALL_lib.extract_prompt_data(df, model_id, prompt_id)\n",
    "\n",
    "    X_trn, X_tst, y_trn, y_tst, _, tstIdxs = train_test_split_80_20(\n",
    "        X, y, random_state=random_state\n",
    "    )\n",
    "\n",
    "    extractor = FisherDistanceExtractor(lambda_reg=lambda_reg)\n",
    "    extractor.fit(X_trn, y_trn)\n",
    "\n",
    "    dist_df = extractor.extract_test_distances(X_tst, y_tst)\n",
    "    wass_df = extractor.extract_wasserstein_scores(X_tst, y_tst)\n",
    "\n",
    "    for d in (dist_df, wass_df):\n",
    "        d[\"model_id\"] = model_id\n",
    "        d[\"prompt_id\"] = prompt_id\n",
    "\n",
    "    return dist_df, wass_df\n",
    "\n",
    "def get_test_distance_block(dist_df, test_id, space=\"fisher\"):\n",
    "    \"\"\"\n",
    "    space ∈ {\"fisher\", \"embedding\"}\n",
    "    \"\"\"\n",
    "    col = \"distance_fisher\" if space == \"fisher\" else \"distance_embed\"\n",
    "\n",
    "    dG = dist_df[\n",
    "        (dist_df[\"test_id\"] == test_id) &\n",
    "        (dist_df[\"target_class\"] == \"G\")\n",
    "    ][col].values\n",
    "\n",
    "    dH = dist_df[\n",
    "        (dist_df[\"test_id\"] == test_id) &\n",
    "        (dist_df[\"target_class\"] == \"H\")\n",
    "    ][col].values\n",
    "\n",
    "    return {\n",
    "        \"D_G\": dG,\n",
    "        \"D_H\": dH,\n",
    "    }\n",
    "\n",
    "def plot_test_distance_violin(\n",
    "    dist_df,\n",
    "    test_id,\n",
    "    space=\"fisher\",\n",
    "    colors=None,\n",
    "    figsize=(5, 4),\n",
    "    savepath=None,\n",
    "    title=\"\",\n",
    "    transparent=False,\n",
    "    rotate=False\n",
    "):\n",
    "    if colors is None:\n",
    "        colors = {\n",
    "            \"G\": \"#6E9B34\",\n",
    "            \"H\": \"#AA4D39\",\n",
    "        }\n",
    "\n",
    "    d = get_test_distance_block(dist_df, test_id, space=space)\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=figsize)\n",
    "\n",
    "    # upper: G\n",
    "    vp = ax.violinplot(\n",
    "        d[\"D_G\"],\n",
    "        positions=[0],\n",
    "        widths=0.5,\n",
    "        showmeans=False,\n",
    "        showmedians=False,\n",
    "        showextrema=False,\n",
    "        side=\"high\",\n",
    "        orientation=\"horizontal\",\n",
    "    )\n",
    "    for pc in vp[\"bodies\"]:\n",
    "        pc.set_facecolor(colors[\"G\"])\n",
    "        pc.set_edgecolor(\"black\")\n",
    "        pc.set_linewidth(1.0)\n",
    "        pc.set_alpha(0.5)\n",
    "\n",
    "    # lower: H\n",
    "    vp = ax.violinplot(\n",
    "        d[\"D_H\"],\n",
    "        positions=[0],\n",
    "        widths=0.5,\n",
    "        showmeans=False,\n",
    "        showmedians=False,\n",
    "        showextrema=False,\n",
    "        side=\"low\",\n",
    "        orientation=\"horizontal\",\n",
    "    )\n",
    "    for pc in vp[\"bodies\"]:\n",
    "        pc.set_facecolor(colors[\"H\"])\n",
    "        pc.set_edgecolor(\"black\")\n",
    "        pc.set_linewidth(1.0)\n",
    "        pc.set_alpha(0.5)\n",
    "\n",
    "    # ---- axis arrow ----\n",
    "    xmin, xmax = ax.get_xlim()\n",
    "    ax.annotate(\n",
    "        r\"$L_2$\",\n",
    "        xy=(-0.001, 0),\n",
    "        xytext=(xmax * 1.1, 0),\n",
    "        arrowprops=dict(\n",
    "            arrowstyle=\"<-\",\n",
    "            color=\"0.6\",\n",
    "            linewidth=1.2,\n",
    "            shrinkA=0,\n",
    "            shrinkB=0,\n",
    "        ),\n",
    "        ha=\"left\",\n",
    "        va=\"center\",\n",
    "        color=\"0.4\",\n",
    "        rotation=-90 if rotate else 0,\n",
    "        fontsize=40,\n",
    "        zorder=0,\n",
    "    )\n",
    "\n",
    "    if title is not None:\n",
    "        ax.set_title(title)\n",
    "\n",
    "    # ---- zero reference ----\n",
    "    ax.axvline(0, 0.25, 0.75, color=\"grey\")\n",
    "\n",
    "    ax.set_axis_off()\n",
    "\n",
    "    if savepath is not None:\n",
    "        fig.savefig(savepath, bbox_inches=\"tight\", pad_inches=0, transparent=transparent)\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "dist_df, wass_df = generate_fisher_distance_table(\n",
    "    df,\n",
    "    mid,\n",
    "    pid,\n",
    "    lambda_reg=best_reg_lambda,\n",
    "    random_state=42\n",
    ")\n",
    "\n",
    "tid2H = {}\n",
    "for k, v in dist_df.groupby(['test_id', 'y_true']):\n",
    "    print(int(k[0]), k[1]==1)\n",
    "    tid2H[int(k[0])] = k[1]==1\n",
    "\n",
    "\n",
    "# Fisher space\n",
    "plot_test_distance_violin(\n",
    "    dist_df,\n",
    "    test_id=rid,\n",
    "    space=\"fisher\",\n",
    "    savepath=f\"img/P_violin_fisher_LP{rid}.pdf\",\n",
    "    transparent = True,\n",
    "    rotate=True\n",
    ")\n",
    "\n",
    "# Original embedding space\n",
    "plot_test_distance_violin(\n",
    "    dist_df,\n",
    "    test_id=rid,\n",
    "    space=\"embedding\",\n",
    "    savepath=f\"img/P_violin_embedding_LP{rid}.pdf\",\n",
    "    transparent = True,\n",
    "    rotate=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4cd7518",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17d0099d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# for rid in range(10, 20):\n",
    "#     # Fisher space\n",
    "#     plot_test_distance_violin(\n",
    "#         dist_df,\n",
    "#         test_id=rid,\n",
    "#         space=\"fisher\",\n",
    "#         savepath=None,\n",
    "#         title=str(rid)\n",
    "#     )\n",
    "\n",
    "#     # Original embedding space\n",
    "#     plot_test_distance_violin(\n",
    "#         dist_df,\n",
    "#         test_id=rid,\n",
    "#         space=\"embedding\",\n",
    "#         savepath=None,\n",
    "#         title=str(rid)\n",
    "#     )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d90c151",
   "metadata": {},
   "outputs": [],
   "source": [
    "# for tid in [x for x in range(20) if not tid2H[x]]:\n",
    "#     G_star = build_tsne_star_graph(\n",
    "#         Z_tr,\n",
    "#         y_tr,\n",
    "#         Z_te[tid],\n",
    "#     )\n",
    "\n",
    "#     fig, ax = plot_tsne_star_graph(G_star, title = str(tid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0919fd8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "wass_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27f6e0fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "testIdxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f295491b",
   "metadata": {},
   "outputs": [],
   "source": [
    "testIdxs[rid]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49139417",
   "metadata": {},
   "outputs": [],
   "source": [
    "df[(df['response_index'] == testIdxs[rid]) & (df['prompt_id'] == pid) & (df['model_id'] == mid)]['response'].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57a8b5fe",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "848b739b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "843507a7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92a0bc68",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c49b80b7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
