{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import copy \n",
    "import os\n",
    "import re\n",
    "from collections import OrderedDict\n",
    "from pprint import pprint\n",
    "import itertools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_models = ['cifar10', 'mnist', 'twenty_newsgroups','tinyimagenet100']\n",
    "dm_identifier = {'cifar10': 'CIFAR-10', \n",
    "                'mnist': 'MNIST', \n",
    "                'twenty_newsgroups': '20newsgroups',\n",
    "                'tinyimagenet100': 'TinyImageNet-100'}\n",
    "sub_metrics = ['Error', 'Coverage']\n",
    "\n",
    "def helper_find_files_and_read_dataframe(root_path, patterns=[r\"cifar10\"]):\n",
    "    #dm_df= OrderedDict((dm, None) for dm in dm_identifier.keys()) \n",
    "    dm_df= OrderedDict((dm, None) for dm in dm_identifier.keys()) \n",
    "    for root, _, files in os.walk(root_path):\n",
    "        for filename in files:\n",
    "            filepath = os.path.join(root, filename)\n",
    "            if os.path.isfile(filepath) and filepath.endswith(\".xlsx\"):\n",
    "                for pattern in patterns:\n",
    "                    if re.search(pattern, filename, re.IGNORECASE):\n",
    "                        dm_df[pattern] = pd.read_excel(filepath, sheet_name=0).drop(columns=['Unnamed: 0']).copy(deep=True)\n",
    "                        # dm_df[pattern].append(pd.read_excel(filepath, sheet_name=0).drop(columns=['Unnamed: 0']).copy(deep=True))\n",
    "                        break  # Stop checking patterns for this file\n",
    "    return dm_df \n",
    "\n",
    "def read_and_get_filtered_dataframes(root_path, patterns=[r\"cifar10\"]):\n",
    "    dm_df = helper_find_files_and_read_dataframe(root_path, patterns = patterns)\n",
    "    # Apply filter to all dataframes\n",
    "    for dm, df in dm_df.items():\n",
    "        if df is None:\n",
    "            continue\n",
    "        df1 = copy.copy(df) # Shallow copy to new dataframe\n",
    "        df1['calib_conf'] = df1['calib_conf'].fillna(\"None\")\n",
    "        df1['calib_conf'] = df1['calib_conf'].astype(str)\n",
    "\n",
    "        # Sort by col: Coverage-Mean in descending order, and then by col: calib_conf in ascending order\n",
    "        df2 = df1.sort_values([\"Coverage-Mean\", \"calib_conf\"], ascending = [False, True]).copy(deep=True)\n",
    "        # Retain the first row for each unique value in col: calib_conf\n",
    "        df3 = df2.drop_duplicates(subset=['calib_conf'], keep='first').copy(deep=True)\n",
    "        dm_df[dm] = df3\n",
    "    return dm_df \n",
    "\n",
    "dm_df = read_and_get_filtered_dataframes(\n",
    "    root_path = \"../outputs/final_results\", \n",
    "    patterns = dm_identifier.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "cms_ = OrderedDict({'None': '-',\n",
    "                    'auto_label_opt_v0': 'Ours' ,\n",
    "                    'temp_scaling': 'TS',\n",
    "                    'dirichlet': 'Dirichlet',\n",
    "                    'scaling_binning': 'SB',\n",
    "                    'histogram_binning_top_label': 'Top-HB'})\n",
    "ttms_ = OrderedDict({'std_cross_entropy': 'Vanilla',\n",
    "                     'crl': 'CRL', \n",
    "                     'fmfp': 'FMFP', \n",
    "                     'squentropy': 'Squentropy'})\n",
    "visited = []\n",
    "body_txt= \"\"\n",
    "bs = \"\\\\\"\n",
    "num_dp = 2\n",
    "global_font_size = (8,11)\n",
    "std_font_size = (6, 11)\n",
    "for tm, cm in itertools.product(ttms_.keys(), cms_.keys()):\n",
    "    # Add post-hoc method name \n",
    "    # cross_prod_i = cm.replace(\"_\", \"\\\\_\")\n",
    "    cross_prod_i = cms_[cm]\n",
    "    if tm not in visited:\n",
    "        #temp_tm = tm.replace(\"_\", \"\\\\_\")\n",
    "        temp_tm = ttms_[tm]\n",
    "        cross_prod_i = rf\"\"\"\\multirow{{6}}{{*}}{{{temp_tm}}}                     & \"\"\" + cross_prod_i \n",
    "        visited.append(tm)\n",
    "    else:\n",
    "        cross_prod_i = \" \".join([\"                                 & \", cross_prod_i]) \n",
    "\n",
    "    # For each dataset, add columns for Error and Coverage \n",
    "    for dm, df in dm_df.items():\n",
    "        if df is not None:\n",
    "            mask1 = (df[\"calib_conf\"] == f\"{cm}\") & (df[\"training_conf\"] == f\"{tm}\")\n",
    "            al_mean = df[mask1]['Auto-Labeling-Err-Mean'].values[0] if not df[mask1]['Auto-Labeling-Err-Mean'].empty else -1 \n",
    "            al_std = df[mask1]['Auto-Labeling-Err-Std'].values[0] if not df[mask1]['Auto-Labeling-Err-Std'].empty else -1 \n",
    "            c_mean = df[mask1]['Coverage-Mean'].values[0] if not df[mask1]['Coverage-Mean'].empty else -1 \n",
    "            c_std= df[mask1]['Coverage-Std'].values[0] if not df[mask1]['Coverage-Std'].empty else -1 \n",
    "        else:\n",
    "            al_mean, al_std, c_mean, c_std = -1, -1, -1, -1 \n",
    "        open_std_font = \"{\" + f\"{bs}fontsize{{{std_font_size[0]}}}{{{std_font_size[1]}}}{bs}selectfont\"\n",
    "        closing_std_font = \"}\" \n",
    "        cross_prod_i = cross_prod_i + \" & \" + f\"\"\" { rf\"{al_mean:.{num_dp}f}\" + rf\" ${bs}pm$ \" + open_std_font + rf\"{al_std:.{num_dp}f}\" } \"\"\" + closing_std_font + \" & \" + f\"\"\" { rf\"{c_mean:.{num_dp}f}\" + rf\" ${bs}pm$ \" + open_std_font + rf\"{c_std:.{num_dp}f}\" } \"\"\" + closing_std_font\n",
    "\n",
    "    if cm == list(cms_.keys())[-1] and tm == list(ttms_.keys())[-1]:\n",
    "        line = rf\"\\bottomrule\"\n",
    "    elif cm == list(cms_.keys())[-1] and tm != list(ttms_.keys())[-1]:\n",
    "        line = \"\\hline\"\n",
    "    else:\n",
    "        line = \"\"\n",
    "    cross_prod_i = cross_prod_i + r\"\\\\\" + line\n",
    "    body_txt= body_txt+ cross_prod_i + \"\\n\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "metrics_txt = \" & \".join( [ \"\\multicolumn{1}{c}\" + \"{\" + rf\"\\textbf\" + \"{\" + sm + \"}\" + \"}\" for sm in sub_metrics] * len(data_models))\n",
    "data_models_txt = ' & ' + ' & '.join([rf\"\\multicolumn{{{len(sub_metrics)}}}{{c}}\" + \"{\" + rf\"\\textbf\" + rf\"{{{dm_identifier[dm]}}}\" + \"}\" for dm in data_models])\n",
    "caption = \"Example TBAL LaTeX Table\"\n",
    "# \\fontsize{}{} # Set font size to 9pt with 11pt baselineskip\n",
    "template = rf\"\"\"\n",
    "\\begin{{table*}}[t]\n",
    "\\fontsize{{{global_font_size[0]}}}{{{global_font_size[1]}}}\\selectfont\n",
    "\\begin{{tabular}}{{llllllllll}}\n",
    "\\toprule\n",
    "\\multicolumn{{1}}{{c}}{{\\multirow{{2}}{{*}}{{\\textbf{{Train-time}}}}}} & \\multicolumn{{1}}{{c}}{{\\multirow{{2}}{{*}}{{\\textbf{{Post-hoc}}}}}} {data_models_txt} \\\\ \\cline{{3-10}}\n",
    "\\multicolumn{{1}}{{c}}{{}}                      & \\multicolumn{{1}}{{c}}{{}}  & {metrics_txt} \\\\ \\toprule \n",
    "\"\"\" + body_txt + rf\"\"\"\n",
    "\\end{{tabular}}\n",
    "\\caption{{{caption}}}\n",
    "\\end{{table*}}\"\"\"\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"./final_table_latex_template.txt\", \"w\") as file:\n",
    "    file.write(template)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml-kernel-tbal",
   "language": "python",
   "name": "ml-kernel-tbal"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
