{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Standard library imports\n",
    "import os\n",
    "import sys\n",
    "import random\n",
    "\n",
    "# Third party imports\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import sklearn.metrics as metrics\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from itertools import combinations\n",
    "from tqdm import tqdm\n",
    "\n",
    "import folktables\n",
    "from folktables import ACSDataSource\n",
    "\n",
    "from sdv.metadata import SingleTableMetadata\n",
    "from sdv.single_table import GaussianCopulaSynthesizer\n",
    "\n",
    "# Local imports\n",
    "import utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data\n",
    "ACStask = folktables.BasicProblem(\n",
    "    features=[\n",
    "        'AGEP',\n",
    "        'COW',\n",
    "        'SCHL',\n",
    "        'MAR',\n",
    "        'OCCP',\n",
    "        'POBP',\n",
    "        'RELP',\n",
    "        'WKHP',\n",
    "        'SEX',\n",
    "        'RAC1P',\n",
    "    ],\n",
    "    target='PINCP',\n",
    "    target_transform=lambda x: x > 50000,\n",
    "    group='RAC1P',\n",
    "    preprocess=utils.adult_filter,\n",
    "    postprocess=lambda x: np.nan_to_num(x, -1),\n",
    ")\n",
    "\n",
    "state = \"NJ\"\n",
    "data_source = ACSDataSource(survey_year='2018', horizon='1-Year', survey='person')\n",
    "acs_data = data_source.get_data(states =[state], download=True)\n",
    "x, y, a = ACStask.df_to_pandas(acs_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 81.16%\n"
     ]
    }
   ],
   "source": [
    "# Split data and train model\n",
    "x_train, x_test, y_train, y_test, a_train, a_test = train_test_split(\n",
    "    x, y.values.ravel(), x['RAC1P'], test_size=0.2, random_state=0)\n",
    "\n",
    "# Train and evaluate random forest classifier\n",
    "clf = RandomForestClassifier(n_estimators=100, max_depth=10, random_state=42)\n",
    "\n",
    "# Train the classifier\n",
    "clf.fit(x_train, y_train)\n",
    "\n",
    "# Make predictions on the test set\n",
    "y_pred = clf.predict(x_test)\n",
    "\n",
    "# Evaluate the model\n",
    "accuracy = np.sum(y_pred == y_test)/len(y_test)\n",
    "print(f'Accuracy: {accuracy * 100:.2f}%')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation of sufficiency and necessity using correct conditional"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/cis/home/bbharti1/anaconda3/envs/cuda117/lib/python3.11/site-packages/sdv/single_table/base.py:82: UserWarning: We strongly recommend saving the metadata using 'save_to_json' for replicability in future SDV versions.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "# Fitting model for p(x)\n",
    "x_metadata = SingleTableMetadata()\n",
    "x_metadata.detect_from_dataframe(x)\n",
    "synthesizer = GaussianCopulaSynthesizer(x_metadata)\n",
    "synthesizer.fit(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare subset of data (and save)\n",
    "eps = 0.4\n",
    "num_features = x.shape[1]\n",
    "test_size = 100\n",
    "num_samples = 100\n",
    "f_0 = np.mean(clf.predict_proba(x_test)[:,1])\n",
    "pos_idx = np.where(clf.predict_proba(x_test)[:,1] >= (f_0 + eps))[0]\n",
    "idx_to_test = pos_idx[random.sample(range(len(pos_idx)), test_size)]\n",
    "x_test_subset = x_test.iloc[list(idx_to_test)]\n",
    "x_test_subset_N = utils.repeat_rows(x_test_subset, num_samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sufficiency and Necessity Computation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running experiment for tau = 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 120/120 [12:39<00:00,  6.33s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running experiment for tau = 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 210/210 [24:25<00:00,  6.98s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running experiment for tau = 9\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [00:50<00:00,  5.03s/it]\n"
     ]
    }
   ],
   "source": [
    "# Compute sufficiency and necessity\n",
    "taus = [3, 6, 9]\n",
    "tau_results = {f\"tau = {tau}\": {} for tau in taus}\n",
    "for tau in taus:\n",
    "    print(f'Running experiment for tau = {tau}')\n",
    "    tau_sets = utils.generate_subsets(num_features, tau)\n",
    "    full_set = set(range(num_features))\n",
    "    f_x = clf.predict_proba(x_test_subset)[:,1]\n",
    "\n",
    "    suff_X = []\n",
    "    necc_X = []\n",
    "    for S in tqdm(tau_sets):\n",
    "        S_c = set(full_set) - set(S)\n",
    "        reference_data_S = x_test_subset_N.iloc[:,list(S)]\n",
    "        reference_data_Sc = x_test_subset_N.iloc[:,list(S_c)]\n",
    "        null_sample_S = synthesizer.sample_remaining_columns(\n",
    "            known_columns=reference_data_S,\n",
    "            batch_size=test_size*num_samples,\n",
    "            max_tries_per_batch=500,\n",
    "        )\n",
    "        null_sample_Sc = synthesizer.sample_remaining_columns(\n",
    "            known_columns=reference_data_Sc,\n",
    "            batch_size=test_size*num_samples,\n",
    "            max_tries_per_batch=500\n",
    "        )\n",
    "        probs_S = clf.predict_proba(null_sample_S)[:,1]\n",
    "        probs_Sc = clf.predict_proba(null_sample_Sc)[:,1]\n",
    "\n",
    "        f_xS = np.mean(probs_S.reshape(test_size, num_samples), axis=1)\n",
    "        f_xSc = np.mean(probs_Sc.reshape(test_size, num_samples), axis=1)\n",
    "\n",
    "        # Sufficiency and necessity for all samples on a particular subset\n",
    "        suff_S = abs(f_x - f_xS).reshape(-1,1)\n",
    "        necc_S = abs(f_xSc - f_0).reshape(-1,1)\n",
    "        suff_X.append(suff_S)\n",
    "        necc_X.append(necc_S)\n",
    "\n",
    "    suff_X = np.concatenate(suff_X, axis=1)\n",
    "    necc_X = np.concatenate(necc_X, axis=1)\n",
    "    tau_results[f\"tau = {tau}\"] = {\"suff\": suff_X, \"necc\": necc_X}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save results\n",
    "import pickle\n",
    "results_dir = \"results\"\n",
    "\n",
    "with open(os.path.join(results_dir, 'tau_results.pkl'), 'wb') as f:\n",
    "    pickle.dump(tau_results, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.11.9 ('cuda117')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "6bca97b14fc9cd3785537269d267f0abb85f4f98e75c76abc0bec20b4de6d918"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
