{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "from scipy import stats\n",
    "\n",
    "import torch.optim as optim\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from ipywidgets import widgets\n",
    "from IPython.display import display,clear_output\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataset import HelocDataset\n",
    "from experiments import Benchmarking\n",
    "from utils.logger_config import setup_logger\n",
    "from tqdm import tqdm\n",
    "from models.wrapper import PYTORCH_MODELS\n",
    "\n",
    "logger = setup_logger()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataset import dataset_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "name = 'heloc'\n",
    "dataset_ares = dataset_loader(name, data_path='data/', dropped_features=[], n_bins=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from experiments.counterfactual import *\n",
    "from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier, BaggingClassifier\n",
    "from lightgbm import LGBMClassifier\n",
    "from xgboost import XGBClassifier\n",
    "from sklearn.svm import SVC\n",
    "from models import PyTorchDNN, PyTorchLinearSVM, PyTorchRBFNet, PyTorchLogisticRegression\n",
    "from sklearn.metrics import accuracy_score, classification_report\n",
    "from sklearn.gaussian_process import GaussianProcessClassifier\n",
    "from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis\n",
    "from sklearn.naive_bayes import GaussianNB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = HelocDataset(dataset_ares=dataset_ares)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_dim = dataset.get_dataframe().shape[1] - 1\n",
    "seed = 0\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "Avalues_method = 'avg'\n",
    "\n",
    "counterfactual_algorithms = [\n",
    "    # 'DiCE',\n",
    "    # 'DisCount',\n",
    "    # 'GlobeCE',\n",
    "    'AReS',\n",
    "    # 'KNN',\n",
    "]\n",
    "\n",
    "experiment = Benchmarking(\n",
    "    dataset=dataset,\n",
    "    models=[\n",
    "        # (QuadraticDiscriminantAnalysis(), 'sklearn'), \n",
    "        # (GaussianNB(), 'sklearn'), \n",
    "        # (BaggingClassifier(), 'sklearn'), \n",
    "        # (GaussianProcessClassifier(),'sklearn'),\n",
    "        # (PyTorchLogisticRegression(input_dim=input_dim), 'PYT'),\n",
    "        # (PyTorchDNN(input_dim=input_dim), 'PYT'),\n",
    "        # (PyTorchRBFNet(input_dim=input_dim, hidden_dim=input_dim), 'PYT'),\n",
    "        # (PyTorchLinearSVM(input_dim=input_dim), 'PYT'),\n",
    "        (RandomForestClassifier(), 'sklearn'), \n",
    "        # (GradientBoostingClassifier(), 'sklearn'), \n",
    "        # (AdaBoostClassifier(), 'sklearn'), \n",
    "    ],\n",
    "    shapley_methods=[\n",
    "            \"Train_Distri\",\n",
    "            \"Train_OTMatch\",\n",
    "            \"CF_UniformMatch\",\n",
    "            \"CF_RandomMatch\",\n",
    "            \"CF_OTMatch\",\n",
    "    ],\n",
    "    distance_metrics=[\n",
    "        'optimal_transport',\n",
    "        'mean_difference',\n",
    "        'median_difference',\n",
    "        'max_mean_discrepancy',\n",
    "    ],\n",
    "    md_baseline = False\n",
    ")\n",
    "\n",
    "experiment.train_and_evaluate_models(random_state=seed)\n",
    "experiment.models_performance()\n",
    "\n",
    "logger.info(\"\\n\\n------Compute Counterfactuals------\")\n",
    "sample_num = 1000\n",
    "model_counterfactuals = {}\n",
    "for model, model_name in zip(experiment.models, experiment.model_names):\n",
    "    model_counterfactuals[model_name] = {}\n",
    "\n",
    "    for algorithm in counterfactual_algorithms:\n",
    "        if algorithm == 'DisCount' and model_name not in PYTORCH_MODELS:\n",
    "            logger.info(f'Skipping {algorithm} for {model_name} due to incompatability')\n",
    "            continue\n",
    "        logger.info(f'Computing {model_name} counterfactuals with {algorithm}')\n",
    "        function_name = f\"compute_{algorithm}_counterfactuals\"\n",
    "        try:\n",
    "            func = globals()[function_name]\n",
    "            model_counterfactuals[model_name][algorithm] = func(\n",
    "                experiment.X_test,\n",
    "                model = model,\n",
    "                target_name = experiment.dataset.target_name,\n",
    "                sample_num = sample_num,\n",
    "                experiment=experiment,\n",
    "            )\n",
    "        except KeyError:\n",
    "            print(f\"Function {function_name} is not defined.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n",
    "\n",
    "logger.info(\"\\n\\n------Compute Shapley Values------\")\n",
    "experiment.compute_intervention_policies(\n",
    "    model_counterfactuals=model_counterfactuals,\n",
    "    Avalues_method=Avalues_method,\n",
    ");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger.info(\"\\n\\n------Evaluating Distance Performance Under Interventions------\")\n",
    "experiment.evaluate_distance_performance_under_interventions(\n",
    "    intervention_num_list=range(0,801,20),\n",
    "    trials_num=100,\n",
    "    replace=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import pickle \n",
    "# with open(f\"pickles/{dataset.name}_experiment_xgboost.pickle\", \"rb\") as input_file:\n",
    "#     experiment = pickle.load(input_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from experiments import plotting\n",
    "\n",
    "plotting.intervention_vs_distance(experiment, save_to_file=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import pickle\n",
    "# with open(f\"pickles/{dataset.name}_distance_results.pickle\", \"wb\") as output_file:\n",
    "#     pickle.dump(experiment.distance_results, output_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from experiments.latex import TikzPlotGenerator\n",
    "\n",
    "model_name = 'QuadraticDiscriminantAnalysis'\n",
    "cf_method = 'GlobeCE'\n",
    "data = experiment.distance_results[model_name][cf_method]\n",
    "\n",
    "plot_generator = TikzPlotGenerator(data)\n",
    "print(plot_generator.generate_plot_code(\n",
    "    # metric='optimal_transport',\n",
    "    # metric='max_mean_discrepancy', \n",
    "    # metric='mean_difference', \n",
    "    metric='median_difference', \n",
    "    methods=[\n",
    "        \"Train_Distri\",\n",
    "        \"Train_OTMatch\",\n",
    "        \"CF_UniformMatch\",\n",
    "        \"CF_RandomMatch\",\n",
    "        \"CF_OTMatch\",\n",
    "    ]\n",
    "))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataset import GermanCreditDataset\n",
    "\n",
    "dataset = GermanCreditDataset()\n",
    "\n",
    "import pickle \n",
    "with open(f\"pickles/{dataset.name}_experiment_optimality.pickle\", \"rb\") as input_file:\n",
    "    experiment = pickle.load(input_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment.distance_results['LGBMClassifier']['DiCE']['optimality']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
