{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m2024-08-31 19:30:35.889\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mtabicl.results.results_sweep\u001b[0m:\u001b[36mload\u001b[0m:\u001b[36m44\u001b[0m - \u001b[1mLoaded results from outputs/saved/tabscratch/foundation-default-tabzilla_has_completed_runs-zeroshot/results_sweep.nc\u001b[0m\n",
      "\u001b[32m2024-08-31 19:30:35.897\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mtabicl.results.results_sweep\u001b[0m:\u001b[36mload\u001b[0m:\u001b[36m44\u001b[0m - \u001b[1mLoaded results from outputs/saved/tabscratch/foundation-default-tabzilla_has_completed_runs-finetune/results_sweep.nc\u001b[0m\n",
      "\u001b[32m2024-08-31 19:30:35.903\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mtabicl.results.results_sweep\u001b[0m:\u001b[36mload\u001b[0m:\u001b[36m44\u001b[0m - \u001b[1mLoaded results from outputs/saved/tabpfn_original/tabpfn-default-tabzilla_has_completed_runs-zeroshot/results_sweep.nc\u001b[0m\n",
      "\u001b[32m2024-08-31 19:30:35.908\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mtabicl.results.results_sweep\u001b[0m:\u001b[36mload\u001b[0m:\u001b[36m44\u001b[0m - \u001b[1mLoaded results from outputs/saved/tabpfn_original/tabpfn-default-tabzilla_has_completed_runs-finetune/results_sweep.nc\u001b[0m\n",
      "\u001b[32m2024-08-31 19:30:35.915\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mtabicl.results.results_sweep\u001b[0m:\u001b[36mload\u001b[0m:\u001b[36m44\u001b[0m - \u001b[1mLoaded results from outputs/saved/tabpfn/test_tabzilla_has_completed_runs_zeroshot/results_sweep.nc\u001b[0m\n",
      "\u001b[32m2024-08-31 19:30:35.921\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mtabicl.results.results_sweep\u001b[0m:\u001b[36mload\u001b[0m:\u001b[36m44\u001b[0m - \u001b[1mLoaded results from outputs/saved/tabpfn/test_tabzilla_has_completed_runs_finetune/results_sweep.nc\u001b[0m\n",
      "\u001b[32m2024-08-31 19:30:35.928\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mtabicl.results.results_sweep\u001b[0m:\u001b[36mload\u001b[0m:\u001b[36m44\u001b[0m - \u001b[1mLoaded results from outputs/saved/tabforest/test_tabzilla_has_completed_runs_zeroshot/results_sweep.nc\u001b[0m\n",
      "\u001b[32m2024-08-31 19:30:35.934\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mtabicl.results.results_sweep\u001b[0m:\u001b[36mload\u001b[0m:\u001b[36m44\u001b[0m - \u001b[1mLoaded results from outputs/saved/tabforest/test_tabzilla_has_completed_runs_finetune/results_sweep.nc\u001b[0m\n",
      "\u001b[32m2024-08-31 19:30:35.940\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mtabicl.results.results_sweep\u001b[0m:\u001b[36mload\u001b[0m:\u001b[36m44\u001b[0m - \u001b[1mLoaded results from outputs/saved/tabforestpfn/test_tabzilla_has_completed_runs_zeroshot/results_sweep.nc\u001b[0m\n",
      "\u001b[32m2024-08-31 19:30:35.946\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mtabicl.results.results_sweep\u001b[0m:\u001b[36mload\u001b[0m:\u001b[36m44\u001b[0m - \u001b[1mLoaded results from outputs/saved/tabforestpfn/test_tabzilla_has_completed_runs_finetune/results_sweep.nc\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "import xarray as xr\n",
    "from tabicl.core.enums import BenchmarkOrigin\n",
    "from tabicl.results.dataset_manipulations import add_model_plot_names, average_out_the_cv_split, create_normalized_score, only_use_models_and_datasets_specified_in_cfg, select_only_the_first_default_run_of_every_model_and_dataset, take_run_with_best_validation_loss\n",
    "from tabicl.results.merge_ds_bench_tabzilla_and_sweeps import merge_ds_bench_tabzilla_and_sweeps\n",
    "from tabicl.results.normalize_scores import create_normalization_info\n",
    "from tabicl.results.ranking_table import make_ranking_table_, process_benchmark_results, process_sweep_results\n",
    "from tabicl.results.reformat_results_get import get_reformatted_results\n",
    "from tabicl.results.results_sweep import ResultsSweep\n",
    "from tabicl.config.config_benchmark_sweep import ConfigBenchmarkSweep\n",
    "\n",
    "cfg_general = ConfigBenchmarkSweep.load(Path(\"outputs/saved/tabforestpfn/test_tabzilla_has_completed_runs_finetune/config_benchmark_sweep.yaml\"))\n",
    "\n",
    "def prepare_ds(txt: str, path: Path):\n",
    "    results_sweep = ResultsSweep.load(path / \"results_sweep.nc\")\n",
    "    ds = results_sweep.ds\n",
    "    ds.attrs['model_plot_name'] = txt\n",
    "    return ds\n",
    "\n",
    "dss = [\n",
    "    prepare_ds(\"TabScratch - Zero-shot\", Path(\"outputs/saved/tabscratch/foundation-default-tabzilla_has_completed_runs-zeroshot\")),\n",
    "    prepare_ds(\"TabScratch - Fine-tuned\", Path(\"outputs/saved/tabscratch/foundation-default-tabzilla_has_completed_runs-finetune\")),\n",
    "    prepare_ds(\"TabPFN (original) - Zero-shot\", Path(\"outputs/saved/tabpfn_original/tabpfn-default-tabzilla_has_completed_runs-zeroshot\")),\n",
    "    prepare_ds(\"TabPFN (original) - Fine-tuned\", Path(\"outputs/saved/tabpfn_original/tabpfn-default-tabzilla_has_completed_runs-finetune\")),\n",
    "    prepare_ds(\"TabPFN (retrained) - Zero-shot\", Path(\"outputs/saved/tabpfn/test_tabzilla_has_completed_runs_zeroshot\")),\n",
    "    prepare_ds(\"TabPFN (retrained) - Fine-tuned\", Path(\"outputs/saved/tabpfn/test_tabzilla_has_completed_runs_finetune\")),\n",
    "    prepare_ds(\"TabForest - Zero-shot\", Path(\"outputs/saved/tabforest/test_tabzilla_has_completed_runs_zeroshot\")),\n",
    "    prepare_ds(\"TabForest - Fine-tuned\", Path(\"outputs/saved/tabforest/test_tabzilla_has_completed_runs_finetune\")),\n",
    "    prepare_ds(\"TabForestPFN - Zero-shot\", Path(\"outputs/saved/tabforestpfn/test_tabzilla_has_completed_runs_zeroshot\")),\n",
    "    prepare_ds(\"TabForestPFN - Fine-tuned\", Path(\"outputs/saved/tabforestpfn/test_tabzilla_has_completed_runs_finetune\")),\n",
    "]\n",
    "\n",
    "ds = merge_ds_bench_tabzilla_and_sweeps(cfg_general, dss)\n",
    "df = make_ranking_table_(cfg_general, ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrrrr}\n",
      "\\toprule\n",
      " & rank_min & rank_max & rank_mean & rank_median & acc_mean & acc_median \\\\\n",
      "\\midrule\n",
      "TabPFN (retrained) - Fine-tuned & 1 & 26 & 8.4 & 7.5 & 0.843 & 0.891 \\\\\n",
      "TabForestPFN - Fine-tuned & 1 & 27 & 8.4 & 7.0 & 0.840 & 0.902 \\\\\n",
      "CatBoost & 1 & 23 & 9.6 & 9.0 & 0.842 & 0.874 \\\\\n",
      "TabPFN (original) - Fine-tuned & 1 & 26 & 9.6 & 10.0 & 0.834 & 0.902 \\\\\n",
      "TabForestPFN - Zero-shot & 1 & 25 & 9.7 & 9.2 & 0.819 & 0.883 \\\\\n",
      "XGBoost & 1 & 23 & 9.8 & 9.8 & 0.836 & 0.899 \\\\\n",
      "TabForest - Fine-tuned & 1 & 27 & 10.9 & 10.0 & 0.806 & 0.873 \\\\\n",
      "TabPFN (retrained) - Zero-shot & 1 & 26 & 11.2 & 10.5 & 0.797 & 0.858 \\\\\n",
      "LightGBM & 1 & 27 & 11.8 & 12.2 & 0.787 & 0.867 \\\\\n",
      "TabPFN (original) - Zero-shot & 1 & 26 & 12.1 & 12.0 & 0.777 & 0.841 \\\\\n",
      "RandomForest & 1 & 26 & 12.1 & 12.0 & 0.792 & 0.851 \\\\\n",
      "Resnet & 1 & 27 & 12.8 & 12.0 & 0.727 & 0.837 \\\\\n",
      "NODE & 1 & 27 & 12.9 & 13.5 & 0.751 & 0.833 \\\\\n",
      "SAINT & 1 & 27 & 13.1 & 13.8 & 0.729 & 0.803 \\\\\n",
      "SVM & 1 & 26 & 13.3 & 14.0 & 0.710 & 0.801 \\\\\n",
      "FT-Transformer & 1 & 24 & 13.6 & 13.2 & 0.733 & 0.805 \\\\\n",
      "TabScratch - Fine-tuned & 1 & 25 & 13.6 & 13.0 & 0.752 & 0.819 \\\\\n",
      "DANet & 2 & 27 & 15.6 & 16.0 & 0.718 & 0.769 \\\\\n",
      "TabForest - Zero-shot & 3 & 26 & 15.7 & 16.0 & 0.709 & 0.822 \\\\\n",
      "MLP-rtdl & 1 & 27 & 16.9 & 19.0 & 0.619 & 0.736 \\\\\n",
      "STG & 1 & 27 & 17.1 & 19.0 & 0.592 & 0.664 \\\\\n",
      "LinearRegression & 1 & 27 & 18.5 & 21.0 & 0.564 & 0.593 \\\\\n",
      "MLP & 2 & 27 & 18.8 & 21.0 & 0.570 & 0.586 \\\\\n",
      "TabNet & 2 & 27 & 19.1 & 20.2 & 0.579 & 0.666 \\\\\n",
      "DecisionTree & 1 & 27 & 19.7 & 21.5 & 0.502 & 0.551 \\\\\n",
      "KNN & 2 & 27 & 20.5 & 23.0 & 0.473 & 0.478 \\\\\n",
      "VIME & 3 & 27 & 22.9 & 25.0 & 0.343 & 0.241 \\\\\n",
      "TabScratch - Zero-shot & 28 & 28 & 28.0 & 28.0 & 0.000 & 0.000 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df[['rank_min', 'rank_max']] = df[['rank_min', 'rank_max']].astype(int)\n",
    "# df['rank_mean'] = df['rank_mean'].apply(lambda x: f\"{x:.1f}\")\n",
    "# df[['acc_mean', 'acc_median']] = df[['acc_mean', 'acc_median']].apply(lambda x: f\"{x:.3f}\")\n",
    "\n",
    "format_mapping = {\n",
    "    'acc_mean': '{:.3f}',\n",
    "    'acc_median': '{:.3f}',\n",
    "    'rank_mean': '{:.1f}',\n",
    "    'rank_median': '{:.1f}',\n",
    "}\n",
    "\n",
    "\n",
    "print(df.to_latex(formatters=format_mapping))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                 rank_min  rank_max  rank_mean  rank_median  \\\n",
      "TabScratch - Zero-shot               19.0      19.0       19.0         19.0   \n",
      "TabScratch - Fine-tuned               1.0      18.0        8.3          7.0   \n",
      "TabPFN (original) - Zero-shot         1.0      19.0        7.7          7.0   \n",
      "TabPFN (original) - Fine-tuned        1.0      18.0        6.4          6.5   \n",
      "TabPFN (retrained) - Zero-shot        1.0      18.0        7.0          6.0   \n",
      "TabPFN (retrained) - Fine-tuned       1.0      18.0        5.6          5.0   \n",
      "TabForest - Zero-shot                 1.0      19.0        9.8         10.0   \n",
      "TabForest - Fine-tuned                1.0      19.0        7.0          6.0   \n",
      "TabForestPFN - Zero-shot              1.0      18.0        6.3          6.0   \n",
      "TabForestPFN - Fine-tuned             1.0      19.0        5.6          4.5   \n",
      "\n",
      "                                 acc_mean  acc_median  \n",
      "TabScratch - Zero-shot              0.000       0.000  \n",
      "TabScratch - Fine-tuned             0.757       0.819  \n",
      "TabPFN (original) - Zero-shot       0.780       0.841  \n",
      "TabPFN (original) - Fine-tuned      0.838       0.909  \n",
      "TabPFN (retrained) - Zero-shot      0.803       0.860  \n",
      "TabPFN (retrained) - Fine-tuned     0.847       0.891  \n",
      "TabForest - Zero-shot               0.714       0.822  \n",
      "TabForest - Fine-tuned              0.810       0.885  \n",
      "TabForestPFN - Zero-shot            0.824       0.900  \n",
      "TabForestPFN - Fine-tuned           0.846       0.910  \n",
      "\n",
      "\\begin{tabular}{lrrrrrr}\n",
      "\\toprule\n",
      " & rank_min & rank_max & rank_mean & rank_median & acc_mean & acc_median \\\\\n",
      "\\midrule\n",
      "TabScratch - Zero-shot & 19 & 19 & 19.0 & 19.0 & 0.000 & 0.000 \\\\\n",
      "TabScratch - Fine-tuned & 1 & 18 & 8.3 & 7.0 & 0.757 & 0.819 \\\\\n",
      "TabPFN (original) - Zero-shot & 1 & 19 & 7.7 & 7.0 & 0.780 & 0.841 \\\\\n",
      "TabPFN (original) - Fine-tuned & 1 & 18 & 6.4 & 6.5 & 0.838 & 0.909 \\\\\n",
      "TabPFN (retrained) - Zero-shot & 1 & 18 & 7.0 & 6.0 & 0.803 & 0.860 \\\\\n",
      "TabPFN (retrained) - Fine-tuned & 1 & 18 & 5.6 & 5.0 & 0.847 & 0.891 \\\\\n",
      "TabForest - Zero-shot & 1 & 19 & 9.8 & 10.0 & 0.714 & 0.822 \\\\\n",
      "TabForest - Fine-tuned & 1 & 19 & 7.0 & 6.0 & 0.810 & 0.885 \\\\\n",
      "TabForestPFN - Zero-shot & 1 & 18 & 6.3 & 6.0 & 0.824 & 0.900 \\\\\n",
      "TabForestPFN - Fine-tuned & 1 & 19 & 5.6 & 4.5 & 0.846 & 0.910 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "individual_performance = []\n",
    "\n",
    "for ds in dss:\n",
    "\n",
    "    ds_merged = merge_ds_bench_tabzilla_and_sweeps(cfg_general, [ds])\n",
    "    df = make_ranking_table_(cfg_general, ds_merged)\n",
    "    my_model = df.filter(regex='Zero-shot|Fine-tune', axis=0)\n",
    "    individual_performance.append(my_model)\n",
    "\n",
    "\n",
    "df_indiv = pd.concat(individual_performance)\n",
    "print(df_indiv)\n",
    "print(\"\")\n",
    "df_indiv[['rank_min', 'rank_max']] = df_indiv[['rank_min', 'rank_max']].astype(int)\n",
    "print(df_indiv.to_latex(formatters=format_mapping))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                           rank_min  rank_max  rank_mean  rank_median  \\\n",
      "TabForestPFN - Fine-tuned       1.0      19.0        5.6         4.50   \n",
      "CatBoost                        1.0      15.0        6.2         5.00   \n",
      "XGBoost                         1.0      16.0        6.3         5.00   \n",
      "LightGBM                        1.0      19.0        7.7         6.00   \n",
      "RandomForest                    1.0      18.0        7.9         8.00   \n",
      "NODE                            1.0      19.0        8.4         8.00   \n",
      "Resnet                          1.0      19.0        8.4         8.00   \n",
      "SAINT                           1.0      19.0        8.5         8.00   \n",
      "SVM                             1.0      18.5        8.6         8.00   \n",
      "FT-Transformer                  1.0      16.0        8.9         8.50   \n",
      "DANet                           2.0      19.0       10.0        10.00   \n",
      "MLP-rtdl                        1.0      19.0       11.5        12.00   \n",
      "STG                             1.0      19.0       11.5        12.00   \n",
      "LinearRegression                1.0      19.0       12.3        13.75   \n",
      "MLP                             1.0      19.0       12.6        14.00   \n",
      "TabNet                          2.0      19.0       12.8        13.25   \n",
      "DecisionTree                    1.0      19.0       13.3        14.00   \n",
      "KNN                             2.5      19.0       13.9        15.00   \n",
      "VIME                            1.0      19.0       15.7        17.00   \n",
      "\n",
      "                           acc_mean  acc_median  \n",
      "TabForestPFN - Fine-tuned     0.846       0.910  \n",
      "CatBoost                      0.848       0.876  \n",
      "XGBoost                       0.841       0.901  \n",
      "LightGBM                      0.792       0.871  \n",
      "RandomForest                  0.797       0.852  \n",
      "NODE                          0.754       0.839  \n",
      "Resnet                        0.729       0.837  \n",
      "SAINT                         0.733       0.817  \n",
      "SVM                           0.713       0.801  \n",
      "FT-Transformer                0.737       0.805  \n",
      "DANet                         0.721       0.768  \n",
      "MLP-rtdl                      0.622       0.737  \n",
      "STG                           0.594       0.672  \n",
      "LinearRegression              0.567       0.592  \n",
      "MLP                           0.572       0.588  \n",
      "TabNet                        0.583       0.672  \n",
      "DecisionTree                  0.504       0.551  \n",
      "KNN                           0.475       0.484  \n",
      "VIME                          0.345       0.240  \n",
      "\n",
      "\\begin{tabular}{lrrrrrr}\n",
      "\\toprule\n",
      " & rank_min & rank_max & rank_mean & rank_median & acc_mean & acc_median \\\\\n",
      "\\midrule\n",
      "TabForestPFN - Fine-tuned & 1 & 19 & 5.6 & 4.5 & 0.846 & 0.910 \\\\\n",
      "CatBoost & 1 & 15 & 6.2 & 5.0 & 0.848 & 0.876 \\\\\n",
      "XGBoost & 1 & 16 & 6.3 & 5.0 & 0.841 & 0.901 \\\\\n",
      "LightGBM & 1 & 19 & 7.7 & 6.0 & 0.792 & 0.871 \\\\\n",
      "RandomForest & 1 & 18 & 7.9 & 8.0 & 0.797 & 0.852 \\\\\n",
      "NODE & 1 & 19 & 8.4 & 8.0 & 0.754 & 0.839 \\\\\n",
      "Resnet & 1 & 19 & 8.4 & 8.0 & 0.729 & 0.837 \\\\\n",
      "SAINT & 1 & 19 & 8.5 & 8.0 & 0.733 & 0.817 \\\\\n",
      "SVM & 1 & 18 & 8.6 & 8.0 & 0.713 & 0.801 \\\\\n",
      "FT-Transformer & 1 & 16 & 8.9 & 8.5 & 0.737 & 0.805 \\\\\n",
      "DANet & 2 & 19 & 10.0 & 10.0 & 0.721 & 0.768 \\\\\n",
      "MLP-rtdl & 1 & 19 & 11.5 & 12.0 & 0.622 & 0.737 \\\\\n",
      "STG & 1 & 19 & 11.5 & 12.0 & 0.594 & 0.672 \\\\\n",
      "LinearRegression & 1 & 19 & 12.3 & 13.8 & 0.567 & 0.592 \\\\\n",
      "MLP & 1 & 19 & 12.6 & 14.0 & 0.572 & 0.588 \\\\\n",
      "TabNet & 2 & 19 & 12.8 & 13.2 & 0.583 & 0.672 \\\\\n",
      "DecisionTree & 1 & 19 & 13.3 & 14.0 & 0.504 & 0.551 \\\\\n",
      "KNN & 2 & 19 & 13.9 & 15.0 & 0.475 & 0.484 \\\\\n",
      "VIME & 1 & 19 & 15.7 & 17.0 & 0.345 & 0.240 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(df)\n",
    "print(\"\")\n",
    "df[['rank_min', 'rank_max']] = df[['rank_min', 'rank_max']].astype(int)\n",
    "print(df.to_latex(formatters=format_mapping))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tab",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
