{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import tensor\n",
    "import pandas as pd\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from ast import literal_eval\n",
    "from py_experimenter.experimenter import PyExperimenter\n",
    "import torch\n",
    "import mysql.connector\n",
    "import openml\n",
    "\n",
    "\n",
    "from py_experimenter.database_connector_mysql import DatabaseConnectorMYSQL\n",
    "\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experimenter = PyExperimenter(\n",
    "    experiment_configuration_file_path=\"./experiments/config/cfg_simple_debug.yml\",\n",
    ")\n",
    "exp_frame = experimenter.get_table()\n",
    "# exp_frame = exp_frame[exp_frame.fraction_cal_samples >= 0.19999]\n",
    "# exp_frame = exp_frame[exp_frame.openml_id != 4534]\n",
    "# exp_frame = exp_frame[exp_frame.openml_id != 15]\n",
    "exp_frame = exp_frame[exp_frame.openml_id != 41]\n",
    "exp_frame = exp_frame[exp_frame.master_seed.isin(list(range(1,11)))]\n",
    "# exp_frame = exp_frame[exp_frame. != 41]\n",
    "# exp_frame = exp_frame[exp_frame.openml_id != 31]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "datasets = {}\n",
    "num_classes = {}\n",
    "dataset_renamer = {}\n",
    "for id in exp_frame.openml_id.unique():\n",
    "    try:\n",
    "        dataset = openml.datasets.get_dataset(id.item())\n",
    "        target_attribute = dataset.default_target_attribute\n",
    "        X, y, _, _ = dataset.get_data(target=target_attribute)\n",
    "        datasets[id] = dataset\n",
    "        dataset_renamer[id] = dataset.name\n",
    "        num_classes[id] = len(np.unique(y))\n",
    "    except:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "logtable = experimenter.get_logtable(\"results\")\n",
    "result_df = pd.merge(exp_frame, logtable, left_on=\"ID\", right_on=\"experiment_id\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_df = result_df[result_df.conformity_score.isin([\"rand_aps\",\"aps\", \"thr\", \"ranker_vanilla\"])] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_df.coverage_rate = result_df.coverage_rate.astype(float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_df.columns\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = result_df\n",
    "\n",
    "result_df[\"dataset\"] =  result_df[\"openml_id\"].replace(dataset_renamer)\n",
    "result_df[\"num_classes\"] = result_df[\"openml_id\"].replace(num_classes)\n",
    "\n",
    "result_df[\"dataset\"] = r\"$\\texttt{\" + result_df[\"dataset\"].astype(str) + \"}$\"\n",
    "result_df[\"dataset\"] = result_df[\"dataset\"] + \" (\" + result_df[\"num_classes\"].astype(str) + \")\"\n",
    "\n",
    "cs_renamer = {\"aps\":\"APS\",\n",
    "              \"rand_aps\": \"APS (rand)\",\n",
    "              \"thr\": \"LAC\",\n",
    "              \"ranker_vanilla\": \" \"}\n",
    "\n",
    "\n",
    "model_renamer = {\"ranker\":\"Ranker\",\n",
    "              \"classifier\": \"Classifier\"}\n",
    "\n",
    "\n",
    "result_df[\"conformity_score\"] =  result_df[\"conformity_score\"].replace(cs_renamer)\n",
    "result_df[\"model\"] =  result_df[\"model\"].replace(model_renamer)\n",
    "\n",
    "result_df[\"method\"] = result_df[\"model\"] + \" \" + result_df[\"conformity_score\"]\n",
    "# result_df = result_df.sort_values(by=\"num_classes\", ascending=True)\n",
    "# group_cols = ['dataset', 'alpha', 'fraction_cal_samples', 'model']\n",
    "group_cols = ['dataset', 'alpha', 'method']\n",
    "metrics_max = ['acc','coverage_rate']\n",
    "metrics_min = ['average_size']\n",
    "metrics = metrics_max + metrics_min\n",
    "\n",
    "grouped_df = result_df.groupby(group_cols).agg({metric: 'mean' for metric in metrics}).reset_index()\n",
    "\n",
    "\n",
    "# grouped_df = grouped_df.sort_values(by='num_classes', ascending=True)\n",
    "agg_dict = {metric: 'max' for metric in metrics_max}\n",
    "agg_dict.update({metric: 'min' for metric in metrics_min})\n",
    "# grouped_df = grouped_df[grouped_df.method.isin([\"classifier + APS\"])]\n",
    "best_values = grouped_df.groupby(['dataset', 'alpha']).agg(agg_dict).reset_index()\n",
    "# group_cols = ['dataset', 'alpha', 'method']\n",
    "# grouped_df = grouped_df.drop([\"num_classes\"], axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_latex_table_with_lines(df, group_cols, value_cols, best_values):\n",
    "    \"\"\"Generate a LaTeX table with multirow entries, bold the best values, and add lines between consecutive multirows.\"\"\"\n",
    "    \n",
    "    def format_value(value, group, metric, max=True):\n",
    "        \"\"\"Format the value, making the highest value bold.\"\"\"\n",
    "        max_value = best_values.loc[\n",
    "            (best_values['dataset'] == group['dataset']) & \n",
    "            (best_values['alpha'] == group['alpha']), metric].values[0]\n",
    "            # (best_values['fraction_cal_samples'] == group['fraction_cal_samples']), metric].values[0]\n",
    "        if value == max_value:\n",
    "            return f\"\\\\textbf{{{value:.4f}}}\"\n",
    "        else:\n",
    "            return f\"{value:.4f}\"\n",
    "    \n",
    "    def recursive_build(df, group_cols):\n",
    "        # if len(group_cols) == 1:\n",
    "            # Base case: only one group left, print it directly\n",
    "            latex_str = \"\"\n",
    "            prev_row = None\n",
    "            for row_id, row in df.iterrows():\n",
    "                latex_str_tmp = \"\"\n",
    "                col_id_list = []\n",
    "                for col_id, col in enumerate(group_cols):\n",
    "                    if prev_row is None or row[col] != prev_row[col]:\n",
    "                        col_id_list.append(col_id)\n",
    "                        filter_cols = group_cols[:col_id+1]\n",
    "                        indices = []\n",
    "                        values = []\n",
    "                        for fcol in filter_cols:\n",
    "                            indices.append(fcol)\n",
    "                            values.append(row[fcol])\n",
    "                        criteria = dict(zip(indices, values))\n",
    "                        # Filter the DataFrame using query\n",
    "                        mask = pd.Series([True] * len(df))\n",
    "\n",
    "                        # Iterate over criteria to apply conditions\n",
    "                        for key, value in criteria.items():\n",
    "                            mask = mask & (df[key] == value)\n",
    "\n",
    "                        # Filter the DataFrame\n",
    "                        filtered_df = df[mask]\n",
    "\n",
    "                        length = len(filtered_df)\n",
    "\n",
    "\n",
    "                        latex_str_tmp += f\" \\\\multirow{{{length}}}{{*}}{{{row[col]}}} & \"\n",
    "                        if col_id == len(group_cols) - 1:\n",
    "                            latex_str_tmp += \" & \".join([format_value(row[col], row, col) for col in value_cols]) + \" \\\\\\\\\\n\"\n",
    "\n",
    "                        # if prev_row is not None and col_id < len(group_cols) - 1:\n",
    "                        # latex_str_tmp += \" \\\\\\\\ \\cline{%d-%d} \\\\\\\\\" % (col_id + 1, len(group_cols) + len(value_cols))\n",
    "\n",
    "                    else:\n",
    "                        latex_str_tmp += \" & \"\n",
    "                        if col_id == len(group_cols) - 1:\n",
    "                            latex_str_tmp += \" & \".join([format_value(row[col], row, col) for col in value_cols]) + \" \\\\\\\\\\n\"\n",
    "                if row_id > 0 and group_cols[min(col_id_list)] != \"model\":\n",
    "                    latex_str += \"\\\\cline{%d-%d}\" % (min(col_id_list)+1, len(group_cols) + len(value_cols))\n",
    "                latex_str += latex_str_tmp\n",
    "                indices = []\n",
    "                values = []\n",
    "                prev_row = row\n",
    "            return latex_str\n",
    "        \n",
    "    # Start recursive building\n",
    "    latex_body = recursive_build(df, group_cols)\n",
    "\n",
    "    # Complete LaTeX table\n",
    "    num_columns = len(group_cols) + len(value_cols)\n",
    "    col_format = 'l' * len(group_cols) + 'r' * len(value_cols)\n",
    "    latex_table = f\"\"\"\n",
    "\\\\begin{{tabular}}{{{col_format}}}\n",
    "\\\\toprule\n",
    "{' & '.join(group_cols)} & {' & '.join(value_cols)} \\\\\\\\\n",
    "\\\\midrule\n",
    "{latex_body}\n",
    "\\\\bottomrule\n",
    "\\\\end{{tabular}}\n",
    "\"\"\"\n",
    "    return latex_table\n",
    "\n",
    "# Generate the LaTeX table with bold formatting and lines\n",
    "latex_table = generate_latex_table_with_lines(grouped_df, group_cols, metrics, best_values)\n",
    "\n",
    "# Print or save the LaTeX table\n",
    "with open('out_table.tex', 'w') as f:\n",
    "    print(latex_table.replace(\"_\", \"\\_\"), file=f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "# result_df = result_df[result_df.method != \"Classifier APS\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_df.alpha.unique()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.colors import Colormap\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "\n",
    "import seaborn as sns\n",
    "sns.set(font_scale=1.8,rc={'text.usetex' : True})\n",
    "sns.set_style(\"whitegrid\")\n",
    "plt.rc('font', **{'family': 'serif'})\n",
    "plt.rcParams[\"figure.figsize\"] = (12, 6)\n",
    "\n",
    "# Define columns\n",
    "group_cols = ['dataset', 'alpha', 'method']\n",
    "metrics_max = ['acc', 'coverage_rate']\n",
    "metrics_min = ['average_size']\n",
    "metrics = metrics_max + metrics_min\n",
    "\n",
    "result_df = result_df[result_df.method != \"Classifier APS\"] \n",
    "result_df = result_df[~result_df.openml_id.isin([182,307,11])] \n",
    "alpha_df = result_df[result_df.alpha.isin([0.02,0.05,0.1, 0.2])]\n",
    "\n",
    "\n",
    "# Melt the original result_df to long format for seaborn\n",
    "melted_df = alpha_df.melt(id_vars=group_cols, value_vars=metrics, var_name='metric', value_name='value')\n",
    "melted_df.replace({\"metric\": {\"acc\": \"Accuracy\", \"coverage_rate\": \"Coverage Rate\", \"average_size\": \"Average Set Size\"}}, inplace=True)\n",
    "\n",
    "\n",
    "# Create FacetGrid with a row per metric\n",
    "g = sns.FacetGrid(\n",
    "    melted_df, col=\"dataset\", row=\"metric\", margin_titles=True, \n",
    "    despine=False, sharey=False #height=4, aspect=2.4\n",
    ")\n",
    "\n",
    "# for ax in g.axes.flat:\n",
    "#     ax.set_xticklabels(ax.get_xticklabels(), rotation=0, visible=True)\n",
    "g.set_titles(col_template=\"{col_name}\", row_template=\"{row_name}\")\n",
    "\n",
    "# Map boxplots with hue for method separation\n",
    "g.map_dataframe(sns.boxplot, x=\"alpha\", y=\"value\", hue=\"method\", dodge=True, showfliers=False, palette=\"Set2\")\n",
    "g.set_axis_labels(r\"$\\alpha$\", \" \")\n",
    "\n",
    "# # Add a legend outside the plot\n",
    "# g.add_legend(title=\"Method\", ncols=4)\n",
    "# g._legend.set_bbox_to_anchor((0.5, -0.05))\n",
    "handles, labels = g.axes[0, 0].get_legend_handles_labels()\n",
    "g.fig.legend(handles, labels, loc='upper center', ncol=3, bbox_to_anchor=(0.5, 0.07), frameon=False)\n",
    "\n",
    "for ax in g.axes.flat:\n",
    "    for label in ax.get_xticklabels():\n",
    "        label.set_rotation(90)\n",
    "# \n",
    "#   # Move legend under plot\n",
    "# g._legend.set_frame_on(False)  # Remove legend frame\n",
    "# g._legend.set_title(\"Method\")  # Set title\n",
    "plt.subplots_adjust(bottom=0.15)\n",
    "# plt.tight_layout()\n",
    "plt.savefig(\"boxplots.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_df[[\"dataset\", \"openml_id\"]].drop_duplicates()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "plnet",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
