{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad59dd92",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import HALL_lib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bf87ecc",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['text.usetex'] = True\n",
    "plt.rcParams['text.latex.preamble'] = r'''\n",
    "\\usepackage{mathtools}\n",
    "% more packages here\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b600274",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============== NEWDATASET ==============\n",
    "\n",
    "# cache_dir = \"cache-newdataset\"\n",
    "# maxResponsesPerPrompt = 500\n",
    "\n",
    "# model_names = {\n",
    "#     0 : \"Llama-8B\",\n",
    "#     1 : \"Phi-14B\",\n",
    "#     2 : \"Mistral-7B\",\n",
    "#     3 : \"Solar-11B\",\n",
    "#     4 : \"Gemma-9B\",\n",
    "#     5 : \"Qwen-15B\",\n",
    "#     6 : \"DeepSeek-7B\",\n",
    "#     7 : \"Yi-9B\"\n",
    "# }\n",
    "\n",
    "# ============== OLDDATASET ==============\n",
    "\n",
    "# cache_dir = \"cache-olddataset-20-22\"\n",
    "# maxResponsesPerPrompt=100\n",
    "\n",
    "# model_names = {\n",
    "#     0: 'Mistral-7B',\n",
    "#     1: 'Gemma-9B',\n",
    "#     2: 'Solar-11B',\n",
    "#     3: 'Phi-14B',\n",
    "#     4: 'Qwen-32B',\n",
    "#     5: 'Gemma-27B',\n",
    "#     6: 'Qwen-15B'\n",
    "# }\n",
    "\n",
    "# best_reg_lambda = 1.9\n",
    "\n",
    "# ============== OLDDATASET ==============\n",
    "\n",
    "# embedding_label = \"stemmed_response_embeddings\"\n",
    "# cache_dir = \"cache-icml-stemmed\"\n",
    "embedding_label=\"response_embeddings\"\n",
    "cache_dir = \"cache-icml\"\n",
    "maxResponsesPerPrompt=150\n",
    "\n",
    "model_names = {\n",
    "    0: 'Mistral-7B',\n",
    "    1: 'Gemma-9B',\n",
    "    2: 'Solar-11B',\n",
    "    3: 'Phi-14B',\n",
    "    4: 'Qwen-15B',\n",
    "    5: 'Gemma-27B',\n",
    "    6: 'Qwen-32B',\n",
    "    7: \"DeepSeek-7B\",\n",
    "    8: \"Llama-8B\",\n",
    "    9: \"Yi-9B\"\n",
    "}\n",
    "\n",
    "model_latextags = {\n",
    "    0: \"\\\\MISTRAL\",\n",
    "    7: \"\\\\DEEPSEEK\",\n",
    "    8: \"\\\\LLAMA\",\n",
    "    1: \"\\\\GEMMANINE\",\n",
    "    9: \"\\\\YI\",\n",
    "    2: \"\\\\SOLAR\",\n",
    "    3: \"\\\\PHI\",\n",
    "    4: \"\\\\QWENFOURTEEN\",\n",
    "    5: \"\\\\GEMMATWENTYSEVEN\",\n",
    "    6: \"\\\\QWENTHIRTYTWO\"\n",
    "}\n",
    "\n",
    "best_reg_lambda = 1.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a810ae8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_fn = \"../llm-hallucinations/dataset_icml.parquet\"\n",
    "\n",
    "df = HALL_lib.loadParquet(dataset_fn, unifyYears=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7f99912",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_order, model_rank = HALL_lib.build_model_size_order(model_names)\n",
    "\n",
    "false_premise = {pid : fp for pid, fp in df[['prompt_id', \"false_premise\"]].value_counts().keys()} if \"false_premise\" in df.columns else None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89869d9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df, geometry_store, null_store = HALL_lib.run_structural_analysis(\n",
    "    df,\n",
    "    lambda_reg=best_reg_lambda,\n",
    "    n_permutations=100,\n",
    "    random_state=42,\n",
    "    min_per_class_plot=5,\n",
    "    use_cache=True,\n",
    "    cache_dir=cache_dir+'/S-data',\n",
    "    overwrite_cache=False,\n",
    "    embedding_label=embedding_label\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dd570d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_ids = list(model_names.keys())\n",
    "\n",
    "prompt_ids_by_model = {\n",
    "    mid: sorted(df[df[\"model_id\"] == mid][\"prompt_id\"].unique())\n",
    "    for mid in model_ids\n",
    "}\n",
    "\n",
    "results_lp_max = HALL_lib.run_full_label_propagation_study(\n",
    "    df=df,\n",
    "    model_ids=model_ids,\n",
    "    prompt_ids_by_model=prompt_ids_by_model,\n",
    "    projector_class=HALL_lib.FisherProjection,\n",
    "    projector_kwargs={\n",
    "        \"lambda_reg\": best_reg_lambda,\n",
    "        \"normalise\": True,\n",
    "        \"normalise_by_trace\": True,\n",
    "    },\n",
    "    train_fractions=None,\n",
    "    n_iter=20,\n",
    "    test_fraction=1/3,\n",
    "    n_splits=20,\n",
    "    ref_lambda_reg=None,\n",
    "    use_cache=True,\n",
    "    cache_dir=f\"{cache_dir}/LP-fisher\",\n",
    "    overwrite_cache=False,\n",
    "    logskip=True,\n",
    "    embedding_label=embedding_label\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54e18992",
   "metadata": {},
   "outputs": [],
   "source": [
    "agg_per_prompt = HALL_lib.aggregate_metric_over_prompts(results_lp_max, metric=\"f1\", agg_prompts=False, agg_models=False, agg_train_frac=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "484fb394",
   "metadata": {},
   "source": [
    "## How prompts behave in a model in terms of accuracy, order by f1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2f99ddf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_prompt_heatmap_per_model(\n",
    "    df_prompt,\n",
    "    model_id,\n",
    "    model_name,\n",
    "    sort_by=\"metric_mean\",\n",
    "    value_col=\"score_mean\",\n",
    "    metric=\"f1\",\n",
    "    score='accuracy',\n",
    "    ratio=(3, 5),\n",
    "    scale=2\n",
    "):\n",
    "    \"\"\"\n",
    "    Heatmap of prompts for one model, ordered by sort_by.\n",
    "    \"\"\"\n",
    "    \n",
    "    df_m = (\n",
    "        df_prompt[df_prompt[\"model_id\"] == model_id]\n",
    "        .sort_values(sort_by)\n",
    "    )\n",
    "\n",
    "    values = df_m[value_col].values[:, None]  # (n_prompts, 1)\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    im = ax.imshow(\n",
    "        values,\n",
    "        aspect=\"auto\",\n",
    "        origin=\"lower\"\n",
    "    )\n",
    "\n",
    "    ax.set_yticks(range(len(df_m)))\n",
    "    ax.set_yticklabels(df_m[\"prompt_id\"])\n",
    "    ax.set_xticks([0])\n",
    "    ax.set_xticklabels([value_col])\n",
    "\n",
    "    ax.set_title(f\"{model_name}\\nPrompts ordered by {metric}\")\n",
    "\n",
    "    fig.colorbar(im, ax=ax, label=value_col)\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "# ===== ===== ===== ===== ===== ===== ===== =====\n",
    "\n",
    "plot_prompt_heatmap_per_model(agg_per_prompt, 0, model_names[0], metric=\"f1\", score='accuracy')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14a7a6f3",
   "metadata": {},
   "source": [
    "## Which prompts are intrinsically hard, regardless of model?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d1ab800",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_prompt_difficulty(results_lp_max, ratio=(5, 3), scale=2):\n",
    "\n",
    "    def compute_prompt_difficulty(df, metric=\"f1\"):\n",
    "        return (\n",
    "            df\n",
    "            .groupby(\"prompt_id\")\n",
    "            .agg(\n",
    "                mean_f1=(metric, \"mean\"),\n",
    "                std_f1=(metric, \"std\"),\n",
    "            )\n",
    "            .sort_values(\"mean_f1\")\n",
    "            .reset_index()\n",
    "        )\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    df_prompt = compute_prompt_difficulty(results_lp_max)\n",
    "\n",
    "    ax.barh(\n",
    "        df_prompt[\"prompt_id\"],\n",
    "        df_prompt[\"mean_f1\"],\n",
    "        xerr=df_prompt[\"std_f1\"]\n",
    "    )\n",
    "\n",
    "    ax.set_xlabel(\"Mean F1 across models\")\n",
    "    ax.set_title(\"Prompt difficulty ranking\")\n",
    "    ax.grid(True, axis=\"x\", alpha=0.4)\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "# ===== ===== ===== ===== ===== ===== ===== =====\n",
    "\n",
    "plot_prompt_difficulty(results_lp_max)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5169ede0",
   "metadata": {},
   "source": [
    "## Are models failing on the same prompts?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80c5deeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_model_agreement_heatmap(\n",
    "    agg_per_prompt,\n",
    "    model_names,\n",
    "    ratio=(4, 3),\n",
    "    scale=2\n",
    "):\n",
    "    def compute_model_agreement(df_prompt, metric=\"metric_mean\"):\n",
    "        pivot = df_prompt.pivot(\n",
    "            index=\"prompt_id\",\n",
    "            columns=\"model_id\",\n",
    "            values=metric\n",
    "        )\n",
    "\n",
    "        return pivot.corr()\n",
    "    \n",
    "    corr_df = compute_model_agreement(agg_per_prompt)\n",
    "    \n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    im = ax.imshow(corr_df.values, vmin=-1, vmax=1)\n",
    "\n",
    "    ax.set_xticks(range(len(corr_df)))\n",
    "    ax.set_yticks(range(len(corr_df)))\n",
    "\n",
    "    ax.set_xticklabels([model_names[m] for m in corr_df.columns], rotation=30)\n",
    "    ax.set_yticklabels([model_names[m] for m in corr_df.index])\n",
    "\n",
    "    fig.colorbar(im, ax=ax, label=\"Prompt-level F1 correlation\")\n",
    "    # ax.set_title(\"Model agreement on prompt difficulty\")\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "# ===== ===== ===== ===== ===== ===== ===== =====\n",
    "\n",
    "plot_model_agreement_heatmap(agg_per_prompt, model_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e987e697",
   "metadata": {},
   "source": [
    "## Compare mean accuracy and std (only 100%)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54056315",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_mean_vs_variability(df_prompt, model_names, ratio=(4, 3), scale=2):\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    for mid, name in model_names.items():\n",
    "        vals = df_prompt[df_prompt[\"model_id\"] == mid][\"metric_mean\"]\n",
    "        ax.scatter(\n",
    "            vals.mean(),\n",
    "            vals.std(),\n",
    "            label=name\n",
    "        )\n",
    "\n",
    "    ax.set_xlabel(\"Mean F1 across prompts\")\n",
    "    ax.set_ylabel(\"Std F1 across prompts\")\n",
    "    ax.set_title(\"Performance vs stability\")\n",
    "    ax.legend()\n",
    "    ax.grid(True, alpha=0.4)\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "# ===== ===== ===== ===== ===== ===== ===== =====\n",
    "\n",
    "agg_per_prompt = HALL_lib.aggregate_metric_over_prompts(results_lp_max, metric=\"f1\", agg_prompts=False, agg_models=False, agg_train_frac=True)\n",
    "plot_mean_vs_variability(agg_per_prompt, model_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02ee69e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_metric_boxplots_two_panels(\n",
    "    results_lp,\n",
    "    model_names,\n",
    "    model_order,\n",
    "    train_fraction,\n",
    "    metrics=(\"accuracy\", \"f1\"),\n",
    "    ratio=(6, 3),\n",
    "    scale=2,\n",
    "    width_ratios=[1, 1],\n",
    "    xlims=None,\n",
    "    cmap_name=\"tab10\"\n",
    "):\n",
    "    fig, axes = plt.subplots(\n",
    "        1, 2,\n",
    "        figsize=[scale * x for x in ratio],\n",
    "        sharey=True,\n",
    "        width_ratios=width_ratios\n",
    "    )\n",
    "\n",
    "    cmap = plt.get_cmap(cmap_name)\n",
    "\n",
    "    df_prompt = (\n",
    "        results_lp\n",
    "        .groupby([\"model_id\", \"prompt_id\", \"train_fraction\"])\n",
    "        .agg(\n",
    "            f1_mean=(\"f1\", \"mean\"),\n",
    "            accuracy_mean=(\"accuracy\", \"mean\"),\n",
    "            mean_n_train=(\"n_train\", \"mean\"),\n",
    "        )\n",
    "        .reset_index()\n",
    "    )\n",
    "\n",
    "    for it, metric in enumerate(metrics):\n",
    "        ax = axes[it]\n",
    "        data = []\n",
    "        labels = []\n",
    "        colors = []\n",
    "\n",
    "        for i, mid in enumerate(reversed(model_order)):\n",
    "            name = model_names[mid]\n",
    "\n",
    "            vals = df_prompt[\n",
    "                (df_prompt[\"model_id\"] == mid) &\n",
    "                (df_prompt[\"train_fraction\"] == train_fraction)\n",
    "            ][f\"{metric}_mean\"].values\n",
    "\n",
    "            if len(vals) > 0:\n",
    "                data.append(vals)\n",
    "                labels.append(name)\n",
    "                colors.append(cmap(i % cmap.N))\n",
    "\n",
    "        positions = np.arange(1, len(data) + 1)\n",
    "\n",
    "        # ---- violins (horizontal)\n",
    "        vp = ax.violinplot(\n",
    "            data,\n",
    "            widths=.75,\n",
    "            positions=positions,\n",
    "            vert=False,\n",
    "            showmeans=False,\n",
    "            showextrema=False,\n",
    "        )\n",
    "\n",
    "        for body, c in zip(vp[\"bodies\"], colors):\n",
    "            body.set_facecolor(c)\n",
    "            body.set_alpha(0.35)\n",
    "            body.set_edgecolor(\"none\")\n",
    "\n",
    "        # ---- boxes (horizontal)\n",
    "        bp = ax.boxplot(\n",
    "            data,\n",
    "            positions=positions,\n",
    "            vert=False,\n",
    "            widths=0.25,\n",
    "            showfliers=False,\n",
    "            patch_artist=True,\n",
    "            medianprops=dict(color=\"black\", linewidth=1.5),\n",
    "        )\n",
    "\n",
    "        for box, c in zip(bp[\"boxes\"], colors):\n",
    "            box.set_facecolor(c)\n",
    "            box.set_alpha(0.6)\n",
    "\n",
    "        ax.set_yticks(positions)\n",
    "        ax.set_yticklabels(labels)\n",
    "        # ax.set_xlabel(\"Score\")\n",
    "        ax.set_title(metric.capitalize())\n",
    "        ax.grid(True, axis=\"x\", alpha=0.4)\n",
    "        ax.set_xlim(xlims[it])\n",
    "\n",
    "    # axes[0].set_ylabel(\"Model\")\n",
    "\n",
    "    # fig.suptitle(\n",
    "    #     f\"Prompt-level score distributions (train_fraction={train_fraction})\",\n",
    "    #     y=1.02\n",
    "    # )\n",
    "\n",
    "    fig.tight_layout()\n",
    "    return fig, axes\n",
    "\n",
    "# ===== ===== ===== ===== ===== ===== ===== =====\n",
    "\n",
    "fig, _ = plot_metric_boxplots_two_panels(\n",
    "    results_lp_max,\n",
    "    model_names=model_names,\n",
    "    model_order=model_order,\n",
    "    train_fraction=1,\n",
    "    width_ratios=[5, 9],\n",
    "    xlims=[(0.5, 1.0), (.1, 1.0)],\n",
    "    ratio=(3,1),\n",
    "    scale=3\n",
    ")\n",
    "\n",
    "# fig.savefig(\"img/LP_statistics.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3df478f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Correlation beetween hall_rate and performance ?\n",
    "# select relevant columns\n",
    "lp_cols = [\n",
    "    \"model_id\", \"prompt_id\",\n",
    "    \"f1\",\n",
    "    \"accuracy\",\n",
    "    \"precision\",\n",
    "    \"recall\",\n",
    "]\n",
    "\n",
    "struct_cols = [\n",
    "    \"model_id\", \"prompt_id\",\n",
    "    \"n_G\", \"n_H\", \"frac_H\",\n",
    "    \"valid_geom\",\n",
    "]\n",
    "\n",
    "df_plot = (\n",
    "    results_lp_max[lp_cols]\n",
    "    .merge(\n",
    "        results_df[struct_cols],\n",
    "        on=[\"model_id\", \"prompt_id\"],\n",
    "        how=\"inner\"\n",
    "    )\n",
    ")\n",
    "\n",
    "# filter\n",
    "\n",
    "df_plot = df_plot[\n",
    "    (df_plot[\"n_G\"] >= 5) &\n",
    "    (df_plot[\"n_H\"] >= 5) &\n",
    "    (df_plot[\"valid_geom\"])\n",
    "].copy()\n",
    "\n",
    "models = sorted(df_plot[\"model_id\"].unique())\n",
    "cmap = plt.get_cmap(\"tab10\", len(models))\n",
    "color_map = {m: cmap(i) for i, m in enumerate(models)}\n",
    "\n",
    "colors = df_plot[\"model_id\"].map(color_map)\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7.5, 5.5))\n",
    "\n",
    "ax.scatter(\n",
    "    df_plot[\"frac_H\"],\n",
    "    df_plot[\"accuracy\"],\n",
    "    c=colors,\n",
    "    alpha=0.75,\n",
    "    edgecolor=\"black\",\n",
    "    linewidth=0.4,\n",
    "    s=45,\n",
    ")\n",
    "\n",
    "ax.set_xlabel(\"Hallucination rate\")\n",
    "ax.set_ylabel(\"Label Propagator Accuracy score\")\n",
    "\n",
    "ax.set_xlim(0, 1)\n",
    "ax.set_ylim(0, 1)\n",
    "ax.grid(True, alpha=0.3)\n",
    "\n",
    "# legend\n",
    "handles = [\n",
    "    plt.Line2D([0], [0], marker='o', linestyle='',\n",
    "               color=color_map[m], label=model_names[m])\n",
    "    for m in models\n",
    "]\n",
    "ax.legend(handles=handles, title=\"Model\", fontsize=9, loc=\"lower left\")\n",
    "\n",
    "centroids = (\n",
    "    df_plot\n",
    "    .groupby(\"model_id\")\n",
    "    .agg(\n",
    "        frac_H_mean=(\"frac_H\", \"mean\"),\n",
    "        accuracy_mean=(\"accuracy\", \"mean\"),\n",
    "        f1_mean=(\"accuracy\", \"mean\"),\n",
    "    )\n",
    "    .reset_index()\n",
    ")\n",
    "\n",
    "ax.scatter(\n",
    "    centroids[\"frac_H_mean\"],\n",
    "    centroids[\"accuracy_mean\"],\n",
    "    c=[color_map[m] for m in centroids[\"model_id\"]],\n",
    "    s=180,\n",
    "    marker=\"X\",\n",
    "    edgecolor=\"black\",\n",
    "    linewidth=1.2,\n",
    "    zorder=5,\n",
    "    label=\"Model average\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37abaadf",
   "metadata": {},
   "outputs": [],
   "source": [
    "agg = (\n",
    "    results_lp_max\n",
    "    .groupby(\"model_id\")\n",
    "    .agg(\n",
    "        acc_mean=(\"accuracy\", \"mean\"),\n",
    "        acc_std=(\"accuracy\", \"std\"),\n",
    "        f1_mean=(\"f1\", \"mean\"),\n",
    "        f1_std=(\"f1\", \"std\"),\n",
    "\n",
    "        margin_mean=(\"mean_margin\", \"mean\"),\n",
    "        margin_std=(\"std_margin\", \"mean\"),\n",
    "\n",
    "        margin_abs_mean=(\"mean_abs_margin\", \"mean\"),\n",
    "        margin_abs_std=(\"std_abs_margin\", \"mean\"),\n",
    "    )\n",
    "    .reset_index()\n",
    ")\n",
    "\n",
    "agg[\"model\"] = agg[\"model_id\"].map(model_latextags)\n",
    "\n",
    "agg = agg[[\n",
    "    \"model\",\n",
    "    \"acc_mean\", \"acc_std\",\n",
    "    \"f1_mean\", \"f1_std\",\n",
    "    \"margin_mean\", \"margin_std\",\n",
    "    \"margin_abs_mean\", \"margin_abs_std\",\n",
    "]]\n",
    "\n",
    "def fmt_pct(mean, std):\n",
    "    return f\"{100*mean:.1f} ({100*std:.1f})\"\n",
    "\n",
    "def fmt_margin(mean, std):\n",
    "    return f\"{mean:.1f} ({std:.1f})\"\n",
    "\n",
    "table_df = pd.DataFrame({\n",
    "    \"Model\": agg[\"model\"],\n",
    "\n",
    "    \"Accuracy\": [\n",
    "        fmt_pct(m, s)\n",
    "        for m, s in zip(agg[\"acc_mean\"], agg[\"acc_std\"])\n",
    "    ],\n",
    "\n",
    "    \"F1\": [\n",
    "        fmt_pct(m, s)\n",
    "        for m, s in zip(agg[\"f1_mean\"], agg[\"f1_std\"])\n",
    "    ],\n",
    "\n",
    "    \"Margin\": [\n",
    "        fmt_margin(m, s)\n",
    "        for m, s in zip(agg[\"margin_mean\"], agg[\"margin_std\"])\n",
    "    ],\n",
    "\n",
    "    r\"$|\\mathrm{Margin}|$\": [\n",
    "        fmt_margin(m, s)\n",
    "        for m, s in zip(\n",
    "            agg[\"margin_abs_mean\"],\n",
    "            agg[\"margin_abs_std\"],\n",
    "        )\n",
    "    ],\n",
    "})\n",
    "\n",
    "latex_table = table_df.to_latex(\n",
    "    index=False,\n",
    "    escape=False,\n",
    "    column_format=\"lccc\",\n",
    ")\n",
    "\n",
    "print(latex_table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce7922e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "agg = (\n",
    "    results_lp_max\n",
    "    .groupby(\"model_id\")\n",
    "    .agg(\n",
    "        acc_mean=(\"accuracy\", \"mean\"),\n",
    "        acc_std=(\"accuracy\", \"std\"),\n",
    "\n",
    "        f1_mean=(\"f1\", \"mean\"),\n",
    "        f1_std=(\"f1\", \"std\"),\n",
    "\n",
    "        # signed margins by class\n",
    "        m0_mean=(\"mean_margin_class_False\", \"mean\"),\n",
    "        m0_std=(\"std_margin_class_False\", \"mean\"),\n",
    "\n",
    "        m1_mean=(\"mean_margin_class_True\", \"mean\"),\n",
    "        m1_std=(\"std_margin_class_True\", \"mean\"),\n",
    "\n",
    "        # absolute margins by class (optional but useful)\n",
    "        am0_mean=(\"mean_abs_margin_class_False\", \"mean\"),\n",
    "        am0_std=(\"std_abs_margin_class_False\", \"mean\"),\n",
    "\n",
    "        am1_mean=(\"mean_abs_margin_class_True\", \"mean\"),\n",
    "        am1_std=(\"std_abs_margin_class_True\", \"mean\"),\n",
    "    )\n",
    "    .reset_index()\n",
    ")\n",
    "\n",
    "agg[\"model\"] = agg[\"model_id\"].map(model_latextags)\n",
    "\n",
    "# agg = agg[[\n",
    "#     \"model\",\n",
    "#     \"acc_mean\", \"acc_std\",\n",
    "#     \"f1_mean\", \"f1_std\",\n",
    "#     \"margin_mean\", \"margin_std\",\n",
    "#     \"margin_abs_mean\", \"margin_abs_std\",\n",
    "# ]]\n",
    "\n",
    "def fmt_pct(mean, std):\n",
    "    return f\"{100*mean:.1f} ({100*std:.1f})\"\n",
    "\n",
    "def fmt_margin(mean, std):\n",
    "    return f\"{mean:.1f} ({std:.1f})\"\n",
    "\n",
    "table_df = pd.DataFrame({\n",
    "    \"Model\": agg[\"model\"],\n",
    "\n",
    "    \"Accuracy\": [\n",
    "        fmt_pct(m, s)\n",
    "        for m, s in zip(agg[\"acc_mean\"], agg[\"acc_std\"])\n",
    "    ],\n",
    "\n",
    "    \"F1\": [\n",
    "        fmt_pct(m, s)\n",
    "        for m, s in zip(agg[\"f1_mean\"], agg[\"f1_std\"])\n",
    "    ],\n",
    "\n",
    "    r\"$M\\,|\\,y=0$\": [\n",
    "        fmt_margin(m, s)\n",
    "        for m, s in zip(agg[\"m0_mean\"], agg[\"m0_std\"])\n",
    "    ],\n",
    "\n",
    "    r\"$M\\,|\\,y=1$\": [\n",
    "        fmt_margin(m, s)\n",
    "        for m, s in zip(agg[\"m1_mean\"], agg[\"m1_std\"])\n",
    "    ],\n",
    "\n",
    "    r\"$|M|\\,|\\,y=0$\": [\n",
    "        fmt_margin(m, s)\n",
    "        for m, s in zip(agg[\"am0_mean\"], agg[\"am0_std\"])\n",
    "    ],\n",
    "\n",
    "    r\"$|M|\\,|\\,y=1$\": [\n",
    "        fmt_margin(m, s)\n",
    "        for m, s in zip(agg[\"am1_mean\"], agg[\"am1_std\"])\n",
    "    ],\n",
    "})\n",
    "\n",
    "\n",
    "latex_table = table_df.to_latex(\n",
    "    index=False,\n",
    "    escape=False,\n",
    "    column_format=\"lcccccccc\",\n",
    ")\n",
    "\n",
    "\n",
    "print(latex_table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "011ac850",
   "metadata": {},
   "outputs": [],
   "source": [
    "agg = (\n",
    "    results_lp_max\n",
    "    .groupby(\"model_id\")\n",
    "    .agg(\n",
    "        acc_mean=(\"accuracy\", \"mean\"),\n",
    "        acc_std=(\"accuracy\", \"std\"),\n",
    "\n",
    "        f1_mean=(\"f1\", \"mean\"),\n",
    "        f1_std=(\"f1\", \"std\"),\n",
    "\n",
    "        # signed margins by class\n",
    "        m0_mean=(\"mean_margin_class_False\", \"mean\"),\n",
    "        m0_std=(\"std_margin_class_False\", \"mean\"),\n",
    "\n",
    "        m1_mean=(\"mean_margin_class_True\", \"mean\"),\n",
    "        m1_std=(\"std_margin_class_True\", \"mean\"),\n",
    "\n",
    "        # absolute margins by class (optional but useful)\n",
    "        am0_mean=(\"mean_abs_margin_class_False\", \"mean\"),\n",
    "        am0_std=(\"std_abs_margin_class_False\", \"mean\"),\n",
    "\n",
    "        am1_mean=(\"mean_abs_margin_class_True\", \"mean\"),\n",
    "        am1_std=(\"std_abs_margin_class_True\", \"mean\"),\n",
    "    )\n",
    "    .reset_index()\n",
    ")\n",
    "\n",
    "agg[\"model\"] = agg[\"model_id\"].map(model_latextags)\n",
    "\n",
    "# Ordering\n",
    "agg = agg.set_index(\"model_id\").loc[model_order].reset_index()\n",
    "\n",
    "# Add average row\n",
    "\n",
    "avg_row = {\n",
    "    \"model_id\": \"avg\",\n",
    "    \"model\": r\"\\textbf{Average}\",\n",
    "\n",
    "    \"acc_mean\": agg[\"acc_mean\"].mean(),\n",
    "    \"acc_std\": agg[\"acc_std\"].mean(),\n",
    "\n",
    "    \"f1_mean\": agg[\"f1_mean\"].mean(),\n",
    "    \"f1_std\": agg[\"f1_std\"].mean(),\n",
    "\n",
    "    \"m0_mean\": agg[\"m0_mean\"].mean(),\n",
    "    \"m0_std\": agg[\"m0_std\"].mean(),\n",
    "\n",
    "    \"m1_mean\": agg[\"m1_mean\"].mean(),\n",
    "    \"m1_std\": agg[\"m1_std\"].mean(),\n",
    "\n",
    "    \"am0_mean\": agg[\"am0_mean\"].mean(),\n",
    "    \"am0_std\": agg[\"am0_std\"].mean(),\n",
    "\n",
    "    \"am1_mean\": agg[\"am1_mean\"].mean(),\n",
    "    \"am1_std\": agg[\"am1_std\"].mean(),\n",
    "}\n",
    "\n",
    "agg = pd.concat([agg, pd.DataFrame([avg_row])], ignore_index=True)\n",
    "\n",
    "# decorator helper\n",
    "\n",
    "def rank_decor(values):\n",
    "    order = np.argsort(values)[::-1]\n",
    "    deco = [None] * len(values)\n",
    "    if len(values) > 0:\n",
    "        deco[order[0]] = \"bold\"\n",
    "    if len(values) > 1:\n",
    "        deco[order[1]] = \"underline\"\n",
    "    return deco\n",
    "\n",
    "def apply_deco(s, d):\n",
    "    if d == \"bold\":\n",
    "        return r\"\\textbf{\" + s + \"}\"\n",
    "    if d == \"underline\":\n",
    "        return r\"\\underline{\" + s + \"}\"\n",
    "    return s\n",
    "\n",
    "# Apply decorations\n",
    "\n",
    "mask = agg[\"model_id\"] != \"avg\"\n",
    "acc_deco  = rank_decor(agg.loc[mask, \"acc_mean\"].values)\n",
    "f1_deco   = rank_decor(agg.loc[mask, \"f1_mean\"].values)\n",
    "am0_deco  = rank_decor(agg.loc[mask, \"am0_mean\"].values)\n",
    "am1_deco  = rank_decor(agg.loc[mask, \"am1_mean\"].values)\n",
    "\n",
    "# formatters\n",
    "def fmt_pct(mean, std):\n",
    "    return f\"{100*mean:.1f} ({100*std:.1f})\"\n",
    "\n",
    "def fmt_margin(mean, std):\n",
    "    return f\"{mean:.1f} ({std:.1f})\"\n",
    "\n",
    "# Build table\n",
    "\n",
    "rows = []\n",
    "i = 0\n",
    "\n",
    "for _, r in agg.iterrows():\n",
    "\n",
    "    is_avg = r[\"model_id\"] == \"avg\"\n",
    "\n",
    "    rows.append({\n",
    "        \"Model\": r[\"model\"],\n",
    "\n",
    "        \"Accuracy\": apply_deco(\n",
    "            fmt_pct(r[\"acc_mean\"], r[\"acc_std\"]),\n",
    "            None if is_avg else acc_deco[i]\n",
    "        ),\n",
    "\n",
    "        \"F1\": apply_deco(\n",
    "            fmt_pct(r[\"f1_mean\"], r[\"f1_std\"]),\n",
    "            None if is_avg else f1_deco[i]\n",
    "        ),\n",
    "\n",
    "        # signed margins → NO highlighting\n",
    "        r\"$M\\,|\\,y=0$\": fmt_margin(r[\"m0_mean\"], r[\"m0_std\"]),\n",
    "        r\"$M\\,|\\,y=1$\": fmt_margin(r[\"m1_mean\"], r[\"m1_std\"]),\n",
    "\n",
    "        # absolute margins → highlight\n",
    "        r\"$|M|\\,|\\,y=0$\": apply_deco(\n",
    "            fmt_margin(r[\"am0_mean\"], r[\"am0_std\"]),\n",
    "            None if is_avg else am0_deco[i]\n",
    "        ),\n",
    "\n",
    "        r\"$|M|\\,|\\,y=1$\": apply_deco(\n",
    "            fmt_margin(r[\"am1_mean\"], r[\"am1_std\"]),\n",
    "            None if is_avg else am1_deco[i]\n",
    "        ),\n",
    "    })\n",
    "\n",
    "    if not is_avg:\n",
    "        i += 1\n",
    "\n",
    "# ... and export\n",
    "\n",
    "table_df = pd.DataFrame(rows)\n",
    "\n",
    "latex_table = table_df.to_latex(\n",
    "    index=False,\n",
    "    escape=False,\n",
    "    column_format=\"lcccccc\",\n",
    ")\n",
    "\n",
    "print(latex_table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b760b27",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc630bb2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.14.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
