{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import copy\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "torch.manual_seed(0)\n",
    "\n",
    "import networkx as nx\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.neighbors import KernelDensity\n",
    "\n",
    "from ConDist import CsvDataset, ConDist_MO, Condist_for_bn\n",
    "from reinforcement import environment, Policy, reinforce, evaluate\n",
    "from metrics import causal_edge_score, causal_cfe_from_parent, sparsity_num, sparsity_cat, L1, cess_for_bn\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open('./bn_data/scm.pickle', 'rb') as f:\n",
    "    SCM = pickle.load(f)\n",
    "\n",
    "def func_gen(w):\n",
    "    def func(*args):\n",
    "        return np.array(w).dot(np.array([1]+list(args)))\n",
    "    return func\n",
    "\n",
    "for k in SCM.keys():\n",
    "    SCM[k]['func']  = func_gen(w=SCM[k]['weight'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.compose import ColumnTransformer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf_data_train = pd.read_csv('./bn_data/sangiovese_data_train.csv')\n",
    "clf_data_test = pd.read_csv('./bn_data/sangiovese_data_test.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "target = clf_data_train[\"y\"]\n",
    "# Split data into train and test\n",
    "x_train= clf_data_train.drop(\"y\", axis=1)\n",
    "\n",
    "categorical = []\n",
    "numerical = x_train.columns.difference(categorical)\n",
    "\n",
    "# We create the preprocessing pipelines for both numeric and categorical data.\n",
    "numeric_transformer = Pipeline(steps=[\n",
    "    ('scaler', StandardScaler())])\n",
    "\n",
    "categorical_transformer = Pipeline(steps=[\n",
    "    ('onehot', OneHotEncoder(handle_unknown='error'))])\n",
    "\n",
    "transformations = ColumnTransformer(\n",
    "    transformers=[\n",
    "        ('num', numeric_transformer, numerical),\n",
    "        ('cat', categorical_transformer, categorical)])\n",
    "\n",
    "# Append classifier to preprocessing pipeline.\n",
    "# Now we have a full prediction pipeline.\n",
    "clf = Pipeline(steps=[('preprocessor', transformations),\n",
    "                      ('classifier', MLPClassifier(random_state=0))])\n",
    "model = clf.fit(x_train, clf_data_train['y'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8236819360414867"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(clf.predict(clf_data_test.drop('y', axis=1)) == clf_data_test['y'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_train = clf.predict(clf_data_train.drop('y', axis=1)).astype('int')\n",
    "pred_test = clf.predict(clf_data_test.drop('y', axis=1)).astype('int')\n",
    "cf_data_train = clf_data_train.drop('y', axis=1)[pred_train==0]\n",
    "cf_data_test = clf_data_test.drop('y', axis=1)[pred_test==0]\n",
    "scm = SCM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For this dataset, the conditional distribution is given by the bayesian network.\n",
    "# The conditional distribution should be estimated at normalized scale\n",
    "cat_feats = []\n",
    "train_num_mean = clf_data_train.drop('y', axis=1).mean()\n",
    "train_num_std = clf_data_train.drop('y', axis=1).std()\n",
    "train_data_norm = clf_data_train.drop('y', axis=1)\n",
    "train_data_norm = (train_data_norm - train_num_mean)/ train_num_std\n",
    "dist_nets = {}\n",
    "for out_feat in scm.keys():\n",
    "    in_feats = scm[out_feat]['input']\n",
    "    if in_feats:\n",
    "        dist_nets[out_feat] = {}\n",
    "        dist_nets[out_feat]['input'] = in_feats\n",
    "        dist_nets[out_feat]['net'] = Condist_for_bn(out_feat, scm, train_num_mean, train_num_std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "DONE_THRESHOLD = 0.5\n",
    "REG_BETA = 0.7\n",
    "CON_DIST_BETA = 0.7\n",
    "ACTION_ENTROPY_BETA = 0.0001\n",
    "N_ALLOWED_STD = 0.2\n",
    "MAX_STEP = 13\n",
    "CAT_CAHNGE_PENALTY = 0.3\n",
    "ADD_BIAS = False\n",
    "cat_feats = []\n",
    "cat_feats_names = []\n",
    "n_classes = []\n",
    "disallowed_feats = []\n",
    "directions = {}\n",
    "policy = Policy(n_feats=13, cat_feats=cat_feats, n_classes=n_classes, disallowed_feats=disallowed_feats, hidden_size=128).to(device)\n",
    "optimizer = optim.Adam(policy.parameters(), lr=1e-4)\n",
    "env = environment(clf, clf_data_train.drop('y',axis=1).mean().to_numpy(), clf_data_train.drop('y',axis=1).std().to_numpy(),\\\n",
    "     train_data_norm.min().to_numpy(), train_data_norm.max().to_numpy(), cat_feats, list(cf_data_train.columns), \\\n",
    "     scm, DONE_THRESHOLD, REG_BETA, cat_change_penalty=CAT_CAHNGE_PENALTY, max_step=MAX_STEP, disallowd_feats=disallowed_feats, directions=directions, order_matter=False, add_bias_to_endo=ADD_BIAS)\n",
    "# scores = reinforce(cf_data_train, env, policy, optimizer, checkpoint_path='./ckpt_bn', n_episodes=30000, print_every=1000, con_dist=dist_nets, n_allowed_std=N_ALLOWED_STD, CON_DIST_BETA=CON_DIST_BETA, ACTION_ENTROPY_BETA = ACTION_ENTROPY_BETA, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./ckpt_bn/model_12000.pt', 'rb') as f:\n",
    "    policy = torch.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "r = evaluate(cf_data_test, env, policy, device=device, print_file='./result.txt')\n",
    "cfs = pd.DataFrame(r['CF'], columns=cf_data_test.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "mask = clf.predict(cfs)\n",
    "mask = [i for i in range(len(mask)) if mask[i]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "cess_cf = cess_for_bn(cf_data_test.iloc[mask].to_numpy(), cfs.iloc[mask].to_numpy(), SCM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "validity: 1.0\n",
      "causal_edge_score(median):  0.2000142843394448\n",
      "causal_edge_score(average):  0.2998812484146612\n",
      "sparsity_num:  0.1332871332871333\n",
      "L1:  1.1318849676888327\n"
     ]
    }
   ],
   "source": [
    "print('validity:', np.mean(model.predict(cfs)))\n",
    "print('causal_edge_score(median): ', np.mean(np.median(cess_cf, axis=1)))\n",
    "print('causal_edge_score(average): ', np.mean(cess_cf))\n",
    "print('sparsity_num: ', sparsity_num(cfs.iloc[mask], cf_data_test.iloc[mask], cat_feats=[]))\n",
    "print('L1: ', L1(cfs.iloc[mask], cf_data_test.iloc[mask], cat_feats=[], stds=clf_data_train.std()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sparsity_num(causal perspective):  0.8044583044583045\n",
      "L1(causal perspective):  0.26513492693685703\n"
     ]
    }
   ],
   "source": [
    "causal_cfs_from_parent = causal_cfe_from_parent(cfs, cf_data_test, scm=SCM, cat_feats=[], add_bias=False)\n",
    "exo_node = [feat for feat in SCM.keys() if SCM[feat]['input'] == []]\n",
    "exo_feats = cf_data_test.columns[exo_node]\n",
    "# Modify the exogenous back to the original for computing L1/ sparsity from causal perspective \n",
    "causal_cfs_from_parent[exo_feats] = cf_data_test[exo_feats].to_numpy()\n",
    "print('sparsity_num(causal perspective): ', sparsity_num(cfs.iloc[mask], causal_cfs_from_parent.iloc[mask], cat_feats=[]))\n",
    "print('L1(causal perspective): ', L1(cfs.iloc[mask], causal_cfs_from_parent.iloc[mask], cat_feats=[], stds=clf_data_train.std()))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.7 ('base': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "13f752eec1c22bf877ed13574aae0c5036489a901cc325afaf3d1c548bba661a"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
