{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('..')\n",
    "from bilevel.build_all_models import *\n",
    "from bilevel.utils import *\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 79.1 ms, sys: 18.9 ms, total: 98 ms\n",
      "Wall time: 95.9 ms\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>young</th>\n",
       "      <th>middle</th>\n",
       "      <th>old</th>\n",
       "      <th>HighSchool&amp;less</th>\n",
       "      <th>College&amp;more</th>\n",
       "      <th>Male</th>\n",
       "      <th>Female</th>\n",
       "      <th>White</th>\n",
       "      <th>Asian-Pac-Islander</th>\n",
       "      <th>Amer-Indian-Eskimo</th>\n",
       "      <th>Other</th>\n",
       "      <th>Black</th>\n",
       "      <th>always_on</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>49526</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>49527</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>49528</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>49529</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>49530</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>49531 rows × 13 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       young  middle  old  HighSchool&less  College&more  Male  Female  White  \\\n",
       "0          0       1    0                0             1     0       1      1   \n",
       "1          1       0    0                0             1     1       0      1   \n",
       "2          1       0    0                1             0     1       0      1   \n",
       "3          0       0    1                1             0     1       0      0   \n",
       "4          1       0    0                0             1     1       0      1   \n",
       "...      ...     ...  ...              ...           ...   ...     ...    ...   \n",
       "49526      1       0    0                0             1     1       0      1   \n",
       "49527      0       1    0                0             1     1       0      0   \n",
       "49528      1       0    0                0             1     1       0      1   \n",
       "49529      1       0    0                0             1     0       1      1   \n",
       "49530      0       1    0                1             0     0       1      0   \n",
       "\n",
       "       Asian-Pac-Islander  Amer-Indian-Eskimo  Other  Black  always_on  \n",
       "0                       0                   0      0      0          1  \n",
       "1                       0                   0      0      0          1  \n",
       "2                       0                   0      0      0          1  \n",
       "3                       1                   0      0      0          1  \n",
       "4                       0                   0      0      0          1  \n",
       "...                   ...                 ...    ...    ...        ...  \n",
       "49526                   0                   0      0      0          1  \n",
       "49527                   1                   0      0      0          1  \n",
       "49528                   0                   0      0      0          1  \n",
       "49529                   0                   0      0      0          1  \n",
       "49530                   1                   0      0      0          1  \n",
       "\n",
       "[49531 rows x 13 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "df_adult = pd.read_csv('./data/adult_reconstruction.csv')\n",
    "\n",
    "A_t_age = pd.concat([df_adult['age'] <= 35,  # young\n",
    "        (df_adult['age'] > 35) & (df_adult['age'] <= 50), # middle\n",
    "        df_adult['age'] > 50], # old\n",
    "             axis = 1, keys = ['young', 'middle', 'old'])\n",
    "\n",
    "A_t_edu = pd.concat([df_adult['education-num'] <= 9, df_adult['education-num'] >= 10], axis = 1, \n",
    "                keys = ['HighSchool&less', 'College&more'])\n",
    "\n",
    "A_t_sex = pd.concat([df_adult['sex'] == 'Male', df_adult['sex'] == 'Female'], axis = 1, keys = ['Male', 'Female'])\n",
    "\n",
    "A_t_race = pd.concat([df_adult['race'] == 'White', df_adult['race'] == 'Asian-Pac-Islander',\n",
    "                      df_adult['race'] == 'Amer-Indian-Eskimo', df_adult['race'] == 'Other',\n",
    "                      df_adult['race'] == 'Black'], axis = 1, keys = ['White', 'Asian-Pac-Islander', 'Amer-Indian-Eskimo', 'Other', 'Black'])\n",
    "A_t_adult_groups = pd.concat([A_t_age, A_t_edu, A_t_sex, A_t_race], axis=1)\n",
    "A_t_adult_groups['always_on'] = True\n",
    "A_t_adult_groups *= 1\n",
    "A_t_adult_groups"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>age</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>education-num</th>\n",
       "      <th>income</th>\n",
       "      <th>workclass_?</th>\n",
       "      <th>workclass_Federal-gov</th>\n",
       "      <th>workclass_Local-gov</th>\n",
       "      <th>workclass_Never-worked</th>\n",
       "      <th>...</th>\n",
       "      <th>race_Black</th>\n",
       "      <th>race_Other</th>\n",
       "      <th>race_White</th>\n",
       "      <th>sex_Female</th>\n",
       "      <th>sex_Male</th>\n",
       "      <th>young</th>\n",
       "      <th>middle</th>\n",
       "      <th>old</th>\n",
       "      <th>HighSchool&amp;less</th>\n",
       "      <th>College&amp;more</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.193878</td>\n",
       "      <td>0.315068</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.800000</td>\n",
       "      <td>0.490460</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.397959</td>\n",
       "      <td>0.054795</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.600000</td>\n",
       "      <td>0.114053</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.091837</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.400000</td>\n",
       "      <td>0.024957</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.465753</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.533333</td>\n",
       "      <td>0.389320</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.377551</td>\n",
       "      <td>0.150685</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.800000</td>\n",
       "      <td>0.413376</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 97 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   hours-per-week       age  capital-gain  capital-loss  education-num  \\\n",
       "0        0.193878  0.315068           0.0           0.0       0.800000   \n",
       "1        0.397959  0.054795           0.0           0.0       0.600000   \n",
       "2        0.091837  0.000000           0.0           0.0       0.400000   \n",
       "3        0.500000  0.465753           0.0           0.0       0.533333   \n",
       "4        0.377551  0.150685           0.0           0.0       0.800000   \n",
       "\n",
       "     income  workclass_?  workclass_Federal-gov  workclass_Local-gov  \\\n",
       "0  0.490460          0.0                    0.0                  0.0   \n",
       "1  0.114053          0.0                    0.0                  0.0   \n",
       "2  0.024957          0.0                    0.0                  0.0   \n",
       "3  0.389320          0.0                    0.0                  0.0   \n",
       "4  0.413376          0.0                    0.0                  0.0   \n",
       "\n",
       "   workclass_Never-worked  ...  race_Black  race_Other  race_White  \\\n",
       "0                     0.0  ...         0.0         0.0         1.0   \n",
       "1                     0.0  ...         0.0         0.0         1.0   \n",
       "2                     0.0  ...         0.0         0.0         1.0   \n",
       "3                     0.0  ...         0.0         0.0         0.0   \n",
       "4                     0.0  ...         0.0         0.0         1.0   \n",
       "\n",
       "   sex_Female  sex_Male  young  middle  old  HighSchool&less  College&more  \n",
       "0         1.0       0.0      0       1    0                0             1  \n",
       "1         0.0       1.0      1       0    0                0             1  \n",
       "2         0.0       1.0      1       0    0                1             0  \n",
       "3         0.0       1.0      0       0    1                1             0  \n",
       "4         0.0       1.0      1       0    0                0             1  \n",
       "\n",
       "[5 rows x 97 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "added_indicators = ['young', 'middle', 'old', 'HighSchool&less', 'College&more'] # adding group indicators to dataframe, so baseline knows the groups too\n",
    "numeric_all = ['hours-per-week', 'age', 'capital-gain', 'capital-loss', 'education-num', 'income'] \n",
    "cat_feat =  ['workclass', 'marital-status', 'relationship', 'native-country', 'occupation', 'race', 'sex']\n",
    "\n",
    "df_adult.drop(['education'], axis = 1, inplace=True)\n",
    "df_adult_mm = numeric_scaler(df_adult, numeric_all)\n",
    "df_adult_mm_oh = one_hot(df_adult_mm, cat_feat)\n",
    "df_adult_mm_oh.drop(cat_feat, axis = 1, inplace=True)\n",
    "df_adult_mm_oh = df_adult_mm_oh * 1.0\n",
    "df_adult_mm_oh[added_indicators] = A_t_adult_groups[added_indicators]\n",
    "df_adult_mm_oh.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 49531/49531 [00:05<00:00, 8884.29it/s]\n",
      "100%|██████████| 49531/49531 [00:32<00:00, 1531.53it/s]\n",
      "100%|██████████| 49531/49531 [00:05<00:00, 8863.26it/s]\n",
      "100%|██████████| 49531/49531 [00:33<00:00, 1473.24it/s]\n",
      "100%|██████████| 49531/49531 [00:05<00:00, 8803.61it/s]\n",
      "100%|██████████| 49531/49531 [00:37<00:00, 1333.41it/s]\n",
      "100%|██████████| 49531/49531 [00:06<00:00, 7100.08it/s]\n",
      "100%|██████████| 49531/49531 [00:34<00:00, 1418.92it/s]\n",
      "100%|██████████| 49531/49531 [00:06<00:00, 7413.39it/s]\n",
      "100%|██████████| 49531/49531 [00:36<00:00, 1355.88it/s]\n",
      "100%|██████████| 49531/49531 [00:06<00:00, 7877.58it/s]\n",
      "100%|██████████| 49531/49531 [00:35<00:00, 1412.75it/s]\n",
      "100%|██████████| 49531/49531 [00:07<00:00, 6605.50it/s]\n",
      "100%|██████████| 49531/49531 [01:17<00:00, 639.32it/s] \n",
      "100%|██████████| 49531/49531 [00:13<00:00, 3581.74it/s]\n",
      "100%|██████████| 49531/49531 [00:49<00:00, 994.02it/s] \n",
      "100%|██████████| 49531/49531 [00:05<00:00, 8378.00it/s]\n",
      "100%|██████████| 49531/49531 [00:34<00:00, 1443.23it/s]\n",
      "100%|██████████| 49531/49531 [00:06<00:00, 7812.04it/s]\n",
      "100%|██████████| 49531/49531 [00:35<00:00, 1399.87it/s]\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "from bilevel.Groupwise_seedruns import BuildGroupwise_diffseeds\n",
    "ds_obj = BuildGroupwise_diffseeds(df_adult_mm_oh, 'income', A_t_adult_groups)\n",
    "ds_obj.build_all_seeds()\n",
    "ds_obj.build_df_res()\n",
    "ds_obj.build_regret_curve()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bilevel.Groupwise_seedruns import plot_regret_curve_with_std\n",
    "plot_regret_curve_with_std(ds_obj, './plots/adult')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bilevel.Groupwise_seedruns import get_end_regret_gw_df\n",
    "df_regend_adult = get_end_regret_gw_df(ds_obj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_regend_adult"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_regend_adult['mean_hindsight'].mean(axis=0), (df_regend_adult['mean_regend_Base'] - df_regend_adult['mean_regend_Anh']).mean(axis=0) # rough values for cumloss, regret difference base-our alg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_regend_adult.to_csv('./tables/adultincome.csv')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "greg",
   "language": "python",
   "name": "greg"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "422a1ee675848ad7ee73ac736eae01a8698556098f797c947729d7d9d67832dc"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
