{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "from scipy import stats\n",
    "\n",
    "import torch.optim as optim\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from ipywidgets import widgets\n",
    "from IPython.display import display,clear_output\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataset import GermanCreditDataset\n",
    "from experiments import Benchmarking\n",
    "from utils.logger_config import setup_logger\n",
    "from tqdm import tqdm\n",
    "from models.wrapper import PYTORCH_MODELS\n",
    "\n",
    "logger = setup_logger()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from experiments.counterfactual import *\n",
    "from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier, BaggingClassifier\n",
    "from lightgbm import LGBMClassifier\n",
    "from xgboost import XGBClassifier\n",
    "from sklearn.svm import SVC\n",
    "from models import PyTorchDNN, PyTorchLinearSVM, PyTorchRBFNet, PyTorchLogisticRegression\n",
    "from sklearn.metrics import accuracy_score, classification_report\n",
    "from sklearn.gaussian_process import GaussianProcessClassifier\n",
    "\n",
    "dataset = GermanCreditDataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_dim = dataset.get_dataframe().shape[1] - 1\n",
    "seed = 0\n",
    "torch.manual_seed(seed)\n",
    "Avalues_method = 'max'\n",
    "\n",
    "counterfactual_algorithms = [\n",
    "    'DiCE',\n",
    "    # 'DisCount',\n",
    "    # 'KNN',\n",
    "]\n",
    "\n",
    "experiment = Benchmarking(\n",
    "    dataset=dataset,\n",
    "    models=[\n",
    "        # (BaggingClassifier(),'sklearn'),\n",
    "        # (GaussianProcessClassifier(),'sklearn'),\n",
    "        # (XGBClassifier(), 'sklearn'),\n",
    "        (LGBMClassifier(),'sklearn'),\n",
    "        # (PyTorchLogisticRegression(input_dim=input_dim), 'PYT'),\n",
    "        # (PyTorchDNN(input_dim=input_dim), 'PYT'),\n",
    "        # (PyTorchRBFNet(input_dim=input_dim, hidden_dim=input_dim), 'PYT'),\n",
    "        # (PyTorchLinearSVM(input_dim=input_dim), 'PYT'),\n",
    "        # (RandomForestClassifier(), 'sklearn'), \n",
    "        # (GradientBoostingClassifier(), 'sklearn'), \n",
    "        # (AdaBoostClassifier(), 'sklearn'), \n",
    "    ],\n",
    "    shapley_methods=[\n",
    "            # \"Train_Distri\",\n",
    "            # \"Train_OTMatch\",\n",
    "            # \"CF_UniformMatch\",\n",
    "            # \"CF_RandomMatch\",\n",
    "            \"CF_OTMatch\",\n",
    "            \"CF_ExactMatch\",\n",
    "    ],\n",
    "    distance_metrics=[\n",
    "        'optimal_transport',\n",
    "        'mean_difference',\n",
    "        'median_difference',\n",
    "        'max_mean_discrepancy',\n",
    "    ],\n",
    "    md_baseline=False,\n",
    ")\n",
    "\n",
    "experiment.train_and_evaluate_models(random_state=seed)\n",
    "experiment.models_performance()\n",
    "\n",
    "logger.info(\"\\n\\n------Compute Counterfactuals------\")\n",
    "sample_num = 50\n",
    "model_counterfactuals = {}\n",
    "for model, model_name in zip(experiment.models, experiment.model_names):\n",
    "    model_counterfactuals[model_name] = {}\n",
    "\n",
    "    for algorithm in counterfactual_algorithms:\n",
    "        if algorithm == 'DisCount' and model_name not in PYTORCH_MODELS:\n",
    "            logger.info(f'Skipping {algorithm} for {model_name} due to incompatability')\n",
    "            continue\n",
    "        logger.info(f'Computing {model_name} counterfactuals with {algorithm}')\n",
    "        function_name = f\"compute_{algorithm}_counterfactuals\"\n",
    "        try:\n",
    "            func = globals()[function_name]\n",
    "            model_counterfactuals[model_name][algorithm] = func(\n",
    "                experiment.X_test,\n",
    "                model = model,\n",
    "                target_name = experiment.dataset.target_name,\n",
    "                sample_num = sample_num,\n",
    "                experiment=experiment,\n",
    "            )\n",
    "        except KeyError:\n",
    "            logger.info(f\"Function {function_name} is not working.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger.info(\"\\n\\n------Compute Action Policies------\")\n",
    "experiment.compute_intervention_policies(\n",
    "    model_counterfactuals=model_counterfactuals,\n",
    "    Avalues_method=Avalues_method,\n",
    ");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "\n",
    "trained_models = deepcopy(experiment.models)\n",
    "model_name = 'LGBMClassifier'\n",
    "model_index = experiment.model_names.index(model_name)\n",
    "model = trained_models[model_index]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from models import ClassifierWrapper\n",
    "experiment.md_baseline = True\n",
    "experiment.models = [\n",
    "    ClassifierWrapper(classifier=model, backend=model.backend),\n",
    "]\n",
    "experiment.shapley_methods=[\n",
    "    \"CF_OTMatch\",\n",
    "    \"CF_ExactMatch\",\n",
    "]\n",
    "experiment.distance_metrics=[\n",
    "    # 'optimal_transport',\n",
    "    'mean_difference',\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger.info(\"\\n\\n------Evaluating Distance Performance Under Interventions------\")\n",
    "experiment.evaluate_distance_performance_under_interventions(\n",
    "    intervention_num_list=range(0,201,5),\n",
    "    trials_num=100,\n",
    "    replace=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import pickle \n",
    "# with open(f\"pickles/{dataset.name}_experiment_optimality.pickle\", \"rb\") as input_file:\n",
    "#     experiment = pickle.load(input_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from experiments import plotting\n",
    "\n",
    "plotting.intervention_vs_distance(experiment, save_to_file=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'AdaBoostClassifier'\n",
    "cf_method = 'DiCE'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_factual = experiment.model_counterfactuals[model_name][cf_method]['X_factual']\n",
    "X_counterfactual = experiment.model_counterfactuals[model_name][cf_method]['X']\n",
    "model_index = experiment.model_names.index(model_name)\n",
    "model = experiment.models[model_index]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from experiments.baseline import OptimalMeanDifference\n",
    "problem = OptimalMeanDifference(X_factual=X_factual, X_counterfactual=X_counterfactual, model=model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "intervention_num_list = range(0,201,5)\n",
    "\n",
    "opt_list = []\n",
    "for intervention_num in range(0,201,5): \n",
    "    if (\n",
    "        len(opt_list) <= 0\n",
    "        or opt_list[-1] > 1e-4\n",
    "    ):\n",
    "        eta = problem.solve_problem(\n",
    "            C=intervention_num\n",
    "        )[\"eta\"]\n",
    "    else:\n",
    "        eta = 0\n",
    "    opt_list.append(eta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from experiments.latex import TikzPlotGenerator\n",
    "\n",
    "\n",
    "data = experiment.distance_results[model_name][cf_method]\n",
    "\n",
    "plot_generator = TikzPlotGenerator(data)\n",
    "print(plot_generator.generate_plot_code(\n",
    "    # metric='optimal_transport',\n",
    "    # metric='max_mean_discrepancy', \n",
    "    metric='mean_difference', \n",
    "    # metric='median_difference', \n",
    "    methods=[\n",
    "        # \"Train_Distri\",\n",
    "        # \"Train_OTMatch\",\n",
    "        # \"CF_UniformMatch\",\n",
    "        # \"CF_RandomMatch\",\n",
    "        \"CF_OTMatch\",\n",
    "        \"CF_ExactMatch\",\n",
    "        \"optimality\",\n",
    "    ],\n",
    "    corrected_y_lists=[opt_list]\n",
    "))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for e in list(zip(intervention_num_list, opt_list)):\n",
    "#     print(e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
