{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from copy import deepcopy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 171,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_model = 'lr'\n",
    "data = 'correction'\n",
    "\n",
    "df_results = pd.read_pickle(f'../results/comparison_{base_model}_{data}.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 172,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>alg</th>\n",
       "      <th>seed</th>\n",
       "      <th>alpha</th>\n",
       "      <th>lambda</th>\n",
       "      <th>Cost</th>\n",
       "      <th>M1 Validity</th>\n",
       "      <th>WC Validity</th>\n",
       "      <th>M1 Expectation</th>\n",
       "      <th>WC Expectation</th>\n",
       "      <th>J</th>\n",
       "      <th>Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>OPT</td>\n",
       "      <td>0</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>4.626025</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.875000</td>\n",
       "      <td>0.716056</td>\n",
       "      <td>0.528740</td>\n",
       "      <td>1.564183</td>\n",
       "      <td>0.638978</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>OPT</td>\n",
       "      <td>1</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>5.568962</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.833333</td>\n",
       "      <td>0.709723</td>\n",
       "      <td>0.546689</td>\n",
       "      <td>1.722250</td>\n",
       "      <td>0.608458</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>OPT</td>\n",
       "      <td>2</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>5.714758</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.454545</td>\n",
       "      <td>0.713369</td>\n",
       "      <td>0.482079</td>\n",
       "      <td>1.903395</td>\n",
       "      <td>0.760443</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>OPT</td>\n",
       "      <td>3</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>4.525096</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.533333</td>\n",
       "      <td>0.705098</td>\n",
       "      <td>0.507814</td>\n",
       "      <td>1.589541</td>\n",
       "      <td>0.684522</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>OPT</td>\n",
       "      <td>4</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>5.685610</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.857143</td>\n",
       "      <td>0.736234</td>\n",
       "      <td>0.543768</td>\n",
       "      <td>1.755108</td>\n",
       "      <td>0.617986</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>ROAR</td>\n",
       "      <td>0</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>2.594445</td>\n",
       "      <td>0.687500</td>\n",
       "      <td>0.062500</td>\n",
       "      <td>0.583903</td>\n",
       "      <td>0.268573</td>\n",
       "      <td>2.332159</td>\n",
       "      <td>1.813270</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>ROAR</td>\n",
       "      <td>1</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>2.557872</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.538501</td>\n",
       "      <td>0.228037</td>\n",
       "      <td>2.825758</td>\n",
       "      <td>2.314184</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>ROAR</td>\n",
       "      <td>2</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>3.928376</td>\n",
       "      <td>0.909091</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.617516</td>\n",
       "      <td>0.333619</td>\n",
       "      <td>2.287244</td>\n",
       "      <td>1.501569</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>ROAR</td>\n",
       "      <td>3</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>3.666644</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.066667</td>\n",
       "      <td>0.652781</td>\n",
       "      <td>0.392463</td>\n",
       "      <td>1.700974</td>\n",
       "      <td>0.967645</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>ROAR</td>\n",
       "      <td>4</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>4.827631</td>\n",
       "      <td>0.928571</td>\n",
       "      <td>0.071429</td>\n",
       "      <td>0.659122</td>\n",
       "      <td>0.367373</td>\n",
       "      <td>2.054141</td>\n",
       "      <td>1.088615</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    alg  seed  alpha  lambda      Cost  M1 Validity  WC Validity  \\\n",
       "0   OPT     0    0.5     0.2  4.626025     1.000000     0.875000   \n",
       "2   OPT     1    0.5     0.2  5.568962     1.000000     0.833333   \n",
       "4   OPT     2    0.5     0.2  5.714758     1.000000     0.454545   \n",
       "6   OPT     3    0.5     0.2  4.525096     1.000000     0.533333   \n",
       "8   OPT     4    0.5     0.2  5.685610     1.000000     0.857143   \n",
       "1  ROAR     0    0.5     0.2  2.594445     0.687500     0.062500   \n",
       "3  ROAR     1    0.5     0.2  2.557872     0.666667     0.000000   \n",
       "5  ROAR     2    0.5     0.2  3.928376     0.909091     0.000000   \n",
       "7  ROAR     3    0.5     0.2  3.666644     1.000000     0.066667   \n",
       "9  ROAR     4    0.5     0.2  4.827631     0.928571     0.071429   \n",
       "\n",
       "   M1 Expectation  WC Expectation         J      Loss  \n",
       "0        0.716056        0.528740  1.564183  0.638978  \n",
       "2        0.709723        0.546689  1.722250  0.608458  \n",
       "4        0.713369        0.482079  1.903395  0.760443  \n",
       "6        0.705098        0.507814  1.589541  0.684522  \n",
       "8        0.736234        0.543768  1.755108  0.617986  \n",
       "1        0.583903        0.268573  2.332159  1.813270  \n",
       "3        0.538501        0.228037  2.825758  2.314184  \n",
       "5        0.617516        0.333619  2.287244  1.501569  \n",
       "7        0.652781        0.392463  1.700974  0.967645  \n",
       "9        0.659122        0.367373  2.054141  1.088615  "
      ]
     },
     "execution_count": 172,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_results.sort_values('alg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "correction, lr\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>seed</th>\n",
       "      <th>alpha</th>\n",
       "      <th>lambda</th>\n",
       "      <th>Cost</th>\n",
       "      <th>M1 Validity</th>\n",
       "      <th>WC Validity</th>\n",
       "      <th>M1 Expectation</th>\n",
       "      <th>WC Expectation</th>\n",
       "      <th>J</th>\n",
       "      <th>Loss</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>alg</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>OPT</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>5.22±0.6</td>\n",
       "      <td>1.0±0.0</td>\n",
       "      <td>0.71±0.2</td>\n",
       "      <td>0.72±0.01</td>\n",
       "      <td>0.52±0.03</td>\n",
       "      <td>1.71±0.14</td>\n",
       "      <td>0.66±0.06</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ROAR</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>3.51±0.96</td>\n",
       "      <td>0.84±0.15</td>\n",
       "      <td>0.04±0.04</td>\n",
       "      <td>0.61±0.05</td>\n",
       "      <td>0.32±0.07</td>\n",
       "      <td>2.24±0.41</td>\n",
       "      <td>1.54±0.55</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      seed  alpha  lambda       Cost M1 Validity WC Validity M1 Expectation  \\\n",
       "alg                                                                           \n",
       "OPT    2.0    0.5     0.2   5.22±0.6     1.0±0.0    0.71±0.2      0.72±0.01   \n",
       "ROAR   2.0    0.5     0.2  3.51±0.96   0.84±0.15   0.04±0.04      0.61±0.05   \n",
       "\n",
       "     WC Expectation          J       Loss  \n",
       "alg                                        \n",
       "OPT       0.52±0.03  1.71±0.14  0.66±0.06  \n",
       "ROAR      0.32±0.07  2.24±0.41  1.54±0.55  "
      ]
     },
     "execution_count": 152,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_results_avg = df_results.groupby('alg').mean()\n",
    "df_results_avg[['Cost', 'M1 Validity', 'WC Validity', 'M1 Expectation', 'WC Expectation',  'J', 'Loss']] = df_results_avg[['Cost', 'M1 Validity', 'WC Validity', 'M1 Expectation', 'WC Expectation', 'J', 'Loss']].round(2).astype(str) + '±' + df_results.groupby('alg').std()[['Cost', 'M1 Validity', 'WC Validity', 'M1 Expectation', 'WC Expectation', 'J', 'Loss']].round(2).astype(str)\n",
    "# df_results_avg[['Cost', 'M1 Validity', 'WC Validity',  'J', 'Loss']] = df_results_avg[['Cost', 'M1 Validity', 'WC Validity', 'J', 'Loss']].round(2).astype(str) + '±' + df_results.groupby('alg').std()[['Cost', 'M1 Validity', 'WC Validity', 'J', 'Loss']].round(2).astype(str)\n",
    "\n",
    "print(f'{data}, {base_model}')\n",
    "df_results_avg"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
