{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from scipy import stats\n",
    "from tabulate import tabulate\n",
    "\n",
    "from pkg.utils.analyze import summarize_outputs_in_df\n",
    "\n",
    "OUTPUT_PATH_KTST_SET_DENSE = \"/home/knowledge-tracing/ktst/benchmark/eval\"\n",
    "OUTPUT_PATH_KTST_SET_DENSE_WO_INIT = \"/home/knowledge-tracing/ktst/ablation_initialization/eval\"\n",
    "\n",
    "DATASETS = {\n",
    "    \"ednet\": \"Ednet\",\n",
    "    \"algebra2005\": \"AL2005\",\n",
    "    \"assist2009\": \"AS2009\",\n",
    "    \"nips_task34\": \"NIPS34\",\n",
    "    \"bridge2algebra2006\": \"BD2006\",\n",
    "    # \"statics2011\": \"Statics2011\",\n",
    "    # \"assist2015\": \"AS2015\",\n",
    "    # \"poj\": \"POJ\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_auc = pd.DataFrame(data=[], columns=[\"Index\", \"Model\"]).set_index(\"Index\")\n",
    "df_acc = pd.DataFrame(data=[], columns=[\"Index\", \"Model\"]).set_index(\"Index\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Add model without 0-init"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_wo_init = summarize_outputs_in_df(outputs_dir=OUTPUT_PATH_KTST_SET_DENSE_WO_INIT)\n",
    "assert len(df_wo_init) == 40\n",
    "\n",
    "out_auc, out_acc = {\"Model\": \"KTST (w/o 0-init)\"}, {\"Model\": \"KTST (w/o 0-init)\"}\n",
    "for k, v in DATASETS.items():\n",
    "    _df = df_wo_init[df_wo_init[\"data.dataset\"] == k].sort_values(\n",
    "        by=[\"data.dataset\", \"data.val_fold_idx\"]\n",
    "    )\n",
    "    out_auc |= {v: _df[\"test_auc\"].values}\n",
    "    out_acc |= {v: _df[\"test_accuracy\"].values}\n",
    "df_auc = pd.concat(\n",
    "    [df_auc, pd.DataFrame.from_dict({\"ktst_wo_init\": out_auc}, orient=\"index\")]\n",
    ")\n",
    "df_acc = pd.concat(\n",
    "    [df_acc, pd.DataFrame.from_dict({\"ktst_wo_init\": out_acc}, orient=\"index\")]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Add model with 0-init"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_sota = summarize_outputs_in_df(outputs_dir=f\"{OUTPUT_PATH_KTST_SET_DENSE}\")\n",
    "df_w_init = df_sota[\n",
    "    (df_sota[\"data.format\"] == \"set_dense\")\n",
    "    & (df_sota[\"model.aggregation\"] == \"q_mean_c\")\n",
    "]\n",
    "assert len(df_w_init) == 40\n",
    "\n",
    "out_auc, out_acc = {\"Model\": \"KTST\"}, {\"Model\": \"KTST\"}\n",
    "for k, v in DATASETS.items():\n",
    "    _df = df_w_init[df_w_init[\"data.dataset\"] == k].sort_values(\n",
    "        by=[\"data.dataset\", \"data.val_fold_idx\"]\n",
    "    )\n",
    "    out_auc |= {v: _df[\"test_auc\"].values}\n",
    "    out_acc |= {v: _df[\"test_accuracy\"].values}\n",
    "df_auc = pd.concat(\n",
    "    [df_auc, pd.DataFrame.from_dict({\"ktst_w_init\": out_auc}, orient=\"index\")]\n",
    ")\n",
    "df_acc = pd.concat(\n",
    "    [df_acc, pd.DataFrame.from_dict({\"ktst_w_init\": out_acc}, orient=\"index\")]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute columns of results table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "for d in (df_auc, df_acc):\n",
    "    for k, v in DATASETS.items():\n",
    "        f_mean = lambda x: np.array(x).mean() if len(x) > 0 else None\n",
    "        f_std = lambda x: np.array(x).std() if len(x) > 0 else None\n",
    "        _f_sig = lambda a, b: (\n",
    "            (\n",
    "                \"$\\circ$\"\n",
    "                if stats.ttest_rel(a, b).pvalue > 0.01\n",
    "                else (\"$\\\\ast$\" if (f_mean(b) - f_mean(a) > 0) else \"$\\\\bullet$\")\n",
    "            )\n",
    "            if (a != b).all()  # if the same entry is compared\n",
    "            else \"$\\\\; \\\\: $\"\n",
    "        )\n",
    "        f_sig = lambda x: (\n",
    "            _f_sig(x, d.iloc[d.index == \"ktst_w_init\"][v].values[0])\n",
    "            if (len(x) > 0)\n",
    "            else \"\"\n",
    "        )\n",
    "        d[f\"{v}_mean\"] = d[v].apply(f_mean)\n",
    "        d[f\"{v}_std\"] = d[v].apply(f_std)\n",
    "        d[f\"{v}_sig\"] = d[v].apply(f_sig)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Display results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>Ednet_mean</th>\n",
       "      <th>Ednet_std</th>\n",
       "      <th>AL2005_mean</th>\n",
       "      <th>AL2005_std</th>\n",
       "      <th>AS2009_mean</th>\n",
       "      <th>AS2009_std</th>\n",
       "      <th>NIPS34_mean</th>\n",
       "      <th>NIPS34_std</th>\n",
       "      <th>BD2006_mean</th>\n",
       "      <th>BD2006_std</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>ktst_wo_init</th>\n",
       "      <td>KTST (w/o 0-init)</td>\n",
       "      <td>0.709168</td>\n",
       "      <td>0.001305</td>\n",
       "      <td>0.823917</td>\n",
       "      <td>0.000763</td>\n",
       "      <td>0.734625</td>\n",
       "      <td>0.001231</td>\n",
       "      <td>0.735153</td>\n",
       "      <td>0.000675</td>\n",
       "      <td>0.857201</td>\n",
       "      <td>0.000289</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ktst_w_init</th>\n",
       "      <td>KTST</td>\n",
       "      <td>0.715404</td>\n",
       "      <td>0.001112</td>\n",
       "      <td>0.828660</td>\n",
       "      <td>0.000546</td>\n",
       "      <td>0.749032</td>\n",
       "      <td>0.001292</td>\n",
       "      <td>0.735552</td>\n",
       "      <td>0.000276</td>\n",
       "      <td>0.860820</td>\n",
       "      <td>0.000530</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                          Model  Ednet_mean  Ednet_std  AL2005_mean  \\\n",
       "ktst_wo_init  KTST (w/o 0-init)    0.709168   0.001305     0.823917   \n",
       "ktst_w_init                KTST    0.715404   0.001112     0.828660   \n",
       "\n",
       "              AL2005_std  AS2009_mean  AS2009_std  NIPS34_mean  NIPS34_std  \\\n",
       "ktst_wo_init    0.000763     0.734625    0.001231     0.735153    0.000675   \n",
       "ktst_w_init     0.000546     0.749032    0.001292     0.735552    0.000276   \n",
       "\n",
       "              BD2006_mean  BD2006_std  \n",
       "ktst_wo_init     0.857201    0.000289  \n",
       "ktst_w_init      0.860820    0.000530  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_acc[sum([[f\"{k}_mean\", f\"{k}_std\"] for k in DATASETS.values()], [\"Model\"])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>Ednet_mean</th>\n",
       "      <th>Ednet_std</th>\n",
       "      <th>AL2005_mean</th>\n",
       "      <th>AL2005_std</th>\n",
       "      <th>AS2009_mean</th>\n",
       "      <th>AS2009_std</th>\n",
       "      <th>NIPS34_mean</th>\n",
       "      <th>NIPS34_std</th>\n",
       "      <th>BD2006_mean</th>\n",
       "      <th>BD2006_std</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>ktst_wo_init</th>\n",
       "      <td>KTST (w/o 0-init)</td>\n",
       "      <td>0.725120</td>\n",
       "      <td>0.003174</td>\n",
       "      <td>0.842516</td>\n",
       "      <td>0.000603</td>\n",
       "      <td>0.776194</td>\n",
       "      <td>0.000985</td>\n",
       "      <td>0.806659</td>\n",
       "      <td>0.000301</td>\n",
       "      <td>0.814193</td>\n",
       "      <td>0.000760</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ktst_w_init</th>\n",
       "      <td>KTST</td>\n",
       "      <td>0.739416</td>\n",
       "      <td>0.000234</td>\n",
       "      <td>0.852205</td>\n",
       "      <td>0.000450</td>\n",
       "      <td>0.799346</td>\n",
       "      <td>0.001224</td>\n",
       "      <td>0.807055</td>\n",
       "      <td>0.000040</td>\n",
       "      <td>0.826365</td>\n",
       "      <td>0.000359</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                          Model  Ednet_mean  Ednet_std  AL2005_mean  \\\n",
       "ktst_wo_init  KTST (w/o 0-init)    0.725120   0.003174     0.842516   \n",
       "ktst_w_init                KTST    0.739416   0.000234     0.852205   \n",
       "\n",
       "              AL2005_std  AS2009_mean  AS2009_std  NIPS34_mean  NIPS34_std  \\\n",
       "ktst_wo_init    0.000603     0.776194    0.000985     0.806659    0.000301   \n",
       "ktst_w_init     0.000450     0.799346    0.001224     0.807055    0.000040   \n",
       "\n",
       "              BD2006_mean  BD2006_std  \n",
       "ktst_wo_init     0.814193    0.000760  \n",
       "ktst_w_init      0.826365    0.000359  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_auc[sum([[f\"{k}_mean\", f\"{k}_std\"] for k in DATASETS.values()], [\"Model\"])]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Format results to LaTeX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_latex(df):\n",
    "    f = lambda x, m, n: (\n",
    "        (\"\" if (x == m) else \"\")\n",
    "        + (\"\" if (x == n) else \"\")\n",
    "        + \"{0:.4f}\".format(x)\n",
    "        + (\"\" if (x == m) else \"\")\n",
    "        + (\"\" if (x == n) else \"\")\n",
    "    )\n",
    "    g = lambda x: \"{0:.4f}\".format(x)\n",
    "    for dataset in DATASETS.values():\n",
    "        max_val, sec_max_val = df[f\"{dataset}_mean\"].nlargest(2)\n",
    "        df[dataset] = df.apply(\n",
    "            lambda x: (\n",
    "                f(x[f\"{dataset}_mean\"], max_val, sec_max_val)\n",
    "                + \" $\\pm$ \"\n",
    "                + g(x[f\"{dataset}_std\"])\n",
    "                + f\" {x[f'{dataset}_sig']}\"\n",
    "                if pd.notnull(x[f\"{dataset}_mean\"])\n",
    "                else \"---\"\n",
    "            ),\n",
    "            axis=1,\n",
    "        )\n",
    "    df_out = df[[\"Model\"] + list(DATASETS.values())]\n",
    "    df_out[\"Model\"] = df_out[\"Model\"].apply(lambda x: f\"\\\\textbf{{{x}}}\")\n",
    "    df_out = df_out.set_index(\"Model\")\n",
    "\n",
    "    headers = [f\"\\\\textbf{{{d}}}\" for d in DATASETS.values()]\n",
    "    print(\n",
    "        tabulate(\n",
    "            df_out,\n",
    "            headers=headers,\n",
    "            tablefmt=\"latex_raw\",\n",
    "            colalign=[\"right\"] + [\"center\"] * len(DATASETS.values()),\n",
    "        )\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{rccccc}\n",
      "\\hline\n",
      "                            &        \\textbf{Ednet}        &       \\textbf{AL2005}        &       \\textbf{AS2009}        &       \\textbf{NIPS34}        &       \\textbf{BD2006}        \\\\\n",
      "\\hline\n",
      " \\textbf{KTST (w/o 0-init)} &  0.7251 $\\pm$ 0.0032 $\\ast$  &  0.8425 $\\pm$ 0.0006 $\\ast$  &  0.7762 $\\pm$ 0.0010 $\\ast$  & 0.8067 $\\pm$ 0.0003 $\\circ$  &  0.8142 $\\pm$ 0.0008 $\\ast$  \\\\\n",
      "              \\textbf{KTST} & 0.7394 $\\pm$ 0.0002 $\\; \\: $ & 0.8522 $\\pm$ 0.0004 $\\; \\: $ & 0.7993 $\\pm$ 0.0012 $\\; \\: $ & 0.8071 $\\pm$ 0.0000 $\\; \\: $ & 0.8264 $\\pm$ 0.0004 $\\; \\: $ \\\\\n",
      "\\hline\n",
      "\\end{tabular}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2852133/716310536.py:24: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df_out[\"Model\"] = df_out[\"Model\"].apply(lambda x: f\"\\\\textbf{{{x}}}\")\n"
     ]
    }
   ],
   "source": [
    "print_latex(df_auc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{rccccc}\n",
      "\\hline\n",
      "                            &        \\textbf{Ednet}        &       \\textbf{AL2005}        &       \\textbf{AS2009}        &       \\textbf{NIPS34}        &       \\textbf{BD2006}        \\\\\n",
      "\\hline\n",
      " \\textbf{KTST (w/o 0-init)} &  0.7092 $\\pm$ 0.0013 $\\ast$  &  0.8239 $\\pm$ 0.0008 $\\ast$  &  0.7346 $\\pm$ 0.0012 $\\ast$  & 0.7352 $\\pm$ 0.0007 $\\circ$  &  0.8572 $\\pm$ 0.0003 $\\ast$  \\\\\n",
      "              \\textbf{KTST} & 0.7154 $\\pm$ 0.0011 $\\; \\: $ & 0.8287 $\\pm$ 0.0005 $\\; \\: $ & 0.7490 $\\pm$ 0.0013 $\\; \\: $ & 0.7356 $\\pm$ 0.0003 $\\; \\: $ & 0.8608 $\\pm$ 0.0005 $\\; \\: $ \\\\\n",
      "\\hline\n",
      "\\end{tabular}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2852133/716310536.py:24: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df_out[\"Model\"] = df_out[\"Model\"].apply(lambda x: f\"\\\\textbf{{{x}}}\")\n"
     ]
    }
   ],
   "source": [
    "print_latex(df_acc)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ktst",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
