{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%load_ext line_profiler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from baselines.dp_dg import dp_dg_experiment\n",
    "from argparse import Namespace\n",
    "from torch import device\n",
    "from convertors import get_train_test_index, write_dp_dg_format_train_test_data, DPDG_ADULT_COLUMN_ORDER\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import os\n",
    "import re\n",
    "from IPython.display import display\n",
    "from fairpate_tabular.utils import get_disparity\n",
    "from baselines.run_dp_dg import find_best_epoch_number, fill_in_preds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adult_dpfermi = pd.read_csv(\"Datasets/Adult/adult_original_purified.csv\").reset_index(drop=True)\n",
    "adult_dpdg = pd.concat([pd.read_csv(\"baselines/dp_dg/data/backup/train.csv\"), pd.read_csv(\"baselines/dp_dg/data/backup/test.csv\")]).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(adult_dpdg.columns.difference(adult_dpfermi.columns))\n",
    "print(adult_dpfermi.columns.difference(adult_dpdg.columns))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Workclass does not match"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for col in adult_dpfermi.columns:\n",
    "    print(col+\":\")\n",
    "    print(\"dpdg:\", adult_dpdg['income' if col==\">50K\" else 'work-class' if col=='workclass' else col].describe())\n",
    "    print(\"dpfermi:\", adult_dpfermi[col].describe())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adult_dpdg[\"work-class\"].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adult_dpfermi[\"workclass\"].value_counts()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### DP-Fermi Dataset has more samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adult_dpdg[\">50K\"] = adult_dpdg[\"income\"].apply(lambda x: \"Yes\" if x == \">50K\" else \"No\")\n",
    "adult_dpdg[\"workclass\"] = adult_dpdg[\"work-class\"]\n",
    "adult_dpdg = adult_dpdg.drop([\"income\", \"work-class\"], axis=1)\n",
    "adult_dpdg = adult_dpdg[adult_dpfermi.columns]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df1mi = pd.MultiIndex.from_frame(adult_dpfermi)\n",
    "df2mi = pd.MultiIndex.from_frame(adult_dpdg)\n",
    "dfdiff = df2mi.difference(df1mi).to_frame().reset_index(drop=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(dfdiff), len(adult_dpdg), len(adult_dpfermi))\n",
    "assert len(dfdiff) == len(adult_dpdg) + len(adult_dpfermi)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_name = \"run_3\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fairpate_config = Namespace(dataset='adult', list_dataset=None, num_classes=2, output_col_name='>50K', split=0.75, dem_disparity_interpretation='max_vs_min', teacher_query_set_split=0.7, num_teachers=4, list_num_teachers=None, threshold=2, list_threshold=None, fairness_threshold=0.2, list_fairness_threshold=None, sigma_threshold=60, list_sigma_threshold=None, sigma_fair_threshold=0, sigma_gnmax=25, list_sigma_gnmax=None, budget=1000, list_budget=None, delta=1e-05, verbose=True, seed=0, list_seed=None, data_path='./fairpate_tabular/data/', min_group_count=50, results_dir='.', use_optuna=False, num_optuna_trials=1000, use_stratification=False, fairness_metric='DemParity', list_fairness_metric=None, num_calib=100, pate_based_model='fairpate', use_inference_time_postprocessing=False, undersampling_ratio=None, optuna_db_path='.', path='./Datasets/Adult/adult_original_purified.csv', num_inp_attr=102, cols_to_norm=['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week'], sensitive_attributes=['sex'], results_db_path='./fairpate_adult_DemParity_results.parquet', gt_fairness=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# this needs to be done at least once. The train/test split (the seed, etc.) is not important, but the pre-processing and ordering is.\n",
    "# write_dp_dg_format_train_test_data(\n",
    "#         seed=fairpate_config.seed,\n",
    "#         split=fairpate_config.split, \n",
    "#         undersampling_ratio=fairpate_config.undersampling_ratio, \n",
    "        \n",
    "#         path='./Datasets/Adult/adult_original_purified.csv', \n",
    "#         data_path='./fairpate_tabular/data/',\n",
    "#         save_path='./baselines/dp_dg/data/adult_v1.0',\n",
    "        \n",
    "#         dataset=fairpate_config.dataset, \n",
    "#         output_col_name=fairpate_config.output_col_name, \n",
    "\n",
    "#         cols_to_norm=fairpate_config.cols_to_norm,\n",
    "#         sensitive_attributes=fairpate_config.sensitive_attributes,\n",
    "#         column_order=DPDG_ADULT_COLUMN_ORDER,\n",
    "# )\n",
    "\n",
    "train_test_index = get_train_test_index(\n",
    "    seed=fairpate_config.seed,\n",
    "    split=fairpate_config.split, \n",
    "    undersampling_ratio=fairpate_config.undersampling_ratio, \n",
    "    \n",
    "    path='./Datasets/Adult/adult_original_purified.csv', \n",
    "    data_path='./fairpate_tabular/data/',\n",
    "\n",
    "    dataset=fairpate_config.dataset,\n",
    "    output_col_name=fairpate_config.output_col_name, \n",
    "\n",
    "    cols_to_norm=fairpate_config.cols_to_norm,\n",
    "    sensitive_attributes=fairpate_config.sensitive_attributes)\n",
    "\n",
    "train_index, test_index = train_test_index\n",
    "val_to_all_train_ratio = 3000./43000. #using the same ratio as in dp-dg experiment\n",
    "train_index, validation_index = \\\n",
    "    sklearn.model_selection.train_test_split(train_index, test_size=val_to_all_train_ratio, random_state=fairpate_config.seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dpdg_config = Namespace(\n",
    "    dataset='adult', \n",
    "    algorithm='ERM', \n",
    "    root_dir='./baselines/dp_dg/data', \n",
    "    enable_privacy=True, \n",
    "    enable_fair_privacy=False, \n",
    "    apply_noise=False, \n",
    "    split_scheme='official', \n",
    "    dataset_kwargs={}, \n",
    "    download=False, \n",
    "    subsample=False, \n",
    "    frac=1.0, \n",
    "    version=None, \n",
    "    loader_kwargs={'num_workers': 1, 'pin_memory': True}, \n",
    "    train_loader='standard', \n",
    "    uniform_over_groups=False, \n",
    "    distinct_groups=None, \n",
    "    n_groups_per_batch=4, \n",
    "    batch_size=1024,  #instead of 256\n",
    "    eval_loader='standard', \n",
    "    weighted_uniform_iid=True, \n",
    "    uniform_iid=None, \n",
    "    sample_rate=0.005, \n",
    "    clip_sample_rate=None,\n",
    "    model='logistic_regression', \n",
    "    model_kwargs={'in_features': 85}, # 85 instead of 86 since work-class does not have \"Never-worked\" in this version of the dataset\n",
    "    transform=None, \n",
    "    target_resolution=None, \n",
    "    resize_scale=None, \n",
    "    max_token_length=None, \n",
    "    loss_function='cross_entropy', \n",
    "    loss_kwargs={}, \n",
    "    groupby_fields=['sex', 'y'], \n",
    "    group_dro_step_size=None, \n",
    "    coral_penalty_weight=None, \n",
    "    irm_lambda=None, \n",
    "    irm_penalty_anneal_iters=None, \n",
    "    algo_log_metric='accuracy', \n",
    "    val_metric='acc_wg', \n",
    "    val_metric_decreasing=False, \n",
    "    n_epochs=20, # instead of 20s\n",
    "    optimizer='SGD', \n",
    "    lr=0.22360679774997896, \n",
    "    weight_decay=0.01, \n",
    "    max_grad_norm=None, \n",
    "    optimizer_kwargs={'momentum': 0.9}, sigma=5.0, \n",
    "    max_per_sample_grad_norm=0.5, \n",
    "    delta=1e-05, \n",
    "    sigma2=1.0, \n",
    "    C0=1.0, \n",
    "    scheduler=None, \n",
    "    scheduler_kwargs={}, \n",
    "    scheduler_metric_split='val',\n",
    "    scheduler_metric_name=None, \n",
    "    process_outputs_function='multiclass_logits_to_pred', \n",
    "    evaluate_all_splits=True, \n",
    "    eval_splits=[], \n",
    "    eval_only=True, \n",
    "    eval_epoch=None, \n",
    "    device=device(type='cpu'), \n",
    "    seed=1, \n",
    "    log_dir=f'./baselines/dp_dg/logs/adult/{experiment_name}',\n",
    "    log_every=50, \n",
    "    save_step=None, \n",
    "    save_best=True, \n",
    "    save_last=True, \n",
    "    save_pred=True, \n",
    "    no_group_logging=False, \n",
    "    use_wandb=False, \n",
    "    progress_bar=False, \n",
    "    resume=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "dp_dg_experiment(dpdg_config, train_val_test_index=(train_index, validation_index, test_index))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_epoch = find_best_epoch_number(f\"baselines/dp_dg/logs/adult/{experiment_name}\")\n",
    "best_epoch_val_disparity = pd.read_csv(f\"baselines/dp_dg/logs/adult/{experiment_name}/adult_split:val_seed:1_epoch:best_disparity.csv\")\n",
    "best_epoch_test_disparity = pd.read_csv(f\"baselines/dp_dg/logs/adult/{experiment_name}/adult_split:test_seed:1_epoch:best_disparity.csv\")\n",
    "\n",
    "results_df = pd.concat([\n",
    "                    pd.read_csv(f\"baselines/dp_dg/logs/adult/{experiment_name}/val_eval.csv\").query(f\"epoch == {best_epoch}\")\n",
    "                    .assign(split=\"validation\")\n",
    "                    .join(best_epoch_val_disparity, how=\"outer\"),\n",
    "                    pd.read_csv(f\"baselines/dp_dg/logs/adult/{experiment_name}/test_eval.csv\").query(f\"epoch == {best_epoch}\")\n",
    "                    .assign(split=\"test\")\n",
    "                    .join(best_epoch_test_disparity, how=\"outer\"),\n",
    "    ]).set_index([\"split\"])\n",
    "\n",
    "results_df = results_df[[\n",
    "        \"acc_avg\",\n",
    "        \"acc_wg\",\n",
    "        \"epsilon\",\n",
    "        \"best_alpha\",\n",
    "        \"dem_disparity\",\n",
    "        \"eo_disparity\",\n",
    "        \"ep_disparity\"\n",
    "    ]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Accuracy Sanity Check and Disparity Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adult_reuslts = adult_dpfermi.copy(deep=True)\n",
    "adult_reuslts[\"best_pred\"] = np.nan\n",
    "adult_reuslts[\"last_pred\"] = np.nan\n",
    "adult_reuslts[\"split\"] = pd.concat([pd.Series(\"train\", index=train_index), \n",
    "                                    pd.Series(\"validation\", index=validation_index), \n",
    "                                    pd.Series(\"test\", index=test_index)], axis=0).sort_index()\n",
    "\n",
    "adult_reuslts = fill_in_preds(adult_reuslts, validation_index, \"best_pred\", f\"baselines/dp_dg/logs/adult/{experiment_name}/adult_split:val_seed:1_epoch:best_pred.csv\")\n",
    "adult_reuslts = fill_in_preds(adult_reuslts, validation_index, \"last_pred\", f\"baselines/dp_dg/logs/adult/{experiment_name}/adult_split:val_seed:1_epoch:last_pred.csv\")\n",
    "adult_reuslts = fill_in_preds(adult_reuslts, test_index, \"best_pred\", f\"baselines/dp_dg/logs/adult/{experiment_name}/adult_split:test_seed:1_epoch:best_pred.csv\")\n",
    "adult_reuslts = fill_in_preds(adult_reuslts, test_index, \"last_pred\", f\"baselines/dp_dg/logs/adult/{experiment_name}/adult_split:test_seed:1_epoch:last_pred.csv\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adult_reuslts.query(\"split in ['test', 'validation']\").groupby([\"split\", \">50K\", \"sex\"]).apply(lambda x: len(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(f\"baselines/dp_dg/logs/adult/{experiment_name}/log.txt\") as f:\n",
    "    validation_lines = [l_idx for l_idx, l in enumerate(f.readlines()) if \"Validation\" in l]\n",
    "\n",
    "with open(f\"baselines/dp_dg/logs/adult/{experiment_name}/log.txt\") as f:    \n",
    "    for l in f.readlines()[validation_lines[-2]:]:\n",
    "        print(l, end=\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dropping the trian set since we do not have labels for them\n",
    "adult_reuslts = adult_reuslts.query(\"split != 'train'\")\n",
    "\n",
    "adult_reuslts[\"sensitive\"] = adult_reuslts[\"sex\"].apply(lambda x: x == \"Male\").astype(int)\n",
    "adult_reuslts[\"label\"] = adult_reuslts[\">50K\"].apply(lambda x: x == \"Yes\").astype(int)\n",
    "\n",
    "print(\"= Accuracy\")\n",
    "print(\"== best model:\")\n",
    "print(adult_reuslts.groupby(\"split\").apply(lambda x: (x[\"best_pred\"].astype(int) == x[\"label\"]).astype(float).sum()/len(x)))\n",
    "print(\"== last model:\")\n",
    "print(adult_reuslts.groupby(\"split\").apply(lambda x: (x[\"last_pred\"].astype(int) == x[\"label\"]).astype(float).sum()/len(x)))\n",
    "\n",
    "print(\"= Demographic Parity\")\n",
    "print(\"== best model:\")\n",
    "print(adult_reuslts.groupby(\"split\").apply(lambda x: get_disparity(\"DemParity\", x[\"best_pred\"], x[\"sensitive\"])))\n",
    "print(\"== last model:\")\n",
    "print(adult_reuslts.groupby(\"split\").apply(lambda x: get_disparity(\"DemParity\", x[\"last_pred\"], x[\"sensitive\"])))\n",
    "print(\"\")\n",
    "\n",
    "print(\"= Equality of Odds\")\n",
    "print(\"== best model:\")\n",
    "print(adult_reuslts.groupby(\"split\").apply(lambda x: get_disparity(\"EqualityOfOdds\", x[\"best_pred\"], x[\"sensitive\"], x[\"label\"])))\n",
    "print(\"== last model:\")\n",
    "print(adult_reuslts.groupby(\"split\").apply(lambda x: get_disparity(\"EqualityOfOdds\", x[\"last_pred\"], x[\"sensitive\"], x[\"label\"])))\n",
    "print(\"\")\n",
    "\n",
    "print(\"= Error Parity\")\n",
    "print(\"== best model:\")\n",
    "print(adult_reuslts.groupby(\"split\").apply(lambda x: get_disparity(\"ErrorParity\", x[\"best_pred\"], x[\"sensitive\"], x[\"label\"])))\n",
    "print(\"== last model:\")\n",
    "print(adult_reuslts.groupby(\"split\").apply(lambda x: get_disparity(\"ErrorParity\", x[\"last_pred\"], x[\"sensitive\"], x[\"label\"])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fairpate_tabular.fairpate import process_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "write_dp_dg_format_train_test_data(\n",
    "    rng=0, seed=0, \n",
    "    split=0.75, \n",
    "    undersampling_ratio=None, \n",
    "    \n",
    "    path='./Datasets/Adult/adult_original_purified.csv', \n",
    "    data_path='./fairpate_tabular/data/',\n",
    "    save_path='./baselines/dp_dg/data/adult_v1.0',\n",
    "    \n",
    "    dataset='adult', \n",
    "    output_col_name='>50K', \n",
    "\n",
    "    cols_to_norm=['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week'], \n",
    "    sensitive_attributes=['sex'],\n",
    "    column_order=['age', 'work-class', 'fnlwgt', 'education', 'education-num',\n",
    "       'marital-status', 'occupation', 'relationship', 'race', 'sex',\n",
    "       'capital-gain', 'capital-loss', 'hours-per-week', 'native-country',\n",
    "       'income']\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%lprun -f process_data process_data(rng=3, args=args, log=print, return_train_test_index=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Processing Speedup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Other Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataloader import GeneralData"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parkinsons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "fairpate_config = Namespace(dataset='parkinsons', list_dataset=None, num_classes=2, output_col_name='total_UPDRS', split=0.75, dem_disparity_interpretation='max_vs_min', teacher_query_set_split=0.7, num_teachers=4, list_num_teachers=None, threshold=2, list_threshold=None, fairness_threshold=0.2, list_fairness_threshold=None, sigma_threshold=60, list_sigma_threshold=None, sigma_fair_threshold=0, sigma_gnmax=25, list_sigma_gnmax=None, budget=1000, list_budget=None, delta=1e-05, verbose=False, seed=0, list_seed=None, data_path='./fairpate_tabular/data/', min_group_count=50, results_dir='.', use_optuna=False, num_optuna_trials=1000, use_stratification=False, fairness_metric='DemParity', list_fairness_metric=None, num_calib=100, pate_based_model='fairpate', use_inference_time_postprocessing=False, undersampling_ratio=None, optuna_db_path='.', path='./Datasets/Parkinsons/parkinsons_updrs_processed.csv', num_inp_attr=19, cols_to_norm=['age', 'test_time', 'motor_UPDRS', 'Jitter(%)', 'Jitter(Abs)', 'Jitter:RAP', 'Jitter:PPQ5', 'Jitter:DDP', 'Shimmer', 'Shimmer(dB)', 'Shimmer:APQ3', 'Shimmer:APQ5', 'Shimmer:APQ11', 'Shimmer:DDA', 'NHR', 'HNR', 'RPDE', 'DFA', 'PPE'], sensitive_attributes=['sex'], results_db_path='./fairpate_parkinsons_DemParity_results.parquet', gt_fairness=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_name = fairpate_config.dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_data = GeneralData(path = fairpate_config.path, rng=None, sensitive_attributes = fairpate_config.sensitive_attributes, cols_to_norm = fairpate_config.cols_to_norm, output_col_name = fairpate_config.output_col_name, split = fairpate_config.split)\n",
    "\n",
    "full_data.getTrain(return_tensor=False).get_preprocessed_df().to_csv(f\"baselines/dp_dg/data/{fairpate_config.dataset}_v1.0/train.csv\", index=False)\n",
    "full_data.getTest(return_tensor=False).get_preprocessed_df().to_csv(f\"./baselines/dp_dg/data/{fairpate_config.dataset}_v1.0/test.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_test_index = get_train_test_index(\n",
    "    seed=fairpate_config.seed,\n",
    "    dataset=fairpate_config.dataset,\n",
    "    split=fairpate_config.split, \n",
    "    path=fairpate_config.path, \n",
    "    data_path='./baselines/data/',\n",
    "    cols_to_norm=fairpate_config.cols_to_norm,\n",
    "    sensitive_attributes=fairpate_config.sensitive_attributes,\n",
    "    undersampling_ratio=fairpate_config.undersampling_ratio, \n",
    "    output_col_name=fairpate_config.output_col_name, \n",
    "    )\n",
    "\n",
    "train_index, test_index = train_test_index\n",
    "val_to_all_train_ratio = 3000./43000. #using the same ratio as in dp-dg experiment\n",
    "train_index, validation_index = \\\n",
    "    sklearn.model_selection.train_test_split(train_index, test_size=val_to_all_train_ratio, random_state=fairpate_config.seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "dpdg_config = Namespace(\n",
    "    dataset='preprocessed', # this is to force tabular data (like adult) datasets that includes `parkinsons` and `credit-card`\n",
    "    algorithm='ERM', \n",
    "    root_dir='./baselines/dp_dg/data/',\n",
    "    enable_privacy=True, \n",
    "    enable_fair_privacy=False, \n",
    "    apply_noise=False, \n",
    "    split_scheme='official', \n",
    "    dataset_kwargs= dict(\n",
    "            dataset_name=fairpate_config.dataset,\n",
    "            csv_file_names=['train.csv', 'test.csv'],\n",
    "            seed=fairpate_config.seed,\n",
    "            y_col=\"y\",\n",
    "            sensitive_col=\"z\",\n",
    "            feat_cols=None), \n",
    "    download=False, \n",
    "    subsample=False, \n",
    "    frac=1.0, \n",
    "    version=None, \n",
    "    loader_kwargs={'num_workers': 1, 'pin_memory': True}, \n",
    "    train_loader='standard', \n",
    "    uniform_over_groups=False, \n",
    "    distinct_groups=None, \n",
    "    n_groups_per_batch=4, \n",
    "    batch_size=1024,  #instead of 256\n",
    "    eval_loader='standard', \n",
    "    weighted_uniform_iid=True, \n",
    "    uniform_iid=None, \n",
    "    sample_rate=0.005, \n",
    "    clip_sample_rate=None,\n",
    "    model='logistic_regression', \n",
    "    model_kwargs={'in_features': 20}, #changed\n",
    "    transform=None, \n",
    "    target_resolution=None, \n",
    "    resize_scale=None, \n",
    "    max_token_length=None, \n",
    "    loss_function='cross_entropy', \n",
    "    loss_kwargs={}, \n",
    "    groupby_fields=['z', 'y'],\n",
    "    group_dro_step_size=None, \n",
    "    coral_penalty_weight=None, \n",
    "    irm_lambda=None, \n",
    "    irm_penalty_anneal_iters=None, \n",
    "    algo_log_metric='accuracy', \n",
    "    val_metric='acc_wg', \n",
    "    val_metric_decreasing=False, \n",
    "    n_epochs=20, # instead of 20s\n",
    "    optimizer='SGD', \n",
    "    lr=0.22360679774997896, \n",
    "    weight_decay=0.01, \n",
    "    max_grad_norm=None, \n",
    "    optimizer_kwargs={'momentum': 0.9}, sigma=5.0, \n",
    "    max_per_sample_grad_norm=0.5, \n",
    "    delta=1e-05, \n",
    "    sigma2=1.0, \n",
    "    C0=1.0, \n",
    "    scheduler=None, \n",
    "    scheduler_kwargs={}, \n",
    "    scheduler_metric_split='val',\n",
    "    scheduler_metric_name=None, \n",
    "    process_outputs_function='multiclass_logits_to_pred', \n",
    "    evaluate_all_splits=True, \n",
    "    eval_splits=[], \n",
    "#     eval_only=False, \n",
    "    eval_only=True,\n",
    "    eval_epoch=None, \n",
    "    device=device(type='cpu'), \n",
    "    seed=1, \n",
    "    log_dir=f'./baselines/dp_dg/logs/{experiment_name}',\n",
    "    log_every=50, \n",
    "    save_step=None, \n",
    "    save_best=True, \n",
    "    save_last=True, \n",
    "    save_pred=True, \n",
    "    no_group_logging=False, \n",
    "    use_wandb=False, \n",
    "    progress_bar=False, \n",
    "    resume=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%capture\n",
    "dp_dg_experiment(dpdg_config, train_val_test_index=(train_index, validation_index, test_index))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>acc_avg</th>\n",
       "      <th>acc_wg</th>\n",
       "      <th>epsilon</th>\n",
       "      <th>best_alpha</th>\n",
       "      <th>dem_disparity</th>\n",
       "      <th>eo_disparity</th>\n",
       "      <th>ep_disparity</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>split</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>validation</th>\n",
       "      <td>0.905844</td>\n",
       "      <td>0.870370</td>\n",
       "      <td>0.109377</td>\n",
       "      <td>63.0</td>\n",
       "      <td>0.090802</td>\n",
       "      <td>0.058587</td>\n",
       "      <td>0.255556</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>validation</th>\n",
       "      <td>0.905844</td>\n",
       "      <td>0.870370</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>test</th>\n",
       "      <td>0.899932</td>\n",
       "      <td>0.887805</td>\n",
       "      <td>0.109377</td>\n",
       "      <td>63.0</td>\n",
       "      <td>0.079812</td>\n",
       "      <td>0.023004</td>\n",
       "      <td>0.030719</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>test</th>\n",
       "      <td>0.899932</td>\n",
       "      <td>0.887805</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "             acc_avg    acc_wg   epsilon  best_alpha  dem_disparity  \\\n",
       "split                                                                 \n",
       "validation  0.905844  0.870370  0.109377        63.0       0.090802   \n",
       "validation  0.905844  0.870370  0.000000         0.0            NaN   \n",
       "test        0.899932  0.887805  0.109377        63.0       0.079812   \n",
       "test        0.899932  0.887805  0.000000         0.0            NaN   \n",
       "\n",
       "            eo_disparity  ep_disparity  \n",
       "split                                   \n",
       "validation      0.058587      0.255556  \n",
       "validation           NaN           NaN  \n",
       "test            0.023004      0.030719  \n",
       "test                 NaN           NaN  "
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "best_epoch = find_best_epoch_number(f\"baselines/dp_dg/logs/{experiment_name}\")\n",
    "best_epoch_val_disparity = pd.read_csv(f\"baselines/dp_dg/logs/{experiment_name}/parkinsons_split:val_seed:1_epoch:best_disparity.csv\")\n",
    "best_epoch_test_disparity = pd.read_csv(f\"baselines/dp_dg/logs/{experiment_name}/parkinsons_split:test_seed:1_epoch:best_disparity.csv\")\n",
    "\n",
    "results_df = pd.concat([\n",
    "                    pd.read_csv(f\"baselines/dp_dg/logs/{experiment_name}/val_eval.csv\").query(f\"epoch == {best_epoch}\")\n",
    "                    .assign(split=\"validation\")\n",
    "                    .join(best_epoch_val_disparity, how=\"outer\"),\n",
    "                    pd.read_csv(f\"baselines/dp_dg/logs/{experiment_name}/test_eval.csv\").query(f\"epoch == {best_epoch}\")\n",
    "                    .assign(split=\"test\")\n",
    "                    .join(best_epoch_test_disparity, how=\"outer\"),\n",
    "    ]).set_index([\"split\"])\n",
    "\n",
    "results_df = results_df[[\n",
    "        \"acc_avg\",\n",
    "        \"acc_wg\",\n",
    "        \"epsilon\",\n",
    "        \"best_alpha\",\n",
    "        \"dem_disparity\",\n",
    "        \"eo_disparity\",\n",
    "        \"ep_disparity\"\n",
    "    ]]\n",
    "\n",
    "results_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Credit-Card"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "fairpate_config = Namespace(dataset='credit-card', list_dataset=None, num_classes=2, output_col_name='default payment next month', split=0.75, dem_disparity_interpretation='max_vs_min', teacher_query_set_split=0.7, num_teachers=4, list_num_teachers=None, threshold=2, list_threshold=None, fairness_threshold=0.2, list_fairness_threshold=None, sigma_threshold=60, list_sigma_threshold=None, sigma_fair_threshold=0, sigma_gnmax=25, list_sigma_gnmax=None, budget=1000, list_budget=None, delta=1e-05, verbose=False, seed=0, list_seed=None, data_path='./fairpate_tabular/data/', min_group_count=50, results_dir='.', use_optuna=False, num_optuna_trials=1000, use_stratification=False, fairness_metric='DemParity', list_fairness_metric=None, num_calib=100, pate_based_model='fairpate', use_inference_time_postprocessing=False, undersampling_ratio=None, optuna_db_path='.', path='./Datasets/CreditCard/credit-card-defaulters_processed.csv', num_inp_attr=85, cols_to_norm=['LIMIT_BAL', 'AGE', 'BILL_AMT1', 'BILL_AMT2', 'BILL_AMT3', 'BILL_AMT4', 'BILL_AMT5', 'BILL_AMT6', 'PAY_AMT1', 'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6'], sensitive_attributes=['SEX'], results_db_path='./fairpate_credit-card_DemParity_results.parquet', gt_fairness=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_name = fairpate_config.dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_data = GeneralData(path = fairpate_config.path, rng=None, sensitive_attributes = fairpate_config.sensitive_attributes, cols_to_norm = fairpate_config.cols_to_norm, output_col_name = fairpate_config.output_col_name, split = fairpate_config.split)\n",
    "\n",
    "full_data.getTrain(return_tensor=False).get_preprocessed_df().to_csv(f\"baselines/dp_dg/data/{fairpate_config.dataset}_v1.0/train.csv\", index=False)\n",
    "full_data.getTest(return_tensor=False).get_preprocessed_df().to_csv(f\"./baselines/dp_dg/data/{fairpate_config.dataset}_v1.0/test.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_test_index = get_train_test_index(\n",
    "    seed=fairpate_config.seed,\n",
    "    dataset=fairpate_config.dataset,\n",
    "    split=fairpate_config.split, \n",
    "    path=fairpate_config.path, \n",
    "    data_path='./baselines/data/',\n",
    "    cols_to_norm=fairpate_config.cols_to_norm,\n",
    "    sensitive_attributes=fairpate_config.sensitive_attributes,\n",
    "    undersampling_ratio=fairpate_config.undersampling_ratio, \n",
    "    output_col_name=fairpate_config.output_col_name, \n",
    "    )\n",
    "\n",
    "train_index, test_index = train_test_index\n",
    "val_to_all_train_ratio = 3000./43000. #using the same ratio as in dp-dg experiment\n",
    "train_index, validation_index = \\\n",
    "    sklearn.model_selection.train_test_split(train_index, test_size=val_to_all_train_ratio, random_state=fairpate_config.seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "dpdg_config = Namespace(\n",
    "    dataset='preprocessed', # this is to force tabular data (like adult) datasets that includes `parkinsons` and `credit-card`\n",
    "    algorithm='ERM', \n",
    "    root_dir='./baselines/dp_dg/data/',\n",
    "    enable_privacy=True, \n",
    "    enable_fair_privacy=False, \n",
    "    apply_noise=False, \n",
    "    split_scheme='official', \n",
    "    dataset_kwargs= dict(\n",
    "            dataset_name=fairpate_config.dataset,\n",
    "            csv_file_names=['train.csv', 'test.csv'],\n",
    "            seed=fairpate_config.seed,\n",
    "            y_col=\"y\",\n",
    "            sensitive_col=\"z\",\n",
    "            feat_cols=None), \n",
    "    download=False, \n",
    "    subsample=False, \n",
    "    frac=1.0, \n",
    "    version=None, \n",
    "    loader_kwargs={'num_workers': 1, 'pin_memory': True}, \n",
    "    train_loader='standard', \n",
    "    uniform_over_groups=False, \n",
    "    distinct_groups=None, \n",
    "    n_groups_per_batch=4, \n",
    "    batch_size=1024,  #instead of 256\n",
    "    eval_loader='standard', \n",
    "    weighted_uniform_iid=True, \n",
    "    uniform_iid=None, \n",
    "    sample_rate=0.005, \n",
    "    clip_sample_rate=None,\n",
    "    model='logistic_regression', \n",
    "    model_kwargs={'in_features': 86}, #changed\n",
    "    transform=None, \n",
    "    target_resolution=None, \n",
    "    resize_scale=None, \n",
    "    max_token_length=None, \n",
    "    loss_function='cross_entropy', \n",
    "    loss_kwargs={}, \n",
    "    groupby_fields=['z', 'y'],\n",
    "    group_dro_step_size=None, \n",
    "    coral_penalty_weight=None, \n",
    "    irm_lambda=None, \n",
    "    irm_penalty_anneal_iters=None, \n",
    "    algo_log_metric='accuracy', \n",
    "    val_metric='acc_wg', \n",
    "    val_metric_decreasing=False, \n",
    "    n_epochs=20, # instead of 20s\n",
    "    optimizer='SGD', \n",
    "    lr=0.22360679774997896, \n",
    "    weight_decay=0.01, \n",
    "    max_grad_norm=None, \n",
    "    optimizer_kwargs={'momentum': 0.9}, sigma=5.0, \n",
    "    max_per_sample_grad_norm=0.5, \n",
    "    delta=1e-05, \n",
    "    sigma2=1.0, \n",
    "    C0=1.0, \n",
    "    scheduler=None, \n",
    "    scheduler_kwargs={}, \n",
    "    scheduler_metric_split='val',\n",
    "    scheduler_metric_name=None, \n",
    "    process_outputs_function='multiclass_logits_to_pred', \n",
    "    evaluate_all_splits=True, \n",
    "    eval_splits=[], \n",
    "#     eval_only=False, \n",
    "    eval_only=True,\n",
    "    eval_epoch=None, \n",
    "    device=device(type='cpu'), \n",
    "    seed=1, \n",
    "    log_dir=f'./baselines/dp_dg/logs/{experiment_name}',\n",
    "    log_every=50, \n",
    "    save_step=None, \n",
    "    save_best=True, \n",
    "    save_last=True, \n",
    "    save_pred=True, \n",
    "    no_group_logging=False, \n",
    "    use_wandb=False, \n",
    "    progress_bar=False, \n",
    "    resume=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%capture\n",
    "dp_dg_experiment(dpdg_config, train_val_test_index=(train_index, validation_index, test_index))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>acc_avg</th>\n",
       "      <th>acc_wg</th>\n",
       "      <th>epsilon</th>\n",
       "      <th>best_alpha</th>\n",
       "      <th>dem_disparity</th>\n",
       "      <th>eo_disparity</th>\n",
       "      <th>ep_disparity</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>split</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>validation</th>\n",
       "      <td>0.741401</td>\n",
       "      <td>0.535519</td>\n",
       "      <td>0.109377</td>\n",
       "      <td>63.0</td>\n",
       "      <td>0.028904</td>\n",
       "      <td>0.041278</td>\n",
       "      <td>0.062768</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>test</th>\n",
       "      <td>0.744000</td>\n",
       "      <td>0.507075</td>\n",
       "      <td>0.109377</td>\n",
       "      <td>63.0</td>\n",
       "      <td>0.011947</td>\n",
       "      <td>0.019004</td>\n",
       "      <td>0.031519</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "             acc_avg    acc_wg   epsilon  best_alpha  dem_disparity  \\\n",
       "split                                                                 \n",
       "validation  0.741401  0.535519  0.109377        63.0       0.028904   \n",
       "test        0.744000  0.507075  0.109377        63.0       0.011947   \n",
       "\n",
       "            eo_disparity  ep_disparity  \n",
       "split                                   \n",
       "validation      0.041278      0.062768  \n",
       "test            0.019004      0.031519  "
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "best_epoch = find_best_epoch_number(f\"baselines/dp_dg/logs/{experiment_name}\")\n",
    "best_epoch_val_disparity = pd.read_csv(f\"baselines/dp_dg/logs/{experiment_name}/{experiment_name}_split:val_seed:1_epoch:best_disparity.csv\")\n",
    "best_epoch_test_disparity = pd.read_csv(f\"baselines/dp_dg/logs/{experiment_name}/{experiment_name}_split:test_seed:1_epoch:best_disparity.csv\")\n",
    "\n",
    "results_df = pd.concat([\n",
    "                    pd.read_csv(f\"baselines/dp_dg/logs/{experiment_name}/val_eval.csv\").query(f\"epoch == {best_epoch}\")\n",
    "                    .assign(split=\"validation\")\n",
    "                    .join(best_epoch_val_disparity, how=\"outer\"),\n",
    "                    pd.read_csv(f\"baselines/dp_dg/logs/{experiment_name}/test_eval.csv\").query(f\"epoch == {best_epoch}\")\n",
    "                    .assign(split=\"test\")\n",
    "                    .join(best_epoch_test_disparity, how=\"outer\"),\n",
    "    ]).set_index([\"split\"])\n",
    "\n",
    "results_df = results_df[[\n",
    "        \"acc_avg\",\n",
    "        \"acc_wg\",\n",
    "        \"epsilon\",\n",
    "        \"best_alpha\",\n",
    "        \"dem_disparity\",\n",
    "        \"eo_disparity\",\n",
    "        \"ep_disparity\"\n",
    "    ]]\n",
    "\n",
    "results_df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dpdg",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
