{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "37aa48d8-0dd4-47d0-b680-97ac9b313f3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "import json\n",
    "import re\n",
    "\n",
    "from helper import add_additional_info, filter_df_for_best_runs, get_abs_rel_performance, load_ds_info\n",
    "from constants import ds_info_file,BASE_PATH_PROJECT"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0cd47692-af6c-4026-9626-108276ea9b19",
   "metadata": {},
   "source": [
    "### Global variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f5e5e5fd-4d8e-4adc-837c-1d7fc0cd5aca",
   "metadata": {},
   "outputs": [],
   "source": [
    "project_paths = [\n",
    "    BASE_PATH_PROJECT / \"results_iclr_exp\",\n",
    "    BASE_PATH_PROJECT / \"results_iclr_exp_wd0.1\",\n",
    "]\n",
    "\n",
    "FILTER_FOR_BEST_MODEL = True\n",
    "COMPUTE_RELATIVE_PERFORMANCES = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "99af0df1-cb39-4748-8a6f-12a749cc2e96",
   "metadata": {},
   "outputs": [],
   "source": [
    "storing_path = BASE_PATH_PROJECT / \"results_iclr_exp/aggregated\"\n",
    "\n",
    "OVERWRITE = False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae43f7c6-967a-452d-8fe5-ba0f00dc41bb",
   "metadata": {},
   "source": [
    "### Gather all results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ee971517-0a31-494f-a8a4-80a9e6764c06",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = []\n",
    "for project_path in project_paths:\n",
    "    for res_path in project_path.rglob('seed_0/results.json'):\n",
    "        \n",
    "        df = pd.read_json(res_path)\n",
    "        df = add_additional_info(df)\n",
    "        \n",
    "        model_id_n_hopt_slug = \"/\".join(res_path.parts[10:-1])\n",
    "        df['model_id_n_hopt_slug'] = model_id_n_hopt_slug\n",
    "        df['res_folder'] = project_path.name\n",
    "\n",
    "        res.append(df)\n",
    "all_results = pd.concat(res).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "282bfc36-d56e-4140-916e-07d08d544532",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3015, 72)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_results.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6866e9a4-8d85-48d5-bb2a-79d49c5f1963",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_results = all_results[~all_results['attention_dropout'].isin(['[0.1, 0.0]','[0.3, 0.0]'])].reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5354141d-00e0-4654-9772-1d15d801b5f6",
   "metadata": {},
   "source": [
    "### Post process the results\n",
    "- Select for each combination the best run based on the validation accuracy\n",
    "- Compute the relative performance gain compared to one layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fefcd0b8-ea54-4d80-8be5-261256333680",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(2289, 72) (688, 72)\n",
      "Filtering all but baseline runs ...\n",
      "df.shape before filtering for best runs (2289, 72)\n",
      "df.shape after filtering for best runs (1223, 72)\n",
      "Filtering baseline runs ...\n",
      "df.shape before filtering for best runs (513, 72)\n",
      "df.shape after filtering for best runs (171, 72)\n"
     ]
    }
   ],
   "source": [
    "if FILTER_FOR_BEST_MODEL:\n",
    "    all_but_baseline = all_results[all_results['experiment']!='All tokens last layer'].copy().reset_index(drop=True)\n",
    "    baseline = all_results[all_results['experiment']=='All tokens last layer'].copy().reset_index(drop=True)\n",
    "    print(all_but_baseline.shape, baseline.shape)\n",
    "\n",
    "    print(\"Filtering all but baseline runs ...\")\n",
    "    all_but_baseline = filter_df_for_best_runs(\n",
    "        df = all_but_baseline, \n",
    "        metric_col = 'best_val_bal_acc1', \n",
    "        group_cols = ['task', 'experiment', \"dataset\", \"model_ids\"]\n",
    "    )\n",
    "\n",
    "    ## Filter only for correct runs \n",
    "    if len(baseline)>0:\n",
    "        print(\"Filtering baseline runs ...\")\n",
    "        baseline = baseline[baseline['res_folder'] == 'results_iclr_exp_wd0.1'].copy().reset_index(drop=True)\n",
    "        \n",
    "        baseline = filter_df_for_best_runs(\n",
    "            df = baseline, \n",
    "            metric_col = 'best_val_bal_acc1', \n",
    "            group_cols = ['task', 'experiment', \"dataset\", \"model_ids\"]\n",
    "        )\n",
    "        \n",
    "        all_results = pd.concat([all_but_baseline, baseline]).reset_index(drop=True)\n",
    "    else:\n",
    "        all_results = all_but_baseline.reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bd3788c3-5ec8-411d-8366-de2ee67419bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "all_results.shape=(1394, 72) before computing relative performances\n",
      "all_results.shape=(1394, 82) after computing relative performances\n"
     ]
    }
   ],
   "source": [
    "if COMPUTE_RELATIVE_PERFORMANCES:\n",
    "    print(f\"{all_results.shape=} before computing relative performances\")\n",
    "    all_results = all_results.groupby(['dataset', 'base_model']).apply(get_abs_rel_performance, include_groups=False).reset_index()\n",
    "    print(f\"{all_results.shape=} after computing relative performances\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7fe0e1ef-a283-4037-a7e6-4bf333808eed",
   "metadata": {},
   "source": [
    "### Store the aggregated data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "25f9890b-3545-4a4d-a976-c7ef3ef708a5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Stored all aggregated results ... \n"
     ]
    }
   ],
   "source": [
    "fn = storing_path / f'all_runs_v11.pkl'\n",
    "if fn.exists() and not OVERWRITE:\n",
    "    raise FileExistsError(f\"File {fn} already exists. No overwriting!!\")\n",
    "else:\n",
    "    all_results.to_pickle(fn)\n",
    "    print(f\"Stored all aggregated results ... \")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
