{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "21c858e6-a03f-479f-a94a-d94fd67d17fb",
   "metadata": {},
   "source": [
    "# Synthetic data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5bd5e5b8-8565-4bd6-a34c-8a21c70a19a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('..') \n",
    "from bilevel.synth_datagen import SynthGenLinear\n",
    "from bilevel.utils import *\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1fd65587-742d-47cb-abab-d82d19927a64",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[3 1 4 2 0] ['green' 'square' 'red' 'triangle' 'circle']\n"
     ]
    }
   ],
   "source": [
    "params = {'samples': 100000, 'dim':20, \n",
    "        'group_dict': {'SHAPE':['circle', 'square', 'triangle'], 'COLOR': ['green', 'red']},\n",
    "        'prob_dict': {'SHAPE': [0.5, 0.3, 0.2], 'COLOR': [0.6, 0.4]},\n",
    "        'feat_lo': 0.0, 'feat_hi': 1.0, 'w_lo': 0.0, 'w_hi': 1.0,\n",
    "        'add_linear_mapping': True, 'add_quad_mapping' : False,\n",
    "        'S_lo': 0.0, 'S_hi':0.0,\n",
    "        'label_noise_width':0.16, 'drop_sensitive':False, 'fixed_seed':21,\n",
    "        }\n",
    "syn_ob = SynthGenLinear(**params) # SEED set to 21, for reproducibility in generation\n",
    "print(syn_ob.dperm, np.array(syn_ob.all_groupnames)[syn_ob.dperm])\n",
    "df = syn_ob.df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fdf69086-4b14-4f75-9ba5-19b21a0eb7d2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Index(['x_0', 'x_1', 'x_2', 'x_3', 'x_4', 'x_5', 'x_6', 'x_7', 'x_8', 'x_9',\n",
      "       'x_10', 'x_11', 'x_12', 'x_13', 'x_14', 'x_15', 'x_16', 'x_17', 'x_18',\n",
      "       'x_19', 'g_circle', 'g_square', 'g_triangle', 'g_green', 'g_red',\n",
      "       'y_circle', 'y_square', 'y_triangle', 'y_green', 'y_red',\n",
      "       'y_mean_active', 'y_min_active', 'y_max_active', 'y_dperm_active'],\n",
      "      dtype='object')\n",
      "['x_0', 'x_1', 'x_2', 'x_3', 'x_4', 'x_5', 'x_6', 'x_7', 'x_8', 'x_9', 'x_10', 'x_11', 'x_12', 'x_13', 'x_14', 'x_15', 'x_16', 'x_17', 'x_18', 'x_19'] ['y_circle', 'y_square', 'y_triangle', 'y_green', 'y_red', 'y_mean_active', 'y_min_active', 'y_max_active', 'y_dperm_active'] ['g_circle', 'g_square', 'g_triangle', 'g_green', 'g_red']\n"
     ]
    }
   ],
   "source": [
    "print(df.columns)\n",
    "filter_feature = [col for col in df if col.startswith('x')]\n",
    "filter_label = [col for col in df if col.startswith('y')]\n",
    "filter_group = [col for col in df if col.startswith('g')]\n",
    "print(filter_feature, filter_label, filter_group)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b2b6a3ae-c6a2-40bf-aea5-910a1416785c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([49857., 30044., 20099., 59985., 40015.]),\n",
       " ['circle', 'square', 'triangle', 'green', 'red'])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "syn_ob.A_t.sum(axis=0), syn_ob.all_groupnames"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2ab0a4f0-0890-4bcd-960a-9c164fac8ec1",
   "metadata": {},
   "outputs": [],
   "source": [
    "A_t = pd.DataFrame(syn_ob.A_t, columns = syn_ob.all_groupnames)\n",
    "A_t['always_on'] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2878c39f-53db-472c-a2eb-caab22cd96b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_mean = df[filter_feature+filter_group + ['y_mean_active']]\n",
    "df_min = df[filter_feature+filter_group + ['y_min_active']]\n",
    "df_max = df[filter_feature+filter_group + ['y_max_active']]\n",
    "df_dperm = df[filter_feature+filter_group + ['y_dperm_active']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1b198af8-9fcc-4db8-8ab3-edaa7a648fc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from bilevel.Groupwise_seedruns import BuildGroupwise_diffseeds"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14a61257-f102-4811-99ac-1ec62f1bb910",
   "metadata": {},
   "source": [
    "## y_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "51036fa3-90ca-4cee-9654-5b52c902a3f3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100000/100000 [00:02<00:00, 35402.93it/s]\n",
      "100%|██████████| 100000/100000 [00:15<00:00, 6626.04it/s]\n",
      "100%|██████████| 100000/100000 [00:02<00:00, 36555.33it/s]\n",
      "100%|██████████| 100000/100000 [00:15<00:00, 6629.86it/s]\n",
      "100%|██████████| 100000/100000 [00:03<00:00, 31758.85it/s]\n",
      "100%|██████████| 100000/100000 [00:14<00:00, 6868.54it/s]\n",
      "100%|██████████| 100000/100000 [00:02<00:00, 36556.46it/s]\n",
      "100%|██████████| 100000/100000 [00:14<00:00, 6802.24it/s]\n",
      "100%|██████████| 100000/100000 [00:02<00:00, 36581.09it/s]\n",
      "100%|██████████| 100000/100000 [00:14<00:00, 6762.17it/s]\n",
      "100%|██████████| 100000/100000 [00:02<00:00, 36925.79it/s]\n",
      "100%|██████████| 100000/100000 [00:15<00:00, 6656.12it/s]\n",
      "100%|██████████| 100000/100000 [00:03<00:00, 31276.56it/s]\n",
      "100%|██████████| 100000/100000 [00:15<00:00, 6631.63it/s]\n",
      "100%|██████████| 100000/100000 [00:02<00:00, 37337.39it/s]\n",
      "100%|██████████| 100000/100000 [00:14<00:00, 6929.97it/s]\n",
      "100%|██████████| 100000/100000 [00:02<00:00, 35187.55it/s]\n",
      "100%|██████████| 100000/100000 [00:14<00:00, 7018.76it/s]\n",
      "100%|██████████| 100000/100000 [00:02<00:00, 36162.88it/s]\n",
      "100%|██████████| 100000/100000 [00:14<00:00, 6786.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 4min 37s, sys: 13min 37s, total: 18min 14s\n",
      "Wall time: 4min 40s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "ds_ymean = BuildGroupwise_diffseeds(df_mean, 'y_mean_active', A_t) # different seeds object\n",
    "ds_ymean.build_all_seeds()\n",
    "ds_ymean.build_df_res()\n",
    "ds_ymean.build_regret_curve()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7b84f86-957b-499c-8503-665576c5fd64",
   "metadata": {},
   "source": [
    "## y_min"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43af2c1d-2c34-45d8-8738-cf4b4d0a846d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100000/100000 [00:02<00:00, 34890.44it/s]\n",
      "100%|██████████| 100000/100000 [00:14<00:00, 6696.81it/s]\n",
      "100%|██████████| 100000/100000 [00:02<00:00, 33424.26it/s]\n",
      "100%|██████████| 100000/100000 [00:14<00:00, 6742.08it/s]\n",
      "100%|██████████| 100000/100000 [00:02<00:00, 36595.75it/s]\n",
      "100%|██████████| 100000/100000 [00:14<00:00, 6765.56it/s]\n",
      "100%|██████████| 100000/100000 [00:02<00:00, 36098.58it/s]\n",
      "100%|██████████| 100000/100000 [00:15<00:00, 6555.79it/s]\n",
      "100%|██████████| 100000/100000 [00:02<00:00, 36280.44it/s]\n",
      "100%|██████████| 100000/100000 [00:15<00:00, 6588.85it/s]\n",
      "100%|██████████| 100000/100000 [00:02<00:00, 35289.30it/s]\n",
      "100%|██████████| 100000/100000 [00:15<00:00, 6574.78it/s]\n",
      "100%|██████████| 100000/100000 [00:03<00:00, 33127.92it/s]\n",
      "100%|██████████| 100000/100000 [00:24<00:00, 4150.67it/s]\n",
      "100%|██████████| 100000/100000 [00:04<00:00, 21827.42it/s]\n",
      "100%|██████████| 100000/100000 [00:25<00:00, 3872.76it/s]\n",
      "100%|██████████| 100000/100000 [00:06<00:00, 16508.18it/s]\n",
      "100%|██████████| 100000/100000 [00:25<00:00, 3947.76it/s]\n",
      "100%|██████████| 100000/100000 [00:05<00:00, 18955.41it/s]\n",
      "100%|██████████| 100000/100000 [00:23<00:00, 4346.00it/s]\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "ds_ymin = BuildGroupwise_diffseeds(df_min, 'y_min_active', A_t) # different seeds object\n",
    "ds_ymin.build_all_seeds()\n",
    "ds_ymin.build_df_res()\n",
    "ds_ymin.build_regret_curve()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f1787bd-6749-4750-972a-697fbe08f29c",
   "metadata": {},
   "source": [
    "## y_max"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa150df8-dcdd-4271-97cc-0aa9622e75a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_ymax = BuildGroupwise_diffseeds(df_max, 'y_max_active', A_t) # different seeds object\n",
    "ds_ymax.build_all_seeds()\n",
    "ds_ymax.build_df_res()\n",
    "ds_ymax.build_regret_curve()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6dd3d66-d8db-405e-9d8b-06d52b265172",
   "metadata": {},
   "source": [
    "## y_dperm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf1a8802-d435-4f84-af58-b27874b026fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_ydperm = BuildGroupwise_diffseeds(df_dperm, 'y_dperm_active', A_t) # different seeds object\n",
    "ds_ydperm.build_all_seeds()\n",
    "ds_ydperm.build_df_res()\n",
    "ds_ydperm.build_regret_curve()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f90c1e75-0b10-4353-9435-a05355ccc42b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_reg_sidebyside(gwise_obj: BuildGroupwise_diffseeds, dir_name:str):\n",
    "    for g_ind, gname in enumerate(gwise_obj.group_names):\n",
    "        gwise_obj.regret_Anh_groupwise_array[g_ind] = np.array(gwise_obj.regret_Anh_groupwise_array[g_ind]) # all 10 values in the row have same dim, so can make np array\n",
    "        gwise_obj.regret_Base_groupwise_array[g_ind] = np.array(gwise_obj.regret_Base_groupwise_array[g_ind])\n",
    "        print(gname, gwise_obj.group_sizes[g_ind])\n",
    "        mean_reg_Anh, sd_reg_Anh = gwise_obj.regret_Anh_groupwise_array[g_ind].mean(axis = 0), gwise_obj.regret_Anh_groupwise_array[g_ind].std(axis = 0)\n",
    "        mean_reg_Base, sd_reg_Base = gwise_obj.regret_Base_groupwise_array[g_ind].mean(axis = 0), gwise_obj.regret_Base_groupwise_array[g_ind].std(axis = 0)\n",
    "        plt.figure(figsize=[12.8, 4.8]) # 2x default figure size\n",
    "        plt.subplot(121)\n",
    "        plt.plot(gwise_obj.pos[g_ind], mean_reg_Base, color = 'C0', label = 'Baseline')\n",
    "        plt.fill_between(gwise_obj.pos[g_ind], mean_reg_Base - sd_reg_Base, mean_reg_Base + sd_reg_Base, alpha = 0.5, color = 'C0')\n",
    "        plt.legend()\n",
    "        plt.xlabel('time')\n",
    "        plt.ylabel('Regret')\n",
    "        plt.title(gname)\n",
    "        \n",
    "        plt.subplot(122)\n",
    "        plt.plot(gwise_obj.pos[g_ind], mean_reg_Anh, color = 'C1', label = 'Our algorithm')\n",
    "        plt.fill_between(gwise_obj.pos[g_ind], mean_reg_Anh - sd_reg_Anh, mean_reg_Anh + sd_reg_Anh, alpha = 0.5, color = 'C1')\n",
    "        plt.legend()\n",
    "        plt.xlabel('time')\n",
    "        plt.ylabel('Regret')\n",
    "        plt.title(gname)\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05896643-1be6-4aab-bdac-9c613f37fa32",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_reg_sidebyside(ds_ymean, \"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "022542bf-6200-4f73-8ff7-9971aeb784b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from bilevel.Groupwise_seedruns import get_end_regret_gw_df\n",
    "df_regend_ymin = get_end_regret_gw_df(ds_ymin)\n",
    "df_regend_ymean = get_end_regret_gw_df(ds_ymean)\n",
    "df_regend_ymax = get_end_regret_gw_df(ds_ymax)\n",
    "df_regend_ydperm = get_end_regret_gw_df(ds_ydperm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48dd2a88-4913-4a2b-b52b-1a3c049c043d",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_regend_ymin.to_csv('./tables/synth_ymin.csv')\n",
    "df_regend_ymean.to_csv('./tables/synth_ymean.csv')\n",
    "df_regend_ymax.to_csv('./tables/synth_ymax.csv')\n",
    "df_regend_ydperm.to_csv('./tables/synth_ydperm.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a0d5683-6895-4781-92b6-3cc8861321a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_regend_ymean['mean_hindsight'].mean(axis=0), (df_regend_ymean['mean_regend_Base'] - df_regend_ymean['mean_regend_Anh']).mean(axis=0) # rough values mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "532bc34e-ba89-4645-8e83-54f5ace48d21",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_regend_ymin['mean_hindsight'].mean(axis=0), (df_regend_ymin['mean_regend_Base'] - df_regend_ymin['mean_regend_Anh']).mean(axis=0) # rough values  min "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e91140b-16ee-4e21-b249-cb7cbaf2e99c",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_regend_ymax['mean_hindsight'].mean(axis=0), (df_regend_ymax['mean_regend_Base'] - df_regend_ymax['mean_regend_Anh']).mean(axis=0) # rough values  max"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbdf0c3c-8855-4dc9-b111-58ab1c194724",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_regend_ydperm['mean_hindsight'].mean(axis=0), (df_regend_ydperm['mean_regend_Base'] - df_regend_ydperm['mean_regend_Anh']).mean(axis=0) # rough values  dperm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "120c1884-54e1-434a-ade1-82c58a2097f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from bilevel.Groupwise_seedruns import plot_regret_curve_with_std\n",
    "plot_regret_curve_with_std(ds_ymean, './plots/synth_mean')\n",
    "\n",
    "plot_regret_curve_with_std(ds_ymin, './plots/synth_min')\n",
    "\n",
    "plot_regret_curve_with_std(ds_ymax, './plots/synth_max')\n",
    "\n",
    "plot_regret_curve_with_std(ds_ydperm, './plots/synth_perm')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "greg",
   "language": "python",
   "name": "greg"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "422a1ee675848ad7ee73ac736eae01a8698556098f797c947729d7d9d67832dc"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
