{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34dd1199",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "import HALL_lib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b92d01f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['text.usetex'] = True\n",
    "plt.rcParams['text.latex.preamble'] = r'''\n",
    "\\usepackage{mathtools}\n",
    "% more packages here\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f762d3bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============== NEWDATASET ==============\n",
    "\n",
    "# cache_dir = \"cache-newdataset\"\n",
    "# maxResponsesPerPrompt = 500\n",
    "\n",
    "# model_names = {\n",
    "#     0 : \"Llama-8B\",\n",
    "#     1 : \"Phi-14B\",\n",
    "#     2 : \"Mistral-7B\",\n",
    "#     3 : \"Solar-11B\",\n",
    "#     4 : \"Gemma-9B\",\n",
    "#     5 : \"Qwen-15B\",\n",
    "#     6 : \"DeepSeek-7B\",\n",
    "#     7 : \"Yi-9B\"\n",
    "# }\n",
    "\n",
    "# ============== OLDDATASET ==============\n",
    "\n",
    "# cache_dir = \"cache-olddataset-20-22\"\n",
    "# maxResponsesPerPrompt=100\n",
    "\n",
    "# model_names = {\n",
    "#     0: 'Mistral-7B',\n",
    "#     1: 'Gemma-9B',\n",
    "#     2: 'Solar-11B',\n",
    "#     3: 'Phi-14B',\n",
    "#     4: 'Qwen-32B',\n",
    "#     5: 'Gemma-27B',\n",
    "#     6: 'Qwen-15B'\n",
    "# }\n",
    "\n",
    "# best_reg_lambda = 1.9\n",
    "\n",
    "# ============== OLDDATASET ==============\n",
    "\n",
    "embedding_label=\"response_embeddings\"\n",
    "cache_dir = \"cache-icml\"\n",
    "maxResponsesPerPrompt=150\n",
    "\n",
    "model_names = {\n",
    "    0: 'Mistral-7B',\n",
    "    1: 'Gemma-9B',\n",
    "    2: 'Solar-11B',\n",
    "    3: 'Phi-14B',\n",
    "    4: 'Qwen-15B',\n",
    "    5: 'Gemma-27B',\n",
    "    6: 'Qwen-32B',\n",
    "    7: \"DeepSeek-7B\",\n",
    "    8: \"Llama-8B\",\n",
    "    9: \"Yi-9B\"\n",
    "}\n",
    "\n",
    "best_reg_lambda = 1.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9921fbd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset_fn = \"../llm-hallucinations/olddataset.parquet\"\n",
    "dataset_fn = \"../llm-hallucinations/dataset_icml.parquet\"\n",
    "\n",
    "df = HALL_lib.loadParquet(dataset_fn, unifyYears=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a0b882f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "minutes = 0 #3 * 60\n",
    "\n",
    "for _ in tqdm(range(minutes), desc=\"Sleeping before execution\"):\n",
    "    time.sleep(60)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3649ba2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_ids = list(model_names.keys())\n",
    "\n",
    "prompt_ids_by_model = {\n",
    "    mid: sorted(df[df[\"model_id\"] == mid][\"prompt_id\"].unique())\n",
    "    for mid in model_ids\n",
    "}\n",
    "\n",
    "results_fisher = HALL_lib.run_full_label_propagation_study(\n",
    "    df=df,\n",
    "    model_ids=model_ids,\n",
    "    prompt_ids_by_model=prompt_ids_by_model,\n",
    "    projector_class=HALL_lib.FisherProjection,\n",
    "    projector_kwargs={\n",
    "        \"lambda_reg\": best_reg_lambda,\n",
    "        \"normalise\": True,\n",
    "        \"normalise_by_trace\": True,\n",
    "    },\n",
    "    train_fractions=None,\n",
    "    n_iter=100,\n",
    "    test_fraction=1/3,\n",
    "    n_splits=10,\n",
    "    use_cache=True,\n",
    "    ref_lambda_reg=best_reg_lambda,\n",
    "    overwrite_cache=False,\n",
    "    cache_dir=f\"{cache_dir}/PA_fisher\",\n",
    "    logskip=True,\n",
    "    embedding_label=embedding_label\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d916767c",
   "metadata": {},
   "outputs": [],
   "source": [
    "ncomps = [1, 2, 3, 5, 10, 15]\n",
    "methods = {\n",
    "    \"wPCA\" : HALL_lib.WhitenedPCAProjection,\n",
    "    \"EP\" : HALL_lib.RandomProjection,\n",
    "    \"UMAP\" : HALL_lib.SupervisedUMAPProjection\n",
    "}\n",
    "\n",
    "results = {}\n",
    "for method_label, method_cls in methods.items():\n",
    "    results[method_label] = {}\n",
    "    for ncomp in ncomps:\n",
    "        print(\"Running\", method_label, \"with\", ncomp, \"components.\")\n",
    "        args = {\n",
    "            \"n_components\": ncomp,\n",
    "            \"random_state\": 42,\n",
    "        }\n",
    "        if method_label == 'UMAP':\n",
    "            args[\"n_neighbors\"] = 10\n",
    "            args[\"min_dist\"] = 0.1\n",
    "        results[method_label][ncomp] = HALL_lib.run_full_label_propagation_study(\n",
    "            df=df,\n",
    "            model_ids=model_ids,\n",
    "            prompt_ids_by_model=prompt_ids_by_model,\n",
    "            projector_class=method_cls,\n",
    "            projector_kwargs=args,\n",
    "            train_fractions=None,\n",
    "            n_iter=100,\n",
    "            test_fraction=1/3,\n",
    "            ref_lambda_reg=best_reg_lambda,\n",
    "            n_splits=10,\n",
    "            use_cache=True,\n",
    "            cache_dir=f\"{cache_dir}/PA_{method_label}_{ncomp}\",\n",
    "            logskip=True,\n",
    "            embedding_label=embedding_label\n",
    "        )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5473a0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# results_umap_1 = run_full_label_propagation_study(\n",
    "#     df=df,\n",
    "#     model_ids=model_ids,\n",
    "#     prompt_ids_by_model=prompt_ids_by_model,\n",
    "#     projector_class=SupervisedUMAPProjection,\n",
    "#     projector_kwargs={\n",
    "#         \"n_components\": 1,\n",
    "#         \"n_neighbors\": 15,\n",
    "#         \"min_dist\": 0.1,\n",
    "#         \"random_state\": 42,\n",
    "#     },\n",
    "#     train_fractions=None,\n",
    "#     n_iter=5,  # UMAP is expensive\n",
    "#     test_fraction=0.2,\n",
    "#     ref_lambda_reg=best_reg_lambda,\n",
    "#     n_splits=5,\n",
    "#     use_cache=True,\n",
    "#     cache_dir=f\"{cache_dir}/PA_umap_1\",\n",
    "#     logskip=True\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "535426f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tag_results(df, method, n_components):\n",
    "    out = df.copy()\n",
    "    out[\"method\"] = method\n",
    "    out[\"n_components\"] = n_components\n",
    "    return out\n",
    "\n",
    "dfs = []\n",
    "\n",
    "dfs.append(tag_results(results_fisher, \"Fisher\", 1))\n",
    "\n",
    "for method_label in methods:\n",
    "    for ncomp in ncomps:\n",
    "        dfs.append(tag_results(results[method_label][ncomp], method_label, ncomp))\n",
    "\n",
    "df_all = pd.concat(dfs, ignore_index=True)\n",
    "\n",
    "# ====\n",
    "\n",
    "summary = (\n",
    "    df_all\n",
    "    .groupby([\"method\", \"n_components\"])\n",
    "    .agg(\n",
    "        acc_mean=(\"accuracy\", \"mean\"),\n",
    "        acc_std=(\"accuracy\", \"std\"),\n",
    "        f1_mean=(\"f1\", \"mean\"),\n",
    "        f1_std=(\"f1\", \"std\"),\n",
    "        margin_mean=(\"mean_margin\", \"mean\"),\n",
    "        agreement_mean=(\"agreement_fisher\", \"mean\"),\n",
    "        agreement_std=(\"agreement_fisher\", \"std\"),\n",
    "        agreement_conf_mean=(\"agreement_fisher_confident\", \"mean\"),\n",
    "        agreement_conf_std=(\"agreement_fisher_confident\", \"std\"),\n",
    "    )\n",
    "    .reset_index()\n",
    ")\n",
    "\n",
    "# ====\n",
    "\n",
    "plt.figure(figsize=(6, 4))\n",
    "\n",
    "for method in summary[\"method\"].unique():\n",
    "    sub = summary[summary[\"method\"] == method]\n",
    "    plt.errorbar(\n",
    "        sub[\"n_components\"],\n",
    "        sub[\"f1_mean\"],\n",
    "        yerr=sub[\"f1_std\"],\n",
    "        marker=\"o\",\n",
    "        label=method,\n",
    "        capsize=3,\n",
    "    )\n",
    "\n",
    "plt.xlabel(\"Number of components\")\n",
    "plt.ylabel(\"Accuracy\")\n",
    "plt.title(\"Label propagation accuracy vs embedding dimension\")\n",
    "plt.legend()\n",
    "plt.grid(alpha=0.3)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d85af242",
   "metadata": {},
   "outputs": [],
   "source": [
    "summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "038c8427",
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = summary[\"method\"].unique()\n",
    "methods = methods[[1,2,0,3]]\n",
    "\n",
    "offset_map = {\n",
    "    m: (i - (len(methods) - 1) / 2) * 0.15\n",
    "    for i, m in enumerate(methods)\n",
    "}\n",
    "\n",
    "marker_map = {\n",
    "    \"Fisher\": \"o\",\n",
    "    \"EP\": \"s\",\n",
    "    \"UMAP\": \"^\",\n",
    "    \"wPCA\": \"D\",\n",
    "}\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharex=True)\n",
    "\n",
    "# === Accuracy panel ===\n",
    "for method in methods:\n",
    "    sub = summary[summary[\"method\"] == method]\n",
    "    x = sub[\"n_components\"] + offset_map[method]\n",
    "\n",
    "    axes[0].errorbar(\n",
    "        x,\n",
    "        sub[\"acc_mean\"],\n",
    "        yerr=sub[\"acc_std\"],\n",
    "        marker=marker_map.get(method, \"o\"),\n",
    "        linestyle=\"-\",\n",
    "        capsize=3,\n",
    "        label=method,\n",
    "    )\n",
    "\n",
    "axes[0].set_ylabel(\"Accuracy\")\n",
    "axes[0].set_title(\"Label propagation accuracy\")\n",
    "axes[0].grid(alpha=0.3)\n",
    "\n",
    "\n",
    "# === F1 panel ===\n",
    "for method in methods:\n",
    "    sub = summary[summary[\"method\"] == method]\n",
    "    x = sub[\"n_components\"] + offset_map[method]\n",
    "\n",
    "    axes[1].errorbar(\n",
    "        x,\n",
    "        sub[\"f1_mean\"],\n",
    "        yerr=sub[\"f1_std\"],\n",
    "        marker=marker_map.get(method, \"o\"),\n",
    "        linestyle=\"-\",\n",
    "        capsize=3,\n",
    "        label=method,\n",
    "    )\n",
    "\n",
    "axes[1].set_ylabel(\"F1 score\")\n",
    "axes[1].set_title(\"Label propagation F1 score\")\n",
    "axes[1].grid(alpha=0.3)\n",
    "\n",
    "\n",
    "# === Shared formatting ===\n",
    "for ax in axes:\n",
    "    ax.set_xlabel(\"Number of components\")\n",
    "\n",
    "axes[1].legend(loc=\"lower right\", ncols=4)\n",
    "plt.tight_layout()\n",
    "fig.savefig(\"img/ProjectorsAnalysis.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c4fcfdf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b1cf21e",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_by_method = (\n",
    "    summary\n",
    "    .sort_values(\"acc_mean\", ascending=False)\n",
    "    .groupby(\"method\")\n",
    "    .first()\n",
    "    .reset_index()\n",
    ")\n",
    "\n",
    "table = best_by_method[\n",
    "    [\"method\", \"n_components\", \"acc_mean\", \"acc_std\", \"f1_mean\", \"f1_std\", \"margin_mean\", \"agreement_mean\", \"agreement_std\", \"agreement_conf_mean\", \"agreement_conf_std\"]\n",
    "]\n",
    "\n",
    "table = table.rename(columns={\n",
    "    \"n_components\": \"dim\",\n",
    "    \"acc_mean\": \"accuracy\",\n",
    "    \"acc_std\": \"std\",\n",
    "    \"margin_mean\": \"mean margin\",\n",
    "    \"f1_mean\": \"f1\",\n",
    "})\n",
    "\n",
    "table"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "639c52bb",
   "metadata": {},
   "source": [
    "The fact that Fisher outperforms UMAP while operating in one dimension strongly suggests that class-separating information is predominantly linear and aligned with between-class mean differences, rather than encoded in nonlinear local geometry."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00507010",
   "metadata": {},
   "outputs": [],
   "source": [
    "summary = (\n",
    "    df_all\n",
    "    .groupby([\"method\", \"n_components\"])\n",
    "    .agg(\n",
    "        # performance\n",
    "        acc_mean=(\"accuracy\", \"mean\"),\n",
    "        acc_std=(\"accuracy\", \"std\"),\n",
    "\n",
    "        f1_mean=(\"f1\", \"mean\"),\n",
    "        f1_std=(\"f1\", \"std\"),\n",
    "\n",
    "        # signed margin (already computed per run)\n",
    "        margin_mean=(\"mean_margin\", \"mean\"),\n",
    "        margin_std=(\"std_margin\", \"mean\"),\n",
    "\n",
    "        # absolute margin (already computed per run)\n",
    "        abs_margin_mean=(\"mean_abs_margin\", \"mean\"),\n",
    "        abs_margin_std=(\"std_abs_margin\", \"mean\"),\n",
    "\n",
    "        # Fisher agreement (across runs)\n",
    "        agreement_mean=(\"agreement_fisher\", \"mean\"),\n",
    "        agreement_std=(\"agreement_fisher\", \"std\"),\n",
    "    )\n",
    "    .reset_index()\n",
    ")\n",
    "\n",
    "# ordering\n",
    "\n",
    "method_order = [\"Fisher\", \"UMAP\", \"wPCA\", \"EP\"]\n",
    "\n",
    "summary[\"method\"] = pd.Categorical(\n",
    "    summary[\"method\"],\n",
    "    categories=method_order,\n",
    "    ordered=True,\n",
    ")\n",
    "\n",
    "summary = (\n",
    "    summary\n",
    "    .sort_values([\"method\", \"n_components\"])\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "\n",
    "# formatters\n",
    "\n",
    "def fmt_pct(mean, std):\n",
    "    return f\"{100*mean:.1f} ({100*std:.1f})\"\n",
    "\n",
    "def fmt_margin(mean, std):\n",
    "    return f\"{mean:.2f} ({std:.2f})\"\n",
    "\n",
    "# table\n",
    "\n",
    "rows = []\n",
    "\n",
    "for _, r in summary.iterrows():\n",
    "    rows.append({\n",
    "        \"Method\": f\"{r['method']} ({int(r['n_components'])})\",\n",
    "\n",
    "        \"Accuracy\": fmt_pct(r[\"acc_mean\"], r[\"acc_std\"]),\n",
    "        \"F1\": fmt_pct(r[\"f1_mean\"], r[\"f1_std\"]),\n",
    "\n",
    "        r\"$M$\": fmt_margin(r[\"margin_mean\"], r[\"margin_std\"]),\n",
    "        r\"$|M|$\": fmt_margin(r[\"abs_margin_mean\"], r[\"abs_margin_std\"]),\n",
    "\n",
    "        r\"$\\%\\,$Agree\": fmt_pct(\n",
    "            r[\"agreement_mean\"],\n",
    "            r[\"agreement_std\"],\n",
    "        ),\n",
    "    })\n",
    "\n",
    "# export\n",
    "\n",
    "table_df = pd.DataFrame(rows)\n",
    "\n",
    "latex_table = table_df.to_latex(\n",
    "    index=False,\n",
    "    escape=False,\n",
    "    column_format=\"lcccccc\",\n",
    ")\n",
    "\n",
    "print(latex_table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f0771c7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.14.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
