{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee453d40",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "sys.path.append(os.path.abspath(\"..\"))\n",
    "\n",
    "from datasets.datasets import SyntheticFACE, SyntheticFACE2, SyntheticMoons, Dataset, CaliforniaHousing, GermanCredit, GermanCreditv2, GiveMeSomeCredit, AdultIncome\n",
    "from visualisation import *\n",
    "\n",
    "from models.mlp_pytorch import PyTorchMLP\n",
    "\n",
    "from conformal.split_conformal import SplitConformalPrediction\n",
    "from conformal.localised_conformal_baselcp import BaseLCP\n",
    "from conformal.localised_conformal_tree import ConformalCONFEXTree\n",
    "\n",
    "from counterfactual_explanations.counterfactual_benchmarker import *\n",
    "from counterfactual_explanations.gradient_based.auxillary_models import *\n",
    "from counterfactual_explanations.gradient_based.cf_gradient_based import *\n",
    "from counterfactual_explanations.gradient_based.losses import *\n",
    "\n",
    "from counterfactual_explanations.milp_based.cf_conformal import *\n",
    "from counterfactual_explanations.milp_based.cf_mindist import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1292775",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00a647ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "Dataset = SyntheticFACE2\n",
    "factual = np.array([1, 2.5])\n",
    "y_target = 1 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "134cef74",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = Dataset(0.6, 0.2, 0.2)\n",
    "X_train, y_train, X_calib, y_calib, X_test, y_test = dataset.get_X_y_split()\n",
    "plot_split_dataset(X_train, X_test, y_train, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a86e8297",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = PyTorchMLP(config={}, input_properties=dataset.input_properties)\n",
    "model.load_or_train(Path(\"facedata2\"), X_train, y_train, True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b99e94f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_decision_boundary(model, X_calib, X_calib, y_calib, y_calib)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b10fc3b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Min distance CF\n",
    "cf_gen = MinDistanceCF(model, dataset.input_properties, config={\"db_distance\":0.05})\n",
    "counterfactual = cf_gen.generate_counterfactual(factual, y_target)\n",
    "plot_counterfactual(factual, counterfactual, model, X_calib, X_calib, y_calib, y_calib, faded_background=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9589805",
   "metadata": {},
   "outputs": [],
   "source": [
    "conformalCF = ConformalCF(model, dataset.input_properties, config={\"conformal_class\": SplitConformalPrediction, \"conformal_config\": {\"alpha\": 0.03}})\n",
    "conformalCF.setup(X_train, y_train, X_calib, y_calib)\n",
    "counterfactual_conformal = conformalCF.generate_counterfactual(factual, y_target)\n",
    "plot_counterfactual(factual, counterfactual_conformal, model, X_calib, X_calib, y_calib, y_calib, conformal=conformalCF.conformal)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20a1b6f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "ccf = {\"alpha\": 0.03,\n",
    "        \"scorefn_name\":'linear',\n",
    "        \"kernel_name\":'box_l1', \n",
    "        \"kernel_bandwidth\": 0.35}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8b814f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "conformal_cfxt = BaseLCP(model, dataset.input_properties, ccf)\n",
    "conformal_cfxt.calibrate(X_calib, y_calib, None)\n",
    "plot_conformal_prediction(model, conformal_cfxt, X_calib, X_calib, y_calib, y_calib)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44cec580",
   "metadata": {},
   "outputs": [],
   "source": [
    "conformalCF_b = ConformalCF(model, dataset.input_properties, config={\"conformal_class\": BaseLCP, \"conformal_config\": ccf})\n",
    "conformalCF_b.setup(X_train, y_train, X_calib, y_calib)\n",
    "counterfactual_conformal_b = conformalCF_b.generate_counterfactual(factual, y_target)\n",
    "plot_counterfactual(factual, counterfactual_conformal_b, model, X_calib, X_calib, y_calib, y_calib, conformal=conformalCF_b.conformal)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07c59fd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "ccf2 = {\"alpha\": 0.05,\n",
    "        \"scorefn_name\":'linear2',\n",
    "        \"kernel_name\":'box_l1', \n",
    "        \"kernel_bandwidth\": 0.3,\n",
    "        \"inf_quantile\": False,\n",
    "        \"global_quantile\": False,\n",
    "        \"split_midpoint\": True,\n",
    "        }\n",
    "\n",
    "conformal_cfxt = ConformalCONFEXTree(model, dataset.input_properties, ccf2)\n",
    "conformal_cfxt.calibrate(X_calib, y_calib, None)\n",
    "plot_conformal_prediction(model, conformal_cfxt, X_calib, X_calib, y_calib, y_calib)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5ea4656",
   "metadata": {},
   "outputs": [],
   "source": [
    "wachter = WachterGenerator(model, dataset.input_properties, config={\"mad\": True})\n",
    "wachter.setup(X_train, y_train, X_calib, y_calib)\n",
    "cfwachter = wachter.generate_counterfactual(factual, y_target)\n",
    "plot_counterfactual(factual, cfwachter, model, X_calib, X_calib, y_calib, y_calib)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0e05f2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "conformalCF_c = ConformalCF(model, dataset.input_properties, config={\"conformal_class\": ConformalCONFEXTree, \"conformal_config\": ccf2})\n",
    "conformalCF_c.setup(X_train, y_train, X_calib, y_calib)\n",
    "counterfactual_conformal_c = conformalCF_c.generate_counterfactual(factual, y_target)\n",
    "plot_counterfactual(factual, counterfactual_conformal_c, model, X_calib, X_calib, y_calib, y_calib, conformal=conformalCF_c.conformal)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "confexplus",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
