{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "\n",
    "import os\n",
    "import sys\n",
    "import pandas as pd\n",
    "from functools import reduce\n",
    "sys.path.insert(0, \"..\")\n",
    "from floral.utils.plotting import PLOTS_DIR\n",
    "TABLES_DIR = os.path.join(\"..\", \"tables\")\n",
    "os.makedirs(TABLES_DIR, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ================ CHOOSE EXPERIMENT BUNDLE ================ #\n",
    "EXPERIMENT = \"run_methods\"\n",
    "# ================ CHOOSE EXPERIMENT BUNDLE ================ #"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "AVAILABLE_EXPERIMENTS = [\n",
    "    \"run_methods\",\n",
    "    \"run_methods_reduced\",\n",
    "    \"run_methods_synthetic\",\n",
    "    \"run_methods_general\",\n",
    "    \"ab_floral\",\n",
    "    \"ab_normlora\",\n",
    "    \"hp_floral\",\n",
    "    \"hp_floral_cifar100\",\n",
    "    \"hp_convlora\",\n",
    "    \"hp_batchnormlora_synthetic\",\n",
    "    \"hp_batchnormlora_cifar100\",\n",
    "]\n",
    "\n",
    "# ==================== DECLARE EXPERIMENT DATASETS, EXPERIMENT VARIABLES, AND DATASET METRICS ==================== #\n",
    "# ----- Methods Performances ------ #\n",
    "EXPERIMENT_SUFFIXES = [\"_reduced\", \"_synthetic\", \"_general\", \"_cifar100\", \"_others\"]\n",
    "EXPERIMENT_DATASETS = {\n",
    "    \"run_methods\": [\n",
    "        \"mnist_rotate\",\n",
    "        \"mnist_label_shift\",\n",
    "        \"cifar10_rotate\",\n",
    "        \"cifar10_label_shift\",\n",
    "        \"cifar100\",\n",
    "    ],\n",
    "    \"run_methods_reduced\": [\n",
    "        \"mnist_rotate_reduced\",\n",
    "        \"mnist_label_shift_reduced\",\n",
    "        \"cifar10_rotate_reduced\",\n",
    "        \"cifar10_label_shift_reduced\",\n",
    "        \"cifar100_reduced\",\n",
    "    ],\n",
    "    \"run_methods_synthetic\": [\n",
    "        \"synthetic_linear\",\n",
    "        \"synthetic_mlp\",\n",
    "    ],\n",
    "    \"run_methods_general\": [\n",
    "        # \"emnist\",  # XXX\n",
    "        # \"shakespeare\",  # XXX\n",
    "        # \"stackoverflow\",  # XXX\n",
    "    ],\n",
    "    \"ab_floral\": [\n",
    "        \"mnist_rotate\",\n",
    "        \"mnist_label_shift\",\n",
    "        \"cifar10_rotate\",\n",
    "        \"cifar10_label_shift\",\n",
    "    ],\n",
    "    \"ab_floral_cifar100\": [\n",
    "        \"cifar100\",\n",
    "    ],\n",
    "    \"ab_normlora\": [\n",
    "        # \"cifar100\",  # XXX\n",
    "        # \"emnist\",  # XXX\n",
    "        # \"stackoverflow\",  # XXX\n",
    "    ],\n",
    "    \"hp_floral\": [\n",
    "        \"mnist_rotate\",\n",
    "        \"mnist_label_shift\",\n",
    "        \"cifar10_rotate\",\n",
    "        \"cifar10_label_shift\",\n",
    "    ],\n",
    "    \"hp_floral_cifar100\": [\n",
    "        \"cifar100\",\n",
    "    ],\n",
    "    \"hp_convlora\": [\n",
    "        \"cifar10_rotate\",\n",
    "        \"cifar10_label_shift\",\n",
    "        \"cifar100\",\n",
    "        # \"emnist\",  # XXX\n",
    "    ],\n",
    "    \"hp_batchnormlora_synthetic\": [\n",
    "        \"synthetic_mlp_bn\",\n",
    "    ],\n",
    "    \"hp_batchnormlora_cifar100\": [\n",
    "        \"cifar100_bn\",\n",
    "    ]\n",
    "}\n",
    "\n",
    "EXPERIMENT_VARIABLES = {\n",
    "    \"run_methods\": [\"method\", \"optimal_router\"],\n",
    "    \"run_methods_reduced\": [\"method\", \"optimal_router\"],\n",
    "    \"run_methods_synthetic\": [\"method\", \"optimal_router\"],\n",
    "    \"run_methods_general\": [\"method\"],\n",
    "    \"ab_floral\": [\"active_loras\", \"bias\"],\n",
    "    \"ab_floral_cifar100\": [\"active_loras\", \"bias\"],\n",
    "    \"hp_convlora\": [\"convlora_method\"],\n",
    "    \"hp_floral\": [\"num_clusters\", \"rank\"],\n",
    "    \"hp_floral_cifar100\": [\"num_clusters\", \"rank\"],\n",
    "    \"hp_batchnormlora_synthetic\": [\"batchnorm_adaptor\", \"batchnorm_stats\"],\n",
    "    \"hp_batchnormlora_cifar100\": [\"batchnorm_adaptor\", \"batchnorm_stats\"],\n",
    "}\n",
    "\n",
    "EXPERIMENT_METRIC = {\n",
    "    \"run_methods\": \"acc_distributed\",\n",
    "    \"run_methods_reduced\": \"acc_distributed\",\n",
    "    \"run_methods_synthetic\": \"loss_distributed\",\n",
    "    \"run_methods_general\": None,  # lookup dataset metrics\n",
    "    \"ab_floral\": \"acc_distributed\",\n",
    "    \"ab_floral_cifar100\": \"acc_distributed\",\n",
    "    \"hp_convlora\": \"acc_distributed\",\n",
    "    \"hp_floral\": \"acc_distributed\",\n",
    "    \"hp_floral_cifar100\": \"acc_distributed\",\n",
    "    \"hp_batchnormlora_synthetic\": \"loss_distributed\",\n",
    "    \"hp_batchnormlora_cifar100\": \"acc_distributed\",\n",
    "}\n",
    "\n",
    "DATASET_METRICS = {\n",
    "    \"synthetic_linear\": [\"loss_distributed\"],\n",
    "    \"synthetic_mlp\": [\"loss_distributed\"],\n",
    "    \"mnist_rotate\":  [\"acc_distributed\"],\n",
    "    \"mnist_label_shift\": [\"acc_distributed\"],\n",
    "    \"cifar10_rotate\": [\"acc_distributed\"],\n",
    "    \"cifar10_label_shift\": [\"acc_distributed\"],\n",
    "    \"cifar100\": [\"acc_distributed\"],\n",
    "    \"mnist_rotate_reduced\": [\"acc_distributed\"],\n",
    "    \"mnist_label_shift_reduced\": [\"acc_distributed\"],\n",
    "    \"cifar10_rotate_reduced\": [\"acc_distributed\"],\n",
    "    \"cifar10_label_shift_reduced\": [\"acc_distributed\"],\n",
    "    \"cifar100_reduced\": [\"acc_distributed\"],\n",
    "    \"synthetic_mlp_bn\": [\"loss_distributed\"],\n",
    "    \"cifar100_bn\": [\"acc_distributed\"],\n",
    "    \"emnist\": [\"acc_distributed\"],\n",
    "    \"shakespeare\": [\"accuracy_top1_distributed\", \"accuracy_top5_distributed\"],\n",
    "    \"stackoverflow\": [\"accuracy_top1_distributed\", \"accuracy_top3_distributed\",\n",
    "                      \"accuracy_top5_distributed\", \"accuracy_top10_distributed\"],\n",
    "}\n",
    "\n",
    "\n",
    "# ==================== DEFINE REPORT NAMES ==================== #\n",
    "EXPERIMENT_TO_REPORT_NAME = {\n",
    "    \"run_methods\": \"Methods comparison\",\n",
    "    \"run_methods_synthetic\": \"Methods comparison\",\n",
    "    \"run_methods_general\": \"Methods comparison\",\n",
    "    \"ab_floral\": \"FLoRAL adaptors ablation\",\n",
    "    \"hp_convlora\": \"ConvLoRA types comparison\",\n",
    "    \"hp_floral\": \"Number of Adaptors and their effective rank\",\n",
    "    \"hp_batchnormlora\": \"Batch-Norm adaptors\",\n",
    "}\n",
    "\n",
    "DATASET_TO_REPORT_NAME = {\n",
    "    \"synthetic_linear\": \"Synthetic Linear\",\n",
    "    \"synthetic_mlp\": \"Synthetic MLP\",\n",
    "    \"mnist_rotate\":  \"MNIST-Rotate\",\n",
    "    \"mnist_label_shift\": \"MNIST-Label-Shift\",\n",
    "    \"cifar10_rotate\": \"CIFAR-10-Rotate\",\n",
    "    \"cifar10_label_shift\": \"CIFAR-10-Label-Shift\",\n",
    "    \"cifar100\": \"CIFAR-100\",\n",
    "    \"mnist_rotate_reduced\": r\"MNIST-Rotate(5\\%)\",\n",
    "    \"mnist_label_shift_reduced\": r\"MNIST-Label-Shift(5\\%)\",\n",
    "    \"cifar10_rotate_reduced\": r\"CIFAR-10-Rotate(5\\%)\",\n",
    "    \"cifar10_label_shift_reduced\": r\"CIFAR-10-Label-Shift(5\\%)\",\n",
    "    \"cifar100_reduced\": r\"CIFAR-100(5\\%)\",\n",
    "    \"synthetic_mlp_bn\": \"Synthetic MLP\",\n",
    "    \"cifar100_bn\": \"CIFAR-100\",\n",
    "    \"emnist\": \"FEMNIST\",\n",
    "    \"shakespeare\": \"Shakespeare\",\n",
    "    \"stackoverflow\": \"Stack Overflow\",\n",
    "}\n",
    "\n",
    "VARIABLE_TO_REPORT_NAME = {\n",
    "    \"method\": \"Method\",\n",
    "    \"optimal_router\": \"Optimal Router\",\n",
    "    \"convlora_method\": \"ConvLoRA Type\",\n",
    "    \"active_loras\": \"Active LoRAs\",\n",
    "    \"bias\": \"Adaptive Bias\",\n",
    "    \"num_clusters\": r\"$C$\",\n",
    "    \"rank\": r\"$\\rho$\",\n",
    "    \"batchnorm_adaptor\": \"Adaptor\",\n",
    "    \"batchnorm_stats\": \"Stats\",\n",
    "}\n",
    "\n",
    "VARIABLE_VALUES_TO_REPORT_NAME = {\n",
    "    \"fedavg\": \"FedAvg\",\n",
    "    \"floral\": r\"FLoRAL($\\rho=1\\%$)\",\n",
    "    \"floral_optimalrouter\": r\"FLoRAL($\\rho=1\\%$)\",\n",
    "    \"floral_10\": r\"FLoRAL($\\rho=10\\%$)\",\n",
    "    \"floral_10_optimalrouter\": r\"FLoRAL($\\rho=10\\%$)\",\n",
    "    \"locallora\": \"Local Adaptor\",\n",
    "    \"ensemble\": \"Ensemble\",\n",
    "    \"ensemble_optimalrouter\": \"Ensemble\",\n",
    "\n",
    "    \"linear\": \"LoRA\",\n",
    "    \"conv\": \"ConvLoRA\",\n",
    "    \"linear+conv\": \"LoRA + ConvLoRA\",\n",
    "\n",
    "    \"balanced\": \"Balanced\",\n",
    "    \"balanced_2d\": \"Balanced 2D\",\n",
    "    \"in\": \"In Layer\",\n",
    "    \"out\": \"Out Layer\",\n",
    "\n",
    "    \"regular\": \"Regular\",\n",
    "    \"reparameterized\": \"Reparameterized\",\n",
    "    \"local\": \"Local\",\n",
    "    \"federated\": \"Federated\",\n",
    "\n",
    "    True: \"Yes\",\n",
    "    False: \"No\",\n",
    "    \"none\": \"None\",\n",
    "}\n",
    "\n",
    "METRIC_TO_REPORT_NAME = {\n",
    "    \"loss_distributed\": \"Loss\",\n",
    "    \"loss_in_vocab_distributed\": \"Loss in Vocab\",\n",
    "    \"acc_distributed\": \"Accuracy\",\n",
    "    \"accuracy_distributed\": \"Accuracy\",\n",
    "    \"accuracy_top1_distributed\": \"Accuracy\",\n",
    "    \"accuracy_top3_distributed\": \"Accuracy (Top-3)\",\n",
    "    \"accuracy_top5_distributed\": \"Accuracy (Top-5)\",\n",
    "    \"accuracy_top3_distributed\": \"Accuracy (Top-10)\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Method</th>\n",
       "      <th>Optimal Router</th>\n",
       "      <th>MNIST-Rotate</th>\n",
       "      <th>MNIST-Label-Shift</th>\n",
       "      <th>CIFAR-10-Rotate</th>\n",
       "      <th>CIFAR-10-Label-Shift</th>\n",
       "      <th>CIFAR-100</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Ensemble</td>\n",
       "      <td>No</td>\n",
       "      <td>91.891892</td>\n",
       "      <td>93.593593</td>\n",
       "      <td>73.290000</td>\n",
       "      <td>56.710000</td>\n",
       "      <td>35.174</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Ensemble</td>\n",
       "      <td>Yes</td>\n",
       "      <td>94.994993</td>\n",
       "      <td>95.295295</td>\n",
       "      <td>73.830000</td>\n",
       "      <td>72.720000</td>\n",
       "      <td>77.062</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>FLoRAL($\\rho=10\\%$)</td>\n",
       "      <td>No</td>\n",
       "      <td>93.093092</td>\n",
       "      <td>94.494494</td>\n",
       "      <td>72.470001</td>\n",
       "      <td>57.490000</td>\n",
       "      <td>57.350</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>FLoRAL($\\rho=10\\%$)</td>\n",
       "      <td>Yes</td>\n",
       "      <td>95.195194</td>\n",
       "      <td>94.794794</td>\n",
       "      <td>72.689999</td>\n",
       "      <td>75.730000</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>FLoRAL($\\rho=1\\%$)</td>\n",
       "      <td>No</td>\n",
       "      <td>91.691691</td>\n",
       "      <td>93.193193</td>\n",
       "      <td>70.210000</td>\n",
       "      <td>74.150000</td>\n",
       "      <td>51.716</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>FLoRAL($\\rho=1\\%$)</td>\n",
       "      <td>Yes</td>\n",
       "      <td>92.892893</td>\n",
       "      <td>92.992993</td>\n",
       "      <td>71.040000</td>\n",
       "      <td>73.069999</td>\n",
       "      <td>50.280</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>FedAvg</td>\n",
       "      <td>No</td>\n",
       "      <td>90.990991</td>\n",
       "      <td>24.924925</td>\n",
       "      <td>64.400001</td>\n",
       "      <td>22.070000</td>\n",
       "      <td>12.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>Local Adaptor</td>\n",
       "      <td>No</td>\n",
       "      <td>86.486487</td>\n",
       "      <td>84.684685</td>\n",
       "      <td>66.040000</td>\n",
       "      <td>69.080000</td>\n",
       "      <td>52.600</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                Method Optimal Router  MNIST-Rotate  MNIST-Label-Shift  \\\n",
       "0             Ensemble             No     91.891892          93.593593   \n",
       "1             Ensemble            Yes     94.994993          95.295295   \n",
       "2  FLoRAL($\\rho=10\\%$)             No     93.093092          94.494494   \n",
       "3  FLoRAL($\\rho=10\\%$)            Yes     95.195194          94.794794   \n",
       "4   FLoRAL($\\rho=1\\%$)             No     91.691691          93.193193   \n",
       "5   FLoRAL($\\rho=1\\%$)            Yes     92.892893          92.992993   \n",
       "6               FedAvg             No     90.990991          24.924925   \n",
       "7        Local Adaptor             No     86.486487          84.684685   \n",
       "\n",
       "   CIFAR-10-Rotate  CIFAR-10-Label-Shift  CIFAR-100  \n",
       "0        73.290000             56.710000     35.174  \n",
       "1        73.830000             72.720000     77.062  \n",
       "2        72.470001             57.490000     57.350  \n",
       "3        72.689999             75.730000        NaN  \n",
       "4        70.210000             74.150000     51.716  \n",
       "5        71.040000             73.069999     50.280  \n",
       "6        64.400001             22.070000     12.000  \n",
       "7        66.040000             69.080000     52.600  "
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "assert EXPERIMENT in AVAILABLE_EXPERIMENTS\n",
    "\n",
    "df_list = []\n",
    "for dataset in EXPERIMENT_DATASETS[EXPERIMENT]:\n",
    "    experiment_base = EXPERIMENT\n",
    "    for suffix in EXPERIMENT_SUFFIXES:\n",
    "        experiment_base = experiment_base.removesuffix(suffix)\n",
    "    experiment_name = f\"{experiment_base}_{dataset}\"\n",
    "    metrics_file = os.path.join(\"..\", PLOTS_DIR, experiment_name, \"metrics.csv\")\n",
    "    if not os.path.exists(metrics_file):\n",
    "        print(f\"metrics csv for experiment '{experiment_name}' does not exist! File: {metrics_file}\")\n",
    "        continue\n",
    "    metrics_df = pd.read_csv(metrics_file)  # TODO: might be outdated, get from history.pkl directly\n",
    "    variables_df = metrics_df[EXPERIMENT_VARIABLES[EXPERIMENT]]\n",
    "    variables_df = variables_df.apply(\n",
    "        lambda col: col.apply(lambda val: VARIABLE_VALUES_TO_REPORT_NAME.get(val, val))\n",
    "    )\n",
    "    variables_df[DATASET_TO_REPORT_NAME[dataset]] = metrics_df[EXPERIMENT_METRIC[EXPERIMENT]]\n",
    "    df_list.append(variables_df)\n",
    "\n",
    "experiment_metrics_df = reduce(\n",
    "    lambda df1, df2: pd.merge(df1, df2, how=\"outer\", on=EXPERIMENT_VARIABLES[EXPERIMENT]), df_list)\n",
    "experiment_metrics_df = experiment_metrics_df.sort_values(by=EXPERIMENT_VARIABLES[EXPERIMENT])\n",
    "# experiment_metrics_df = experiment_metrics_df.groupby(by=EXPERIMENT_VARIABLES[EXPERIMENT]).mean()\n",
    "experiment_metrics_df = experiment_metrics_df.rename(VARIABLE_TO_REPORT_NAME, axis=\"columns\")\n",
    "experiment_metrics_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [],
   "source": [
    "float_format = \"%.2f\" if \"acc\" in EXPERIMENT_METRIC[EXPERIMENT] else \"%.4f\"\n",
    "experiment_metrics_df.to_latex(os.path.join(TABLES_DIR, EXPERIMENT + \".tex\"), float_format=float_format)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "floral",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
