{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "43d9e6ef",
   "metadata": {},
   "source": [
    "# Overview\n",
    "This notebook provides code to reproduce FairFIS intrepretability results for a tree-based surrogate of an MLP model. This notebook provides an example on the Adult dataset. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf24eead",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import accuracy_score\n",
    "import math\n",
    "from sklearn.tree import DecisionTreeRegressor\n",
    "from scipy.special import expit\n",
    "from scipy.stats import pearsonr\n",
    "from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor\n",
    "from FairFIS import fis_tree, fis_forest, fis_boosting\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.ensemble import GradientBoostingClassifier,GradientBoostingRegressor, RandomForestRegressor\n",
    "from FairFIS import util\n",
    "import random\n",
    "from scipy.special import expit\n",
    "from sklearn.model_selection import train_test_split\n",
    "from scipy.special import logit, expit\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense\n",
    "from keras.layers import Dropout\n",
    "from keras.constraints import max_norm\n",
    "from sklearn.metrics import r2_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1a13f33",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Read in Data\n",
    "dat = pd.read_csv(\"adult.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5d2c7ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Set target and protected attributes\n",
    "nrow = dat.shape[0]\n",
    "target_sensitive_attribute = \"sex\"\n",
    "sensitive_attribute2 = \"race\"\n",
    "target = \"income-per-year\"\n",
    "drop_features = [target_sensitive_attribute,sensitive_attribute2,target]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13b3616f",
   "metadata": {},
   "outputs": [],
   "source": [
    "column_names = list(dat.columns)\n",
    "column_names.remove(target_sensitive_attribute)\n",
    "column_names.remove(sensitive_attribute2)\n",
    "column_names.remove(target)\n",
    "ncol = len(column_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95058e5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "train, test = train_test_split(dat, test_size=0.3 ,shuffle=True,random_state = 0)\n",
    "a_train = train[target_sensitive_attribute].to_numpy()\n",
    "y_train = train[target].to_numpy()\n",
    "y_train = np.where(y_train != 1, 0, y_train)\n",
    "train = train.drop(drop_features, axis = 1)\n",
    "train = train.to_numpy()\n",
    "\n",
    "a_test = test[target_sensitive_attribute].to_numpy()\n",
    "y_test = test[target].to_numpy()\n",
    "y_test = np.where(y_test != 1, 0, y_test)\n",
    "test = test.drop(drop_features, axis = 1)\n",
    "test = test.to_numpy()\n",
    "\n",
    "# Fit MLP model\n",
    "model = Sequential()\n",
    "model.add(Dense(28, input_dim=ncol, activation='relu', kernel_initializer=\"uniform\"))\n",
    "model.add(Dropout(0.2))\n",
    "model.add(Dense(20, activation='relu', kernel_constraint=max_norm(3), kernel_initializer=\"uniform\"))\n",
    "model.add(Dropout(0.2))\n",
    "model.add(Dense(10, activation='relu', kernel_initializer=\"uniform\"))\n",
    "model.add(Dense(1, activation='sigmoid', kernel_initializer=\"uniform\"))\n",
    "\n",
    "# Compile MLP\n",
    "model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])\n",
    "\n",
    "# Fit the MLP model\n",
    "model.fit(train, y_train, epochs=100, batch_size=10)\n",
    "scores = model.predict(train)\n",
    "train_pred = np.sign(logit(scores))\n",
    "\n",
    "\n",
    "#Fit surrogate Model\n",
    "clf = DecisionTreeClassifier()\n",
    "clf.fit(train,train_pred)\n",
    "\n",
    "\n",
    "#Compute Scores\n",
    "f_forest = fis_tree(clf,train,y_train,a_train,0, triangle = False)\n",
    "f_forest.calculate_fairness_importance_score()\n",
    "\n",
    "dp_scores = f_forest._fairness_importance_score_dp_root\n",
    "eo_scores = f_forest._fairness_importance_score_eqop_root\n",
    "acc_scores = f_forest.fitted_clf.feature_importances_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd2f4a4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {'Feature': column_names,\n",
    "    'DP': dp_scores,\n",
    "        'ACC': acc_scores,\n",
    "         'EO': eo_scores}\n",
    "\n",
    "# Create DataFrame\n",
    "df = pd.DataFrame(data)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
