{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from scipy import stats\n",
    "from tabulate import tabulate\n",
    "\n",
    "from pkg.utils.analyze import summarize_outputs_in_df\n",
    "\n",
    "eval_path = \"/home/knowledge-tracing/ktst/ablation_attention_assist2009/eval\"\n",
    "\n",
    "ATTENTION_VARIANTS = {\n",
    "    \"standard False\": \"Standard MHA + PE (q$\\\\neq$k)\",\n",
    "    \"akt_monotonic False\": \"AKT (q$=$k)\",\n",
    "    \"alibi_monotonic False\": \"ALiBi (q$\\\\neq$k)\",\n",
    "    \"alibi_monotonic_q_k False\": \"ALiBi (q$=$k)\",\n",
    "    \"learnable_alibi_monotonic False\": \"Learnable ALiBi (q$\\\\neq$k)\",\n",
    "    \"learnable_alibi_monotonic_q_k False\": \"Learnable ALiBi (q$=$k)\",\n",
    "    \"learnable_alibi_monotonic_q_k True\": \"Learnable ALiBi (q$=$k) decoder-only\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2884083/3539417821.py:4: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`\n",
      "  ].apply(lambda x: ATTENTION_VARIANTS[f\"{x[0]} {x[1]}\"], axis=1)\n"
     ]
    }
   ],
   "source": [
    "df_eval = summarize_outputs_in_df(outputs_dir=eval_path)\n",
    "df_eval[\"attention_variant\"] = df_eval[\n",
    "    [\"model.attn_variant\", \"model.use_decoder_only\"]\n",
    "].apply(lambda x: ATTENTION_VARIANTS[f\"{x[0]} {x[1]}\"], axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = {}\n",
    "for k, v in ATTENTION_VARIANTS.items():\n",
    "    _df = df_eval[df_eval[\"attention_variant\"] == v].sort_values(\n",
    "        by=[\"data.val_fold_idx\"]\n",
    "    )\n",
    "    out |= {\n",
    "        k: {\n",
    "            \"Attention mechanism\": v,\n",
    "            \"auc\": _df[\"test_auc\"].values,\n",
    "            \"acc\": _df[\"test_accuracy\"].values,\n",
    "        }\n",
    "    }\n",
    "\n",
    "df = pd.DataFrame.from_dict(out, orient=\"index\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "for k, v in {\"auc\": \"AUC\", \"acc\": \"ACC\"}.items():\n",
    "    f_mean = lambda x: np.array(x).mean() if len(x) > 0 else None\n",
    "    f_std = lambda x: np.array(x).std() if len(x) > 0 else None\n",
    "    _f_sig = lambda a, b: (\n",
    "        (\n",
    "            \"$\\circ$\"\n",
    "            if stats.ttest_rel(a, b).pvalue > 0.01\n",
    "            else (\"$\\\\ast$\" if (f_mean(b) - f_mean(a) > 0) else \"$\\\\bullet$\")\n",
    "        )\n",
    "        if (a != b).all()  # if the same entry is compared\n",
    "        else \"$\\\\; \\\\: $\"\n",
    "    )\n",
    "    f_sig = lambda x: (\n",
    "        _f_sig(x, df.iloc[df.index == \"learnable_alibi_monotonic_q_k False\"][k].values[0])\n",
    "        if (len(x) > 0)\n",
    "        else \"\"\n",
    "    )\n",
    "    df[f\"{v}_mean\"] = df[k].apply(f_mean)\n",
    "    df[f\"{v}_std\"] = df[k].apply(f_std)\n",
    "    df[f\"{v}_sig\"] = df[k].apply(f_sig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_latex(df):\n",
    "    f = lambda x, m, n: (\n",
    "        (\"\\\\textbf{\" if (x == m) else \"\")\n",
    "        + (\"\\\\underline{\" if (x == n) else \"\")\n",
    "        + \"{0:.4f}\".format(x)\n",
    "        + (\"}\" if (x == m) else \"\")\n",
    "        + (\"}\" if (x == n) else \"\")\n",
    "    )\n",
    "    g = lambda x: \"{0:.4f}\".format(x)\n",
    "    for metric in (\"AUC\", \"ACC\"):\n",
    "        max_val, sec_max_val = df[f\"{metric}_mean\"].nlargest(2)\n",
    "        df[metric] = df.apply(\n",
    "            lambda x: (\n",
    "                f(x[f\"{metric}_mean\"], max_val, sec_max_val)\n",
    "                + \" $\\pm$ \"\n",
    "                + g(x[f\"{metric}_std\"])\n",
    "                + f\" {x[f'{metric}_sig']}\"\n",
    "                if pd.notnull(x[f\"{metric}_mean\"])\n",
    "                else \"---\"\n",
    "            ),\n",
    "            axis=1,\n",
    "        )\n",
    "    df_out = df[[\"Attention mechanism\", \"AUC\", \"ACC\"]]\n",
    "    df_out = df_out.set_index(\"Attention mechanism\")\n",
    "\n",
    "    headers = [f\"\\\\textbf{{{d}}}\" for d in (\"Attention mechanism\",\"AUC\", \"ACC\")]\n",
    "    print(\n",
    "        tabulate(\n",
    "            df_out,\n",
    "            headers=headers,\n",
    "            tablefmt=\"latex_raw\",\n",
    "            colalign=[\"right\"] + [\"center\"] * len((\"AUC\", \"ACC\")),\n",
    "        )\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{rcc}\n",
      "\\hline\n",
      "         \\textbf{Attention mechanism} &              \\textbf{AUC}              &              \\textbf{ACC}               \\\\\n",
      "\\hline\n",
      "         Standard MHA + PE (q$\\neq$k) &       0.7744 $\\pm$ 0.0014 $\\ast$       &       0.7307 $\\pm$ 0.0015 $\\ast$        \\\\\n",
      "                          AKT (q$=$k) &       0.7958 $\\pm$ 0.0011 $\\ast$       &       0.7464 $\\pm$ 0.0009 $\\ast$        \\\\\n",
      "                     ALiBi (q$\\neq$k) &       0.7953 $\\pm$ 0.0018 $\\ast$       &       0.7456 $\\pm$ 0.0021 $\\circ$       \\\\\n",
      "                        ALiBi (q$=$k) & \\underline{0.7978} $\\pm$ 0.0010 $\\ast$ & \\underline{0.7479} $\\pm$ 0.0007 $\\circ$ \\\\\n",
      "           Learnable ALiBi (q$\\neq$k) &      0.7976 $\\pm$ 0.0013 $\\circ$       &       0.7476 $\\pm$ 0.0012 $\\circ$       \\\\\n",
      "              Learnable ALiBi (q$=$k) & \\textbf{0.7993} $\\pm$ 0.0012 $\\; \\: $  &  \\textbf{0.7490} $\\pm$ 0.0013 $\\; \\: $  \\\\\n",
      " Learnable ALiBi (q$=$k) decoder-only &       0.7977 $\\pm$ 0.0011 $\\ast$       &       0.7473 $\\pm$ 0.0007 $\\circ$       \\\\\n",
      "\\hline\n",
      "\\end{tabular}\n"
     ]
    }
   ],
   "source": [
    "print_latex(df)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dev",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
