{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "from scipy import stats\n",
    "\n",
    "import torch.optim as optim\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from ipywidgets import widgets\n",
    "from IPython.display import display,clear_output\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataset import CompasDataset\n",
    "from experiments import Benchmarking\n",
    "from utils.logger_config import setup_logger\n",
    "from tqdm import tqdm\n",
    "from models.wrapper import PYTORCH_MODELS\n",
    "\n",
    "logger = setup_logger()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataset import dataset_loader\n",
    "from experiments.counterfactual import *\n",
    "from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier\n",
    "from lightgbm import LGBMClassifier\n",
    "from xgboost import XGBClassifier\n",
    "from sklearn.svm import SVC\n",
    "from models import PyTorchDNN, PyTorchLinearSVM, PyTorchRBFNet,PyTorchLogisticRegression,PyTorchLinearSVM\n",
    "from sklearn.metrics import accuracy_score, classification_report\n",
    "from sklearn.gaussian_process import GaussianProcessClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "name = 'compas'\n",
    "dataset_ares = dataset_loader(name, data_path='data/', dropped_features=[], n_bins=None)\n",
    "\n",
    "dataset = CompasDataset(dataset_ares=dataset_ares)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_dim = dataset.get_dataframe().shape[1] - 1\n",
    "seed = 0\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "Avalues_method = 'max'\n",
    "\n",
    "counterfactual_algorithms = [\n",
    "    # 'DiCE',\n",
    "    'DisCount',\n",
    "    # 'GlobeCE',\n",
    "    # 'AReS',\n",
    "    # 'KNN',\n",
    "]\n",
    "\n",
    "experiment = Benchmarking(\n",
    "    dataset=dataset,\n",
    "    models=[\n",
    "        # (BaggingClassifier(),'sklearn'),\n",
    "        # (GaussianProcessClassifier(),'sklearn'),\n",
    "        # (XGBClassifier(), 'sklearn'),\n",
    "        # (LGBMClassifier(),'sklearn'),\n",
    "        # (PyTorchLogisticRegression(input_dim=input_dim), 'PYT'),\n",
    "        # (PyTorchDNN(input_dim=input_dim), 'PYT'),\n",
    "        (PyTorchRBFNet(input_dim=input_dim, hidden_dim=input_dim), 'PYT'),\n",
    "        # (PyTorchLinearSVM(input_dim=input_dim), 'PYT'),\n",
    "        # (RandomForestClassifier(), 'sklearn'), \n",
    "        # (GradientBoostingClassifier(), 'sklearn'), \n",
    "        # (AdaBoostClassifier(), 'sklearn'), \n",
    "    ],\n",
    "    shapley_methods=[\n",
    "            \"Train_Distri\",\n",
    "            \"Train_OTMatch\",\n",
    "            \"CF_UniformMatch\",\n",
    "            \"CF_RandomMatch\",\n",
    "            \"CF_OTMatch\",\n",
    "    ],\n",
    "    distance_metrics=[\n",
    "        'optimal_transport',\n",
    "        'mean_difference',\n",
    "        'median_difference',\n",
    "        'max_mean_discrepancy',\n",
    "    ],\n",
    "    md_baseline=False,\n",
    ")\n",
    "\n",
    "experiment.train_and_evaluate_models(random_state=seed)\n",
    "experiment.models_performance()\n",
    "\n",
    "\n",
    "logger.info(\"\\n\\n------Compute Counterfactuals------\")\n",
    "sample_num = 50\n",
    "model_counterfactuals = {}\n",
    "for model, model_name in zip(experiment.models, experiment.model_names):\n",
    "    model_counterfactuals[model_name] = {}\n",
    "\n",
    "    for algorithm in counterfactual_algorithms:\n",
    "        if algorithm == 'DisCount' and model_name not in PYTORCH_MODELS:\n",
    "            logger.info(f'Skipping {algorithm} for {model_name} due to incompatability')\n",
    "            continue\n",
    "        logger.info(f'Computing {model_name} counterfactuals with {algorithm}')\n",
    "        function_name = f\"compute_{algorithm}_counterfactuals\"\n",
    "        try:\n",
    "            func = globals()[function_name]\n",
    "            model_counterfactuals[model_name][algorithm] = func(\n",
    "                experiment.X_test,\n",
    "                model = model,\n",
    "                target_name = experiment.dataset.target_name,\n",
    "                sample_num = sample_num,\n",
    "                experiment=experiment,\n",
    "            )\n",
    "        except KeyError:\n",
    "            print(f\"Function {function_name} is not defined.\")\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger.info(\"\\n\\n------Compute Shapley Values------\")\n",
    "experiment.compute_intervention_policies(\n",
    "    model_counterfactuals=model_counterfactuals,\n",
    "    Avalues_method=Avalues_method,\n",
    ");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger.info(\"\\n\\n------Evaluating Distance Performance Under Interventions------\")\n",
    "experiment.evaluate_distance_performance_under_interventions(\n",
    "    intervention_num_list=range(0,101,5),\n",
    "    trials_num=100,\n",
    "    replace=False\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import pickle \n",
    "# with open(f\"pickles/{dataset.name}_experiment.pickle\", \"rb\") as input_file:\n",
    "#     experiment = pickle.load(input_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from experiments import plotting\n",
    "\n",
    "plotting.intervention_vs_distance(experiment, save_to_file=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from experiments.latex import TikzPlotGenerator\n",
    "\n",
    "model_name = 'PyTorchRBFNet'\n",
    "cf_method = 'DisCount'\n",
    "data = experiment.distance_results[model_name][cf_method]\n",
    "\n",
    "plot_generator = TikzPlotGenerator(data)\n",
    "print(plot_generator.generate_plot_code(\n",
    "    # metric='optimal_transport',\n",
    "    # metric='max_mean_discrepancy', \n",
    "    # metric='mean_difference', \n",
    "    metric='median_difference', \n",
    "    methods=[\n",
    "        \"Train_Distri\",\n",
    "        \"Train_OTMatch\",\n",
    "        \"CF_UniformMatch\",\n",
    "        \"CF_RandomMatch\",\n",
    "        \"CF_OTMatch\",\n",
    "    ]\n",
    "))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
