{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "plt.style.use('ggplot')\n",
    "import copy\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "torch.manual_seed(0)\n",
    "\n",
    "import networkx as nx\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.neighbors import KernelDensity\n",
    "\n",
    "from reinforcement import environment, Policy, reinforce, evaluate\n",
    "from metrics import causal_edge_score, causal_cfe_from_parent, sparsity_num, sparsity_cat, L1\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The SCM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "SCM = {i: {'input':[], 'func':None} for i in range(12)}\n",
    "SCM[0] = {'input':[2], 'func': lambda a: a}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "device = 'cpu'"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train the classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.compose import ColumnTransformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf_data_train = pd.read_csv('./adult_data/adult_data_train.csv')\n",
    "clf_data_test = pd.read_csv('./adult_data/adult_data_test.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "target = clf_data_train[\"y\"]\n",
    "# Split data into train and test\n",
    "x_train= clf_data_train.drop(\"y\", axis=1)\n",
    "\n",
    "categorical = ['workclass', 'marital-status', 'occupation',\t'relationship',\t'race',\t'sex', 'native-country']\n",
    "numerical = x_train.columns.difference(categorical)\n",
    "\n",
    "# We create the preprocessing pipelines for both numeric and categorical data.\n",
    "numeric_transformer = Pipeline(steps=[\n",
    "    ('scaler', StandardScaler())])\n",
    "\n",
    "categorical_transformer = Pipeline(steps=[\n",
    "    ('onehot', OneHotEncoder(handle_unknown='error'))])\n",
    "\n",
    "transformations = ColumnTransformer(\n",
    "    transformers=[\n",
    "        ('num', numeric_transformer, numerical),\n",
    "        ('cat', categorical_transformer, categorical)])\n",
    "\n",
    "# Append classifier to preprocessing pipeline.\n",
    "# Now we have a full prediction pipeline.\n",
    "clf = Pipeline(steps=[('preprocessor', transformations),\n",
    "                      ('classifier', MLPClassifier(random_state=0))])\n",
    "model = clf.fit(x_train, clf_data_train['y'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8257793499889454"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(clf.predict(clf_data_test.drop('y', axis=1)) == clf_data_test['y'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_train = clf.predict(clf_data_train.drop('y', axis=1)).astype('int')\n",
    "pred_test = clf.predict(clf_data_test.drop('y', axis=1)).astype('int')\n",
    "cf_data_train = clf_data_train.drop('y', axis=1)[pred_train==0]\n",
    "cf_data_test = clf_data_test.drop('y', axis=1)[pred_test==0]\n",
    "scm = SCM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_feats = [cf_data_test.columns.get_loc(name) for name in categorical]\n",
    "train_num_mean = clf_data_train[numerical].mean()\n",
    "train_num_std = clf_data_train[numerical].std()\n",
    "train_data_norm = clf_data_train.drop('y', axis=1)\n",
    "train_data_norm[numerical] = ((train_data_norm[numerical] - train_num_mean)/ train_num_std).to_numpy()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup and training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Episode 1000\tAverage Score: 0.60\n",
      "Episode 2000\tAverage Score: 0.74\n",
      "Episode 3000\tAverage Score: 0.62\n",
      "Episode 4000\tAverage Score: 0.87\n",
      "Episode 5000\tAverage Score: 0.96\n",
      "Episode 6000\tAverage Score: 0.85\n",
      "Episode 7000\tAverage Score: 1.09\n",
      "Episode 8000\tAverage Score: 1.06\n",
      "Episode 9000\tAverage Score: 1.27\n",
      "Episode 10000\tAverage Score: 1.46\n",
      "Episode 11000\tAverage Score: 1.58\n",
      "Episode 12000\tAverage Score: 1.74\n",
      "Episode 13000\tAverage Score: 2.07\n",
      "Episode 14000\tAverage Score: 2.28\n",
      "Episode 15000\tAverage Score: 2.87\n",
      "Episode 16000\tAverage Score: 3.12\n",
      "Episode 17000\tAverage Score: 3.60\n",
      "Episode 18000\tAverage Score: 4.00\n",
      "Episode 19000\tAverage Score: 4.47\n",
      "Episode 20000\tAverage Score: 4.68\n",
      "Episode 21000\tAverage Score: 4.88\n",
      "Episode 22000\tAverage Score: 4.92\n",
      "Episode 23000\tAverage Score: 4.89\n",
      "Episode 24000\tAverage Score: 4.96\n",
      "Episode 25000\tAverage Score: 4.96\n",
      "Episode 26000\tAverage Score: 5.01\n",
      "Episode 27000\tAverage Score: 4.99\n",
      "Episode 28000\tAverage Score: 5.02\n",
      "Episode 29000\tAverage Score: 5.01\n",
      "Episode 30000\tAverage Score: 5.01\n",
      "Episode 31000\tAverage Score: 5.06\n",
      "Episode 32000\tAverage Score: 5.08\n",
      "Episode 33000\tAverage Score: 5.06\n",
      "Episode 34000\tAverage Score: 5.07\n",
      "Episode 35000\tAverage Score: 5.10\n",
      "Episode 36000\tAverage Score: 5.10\n",
      "Episode 37000\tAverage Score: 5.12\n",
      "Episode 38000\tAverage Score: 5.10\n",
      "Episode 39000\tAverage Score: 5.13\n",
      "Episode 40000\tAverage Score: 5.09\n"
     ]
    }
   ],
   "source": [
    "DONE_THRESHOLD = 0.5\n",
    "REG_BETA = 0.5\n",
    "CON_DIST_BETA = 0\n",
    "ACTION_ENTROPY_BETA = 0.001\n",
    "N_ALLOWED_STD = 1.5\n",
    "MAX_STEP = 8\n",
    "cat_feats_names = categorical\n",
    "n_classes = [len(np.unique(train_data_norm[name])) for name in cat_feats_names]\n",
    "disallowed_feats = [cf_data_train.columns.get_loc(name) for name in ['marital-status', 'race', 'sex', 'native-country']]\n",
    "directions = {0: 1, 2: 1} # age and education cant decrease\n",
    "policy = Policy(n_feats=11, cat_feats=cat_feats, n_classes=n_classes, disallowed_feats=disallowed_feats, hidden_size=128).to(device)\n",
    "optimizer = optim.Adam(policy.parameters(), lr=1e-4)\n",
    "env = environment(clf, clf_data_train.drop('y',axis=1).mean().to_numpy(), clf_data_train.drop('y',axis=1).std().to_numpy(),\\\n",
    "     train_data_norm.min().to_numpy(), train_data_norm.max().to_numpy(), cat_feats, list(cf_data_train.columns), \\\n",
    "     scm, DONE_THRESHOLD, REG_BETA, MAX_STEP, disallowd_feats=disallowed_feats, directions=directions, order_matter=False)\n",
    "# scores = reinforce(cf_data_train, env, policy, optimizer, checkpoint_path='./ckpt_adult1', n_episodes=40000, print_every=1000, con_dist=None, n_allowed_std=N_ALLOWED_STD, CON_DIST_BETA=CON_DIST_BETA, ACTION_ENTROPY_BETA = ACTION_ENTROPY_BETA, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.save(policy, './adult_policy_best.pt')\n",
    "with open('./ckpt_adult/best_model.pt', 'rb') as f:\n",
    "    policy = torch.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "r = evaluate(cf_data_test, env, policy, device=device, print_file='./result.txt')\n",
    "cfs = pd.DataFrame(r['CF'], columns=cf_data_test.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>education</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capitalloss</th>\n",
       "      <th>hoursperweek</th>\n",
       "      <th>native-country</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>59.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>98.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>38.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>26.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>4356.0</td>\n",
       "      <td>45.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>23.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>4356.0</td>\n",
       "      <td>30.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>41.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>99.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3623</th>\n",
       "      <td>42.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>99.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3624</th>\n",
       "      <td>90.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>4356.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3625</th>\n",
       "      <td>56.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>66.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3626</th>\n",
       "      <td>56.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>99.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3627</th>\n",
       "      <td>33.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>99.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>3628 rows × 11 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       age  workclass  education  marital-status  occupation  relationship  \\\n",
       "0     59.0        1.0       15.0             1.0         1.0           1.0   \n",
       "1     38.0        2.0       15.0             2.0         4.0           0.0   \n",
       "2     26.0        5.0       15.0             0.0         1.0           2.0   \n",
       "3     23.0        2.0       15.0             0.0         4.0           5.0   \n",
       "4     41.0        2.0       15.0             2.0         9.0           0.0   \n",
       "...    ...        ...        ...             ...         ...           ...   \n",
       "3623  42.0        1.0       15.0             2.0         9.0           4.0   \n",
       "3624  90.0        1.0       15.0             0.0         6.0           1.0   \n",
       "3625  56.0        2.0       15.0             1.0        10.0           1.0   \n",
       "3626  56.0        3.0       15.0             2.0         1.0           0.0   \n",
       "3627  33.0        6.0       15.0             0.0         3.0           5.0   \n",
       "\n",
       "      race  sex  capitalloss  hoursperweek  native-country  \n",
       "0      0.0  0.0          0.0          98.0             0.0  \n",
       "1      0.0  1.0          0.0          40.0             0.0  \n",
       "2      0.0  1.0       4356.0          45.0             0.0  \n",
       "3      2.0  0.0       4356.0          30.0             0.0  \n",
       "4      0.0  0.0          0.0          99.0             0.0  \n",
       "...    ...  ...          ...           ...             ...  \n",
       "3623   0.0  1.0          0.0          99.0             0.0  \n",
       "3624   0.0  1.0       4356.0          40.0             0.0  \n",
       "3625   0.0  0.0          0.0          66.0             0.0  \n",
       "3626   0.0  0.0          0.0          99.0             0.0  \n",
       "3627   0.0  0.0          0.0          99.0             0.0  \n",
       "\n",
       "[3628 rows x 11 columns]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cfs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "mask = clf.predict(cfs)\n",
    "mask = [i for i in range(len(mask)) if mask[i]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "validity: 0.9944873208379272\n",
      "sparsity_num:  0.3668930155210643\n",
      "sparsity_cat:  0.8535793474817865\n",
      "L1:  1.8719410238516443\n",
      "sparsity_num(causal perspective):  0.5922949002217295\n",
      "sparsity_cat(causal perspective):  0.8535793474817865\n",
      "L1(causal perspective):  [1.77154609]\n"
     ]
    }
   ],
   "source": [
    "print('validity:', np.mean(model.predict(cfs)))\n",
    "print('sparsity_num: ', sparsity_num(cfs.iloc[mask], cf_data_test.iloc[mask], cat_feats=categorical))\n",
    "print('sparsity_cat: ', sparsity_cat(cfs.iloc[mask], cf_data_test.iloc[mask], cat_feats=categorical))\n",
    "print('L1: ', L1(cfs.iloc[mask], cf_data_test.iloc[mask], cat_feats=categorical, stds=clf_data_train.std()))\n",
    "print('sparsity_num(causal perspective): ', np.mean(np.array(r['reg'])[mask][:,[0, 2, 8, 9]]==0))\n",
    "print('sparsity_cat(causal perspective): ', np.mean(np.array(r['reg'])[mask][:,[cat_feats]]==0))\n",
    "print('L1(causal perspective): ', np.mean(np.abs(np.array(r['reg'])[mask][:,[0, 2, 8, 9]])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9986141906873615\n"
     ]
    }
   ],
   "source": [
    "# constraints\n",
    "rule1 = (cfs['age'].iloc[mask].to_numpy() - cf_data_test['age'].iloc[mask].to_numpy()) >= 0\n",
    "rule2 = (cfs['education'].iloc[mask].to_numpy() - cf_data_test['education'].iloc[mask].to_numpy()) >= 0\n",
    "rule3 = ((cfs['education'].iloc[mask].to_numpy() - cf_data_test['education'].iloc[mask].to_numpy()) == 0) + (cfs['age'].iloc[mask].to_numpy() > cf_data_test['age'].iloc[mask].to_numpy())*(cfs['education'].iloc[mask].to_numpy() > cf_data_test['education'].iloc[mask].to_numpy())\n",
    "print(np.mean(rule1*rule2*rule3>0))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.7 ('base': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "13f752eec1c22bf877ed13574aae0c5036489a901cc325afaf3d1c548bba661a"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
