{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1127995",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "574a4a38",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "from datasets.datasets import SyntheticFACE, SyntheticMoons, Dataset, CaliforniaHousing, GermanCreditv2, GiveMeSomeCredit, AdultIncome\n",
    "from visualisation import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99fd57ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "from datasets.datasets import SyntheticFACE, SyntheticMoons, Dataset, CaliforniaHousing, GermanCreditv2, GiveMeSomeCredit, AdultIncome\n",
    "from visualisation import *\n",
    "\n",
    "from models.mlp_pytorch import PyTorchMLP\n",
    "from conformal.split_conformal import SplitConformalPrediction\n",
    "from conformal.localised_conformal_baselcp import BaseLCP\n",
    "from conformal.localised_conformal_tree import ConformalCONFEXTree\n",
    "\n",
    "from counterfactual_explanations.counterfactual_benchmarker import *\n",
    "from counterfactual_explanations.gradient_based.auxillary_models import *\n",
    "from counterfactual_explanations.gradient_based.cf_gradient_based import *\n",
    "from counterfactual_explanations.tree.cf_featuretweak import FeatureTweakGenerator\n",
    "from counterfactual_explanations.tree.cf_focus import FOCUSGenerator\n",
    "from counterfactual_explanations.gradient_based.losses import *\n",
    "\n",
    "from counterfactual_explanations.milp_based.cf_conformal import *\n",
    "from counterfactual_explanations.milp_based.cf_mindist import *\n",
    "from counterfactual_explanations.dim_reduction import *\n",
    "import argparse\n",
    "\n",
    "from datetime import datetime\n",
    "\n",
    "print(f\"Start {datetime.now()}\")\n",
    "\n",
    "\n",
    "dataset_cls = CaliforniaHousing\n",
    "mlp_config = {\"epochs\": 100, \"batch_size\": 64}\n",
    "# mlp_config = {\"epochs\": 50, \"batch_size\": 256}\n",
    "\n",
    "is_rf = False\n",
    "\n",
    "print(f\"Start {datetime.now()}\")\n",
    "mlp_config = {\"epochs\": 100, \"batch_size\": 64}\n",
    "\n",
    "dataset = dataset_cls(0.6, 0.2, 0.2)\n",
    "model_factories = []\n",
    "\n",
    "if is_rf:\n",
    "    factory = ModelFactory(RandomForestSKLearn, dataset.input_properties, config={}, config_multi={})\n",
    "    model_factories.append(factory)\n",
    "else:\n",
    "    factory = ModelFactory(PyTorchMLP, dataset.input_properties, config=mlp_config, config_multi={})\n",
    "    model_factories.append(factory)\n",
    "\n",
    "n_factuals_main = 100\n",
    "n_repeats = 2\n",
    "path = Path(\"results_path\")\n",
    "use_pretrained = True\n",
    "\n",
    "metrics = [\n",
    "    FailuresMetric(), \n",
    "    DistanceMetric(), \n",
    "    ValidityMetric(),  \n",
    "    ImplausibilityMetric(included_prop=0.1), \n",
    "    ImplausibilityMetric(included_prop=1),\n",
    "    LOFMetric(n_neighbours=20, stratified=True),\n",
    "    SensitivityMetric(n_sensitivity=25, n_neighbours=4, budget=0.001), \n",
    "    StabilityMetric(n_neighbours=8, budget=0.001),\n",
    "]\n",
    "\n",
    "conformal_config = {\n",
    "    \"alpha\": [0.01, 0.05, 0.1], \"scorefn_name\": [\"linear_logits\" if is_rf else \"linear2\"], \"kernel_bandwidth\": [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45], \n",
    "}\n",
    "\n",
    "generators = [\n",
    "    GeneratorFactory([MinDistanceCF], config={}, config_multi={}),\n",
    "\n",
    "    GeneratorFactory([ConformalCF], config={\"conformal_class\": SplitConformalPrediction}, config_multi={\"conformal_config\": {\"alpha\": [0.01, 0.05, 0.1]}}),\n",
    "\n",
    "    GeneratorFactory([ConformalCF], config={\"conformal_class\": ConformalCONFEXTree}, config_multi={\n",
    "        \"conformal_config\": conformal_config\n",
    "    }),\n",
    "]\n",
    "\n",
    "if is_rf:\n",
    "    generators.extend([\n",
    "        GeneratorFactory([FOCUSGenerator], config={\"n_iter\": 200}, config_multi={}),\n",
    "        GeneratorFactory([FeatureTweakGenerator], config={\"epsilon\": 0.01}, config_multi={})\n",
    "    ])\n",
    "else:\n",
    "    f1 = GeneratorFactory([WachterGenerator], config={\"mad\": True}, config_multi={})\n",
    "    f2 = GeneratorFactory([SchutGenerator], config={}, config_multi={})\n",
    "    f3 = GeneratorFactory([ECCCOGenerator], config={}, config_multi={\"conformal_config\": {\"alpha\": [0.01, 0.05, 0.1]}})\n",
    "    generators.extend([f1, f2, f3])\n",
    "    \n",
    "\n",
    "## Do not modify below\n",
    "print(\"Initializing CFBenchmarker...\")\n",
    "benchmarker = CFBenchmarker(dataset, n_factuals_main, n_repeats, metrics, model_factories, generators, path, use_pretrained=use_pretrained)\n",
    "\n",
    "print(\"Setting up models...\")\n",
    "benchmarker.setup_models()\n",
    "\n",
    "print(\"Evaluating models...\")\n",
    "benchmarker.evaluate_models()\n",
    "\n",
    "print(\"Setting factuals...\")\n",
    "benchmarker.set_factuals()\n",
    "\n",
    "print(\"Initializing generators...\")\n",
    "benchmarker.initialise_generators()\n",
    "\n",
    "print(\"Generating counterfactuals...\")\n",
    "benchmarker.get_counterfactuals(reset=not use_pretrained)\n",
    "\n",
    "print(\"Evaluating counterfactuals...\")\n",
    "df_out = benchmarker.evaluate_counterfactuals()\n",
    "\n",
    "print(\"Generate table\")\n",
    "benchmarker.generate_table(\"Distance_wNone_l1\", \"LOF_20S\", \"Implausibility_0.1\", \"Sensitivity25,4,0.001\", \"Stability8,0.001\", dp2=True, scaling=[(3, 0.1)])[0]\n",
    "\n",
    "print(\"Test conformal...\")\n",
    "benchmarker.test_conformal()\n",
    "\n",
    "\n",
    "print(f\"Evaluation complete. See {path}\")\n",
    "print(f\"End {datetime.now()}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16813399",
   "metadata": {},
   "outputs": [],
   "source": [
    "t21 = benchmarker.generate_table(\"Distance_wNone_l1\", \"LOF_20S\", \"Implausibility_0.1\", \"Sensitivity25,4,0.001\", \"Stability8,0.001\", dp2=False, scaling=[(3, 0.1)], include_extra=[\"Validity\", \"Failures\"])[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e5f739e",
   "metadata": {},
   "outputs": [],
   "source": [
    "t21 = t21.drop(index=[\"WachterGenerator{\\\"mad\\\":true}\", \"MinDistanceCF{}\", \"\"\"SchutGenerator{\"new\":true}\"\"\", \"\"\"ECCCOGenerator{\"conformal_config\":{\"alpha\":0.01}}\"\"\", \"\"\"ECCCOGenerator{\"conformal_config\":{\"alpha\":0.05}}\"\"\", \"\"\"ECCCOGenerator{\"conformal_config\":{\"alpha\":0.1}}\"\"\"])\n",
    "t21"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32eff418",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "# Extract \"alpha\" and \"kernel_bandwidth\" values from the index and add them as new columns\n",
    "t21['alpha'] = t21.index.to_series().apply(lambda x: re.search(r'\"alpha\":(\\d+\\.?\\d*)', x).group(1) if re.search(r'\"alpha\":(\\d+\\.?\\d*)', x) else None)\n",
    "t21['kernel_bandwidth'] = t21.index.to_series().apply(lambda x: re.search(r'\"kernel_bandwidth\":(\\d+\\.?\\d*)', x).group(1) if re.search(r'\"kernel_bandwidth\":(\\d+\\.?\\d*)', x) else None)\n",
    "\n",
    "t21"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3975e306",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from labellines import labelLines\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# key = \"LOF_20S\"\n",
    "key = \"Distance_wNone_l1\"\n",
    "\n",
    "# Extract mean and standard deviation from the 'Distance_wNone_l1' column\n",
    "t21[f'{key}mean'] = t21[key].str.split(' ± ').str[0].astype(float)\n",
    "t21[f'{key}sd'] = t21[key].str.split(' ± ').str[1].astype(float)\n",
    "\n",
    "# Convert 'kernel_bandwidth' and 'alpha' to numeric for plotting\n",
    "t21['kernel_bandwidth'] = pd.to_numeric(t21['kernel_bandwidth'], errors='coerce')\n",
    "t21['alpha'] = pd.to_numeric(t21['alpha'], errors='coerce')\n",
    "\n",
    "# Filter rows with non-null kernel_bandwidth and alpha values\n",
    "df_filtered = t21.dropna(subset=['kernel_bandwidth', 'alpha'])\n",
    "\n",
    "# Separate rows with None kernel_bandwidth\n",
    "df_none_bandwidth = t21[t21['kernel_bandwidth'].isnull()]\n",
    "\n",
    "# Group by alpha and plot\n",
    "plt.figure(figsize=(6, 6))\n",
    "alpha_colours = {}\n",
    "\n",
    "for alpha, group in df_filtered.groupby('alpha'):\n",
    "    group = group.sort_values(by='kernel_bandwidth')  # Ensure the data is sorted by kernel_bandwidth\n",
    "    plt.errorbar(\n",
    "        group['kernel_bandwidth'], \n",
    "        group[f'{key}mean'], \n",
    "        yerr=group[f'{key}sd'], \n",
    "        fmt='o-', \n",
    "        label=f'alpha={alpha}',\n",
    "        capsize=2, elinewidth=1, markeredgewidth=0.5\n",
    "    )\n",
    "    alpha_colours[alpha] = plt.gca().lines[-1].get_color()  # Store the color of the line\n",
    "\n",
    "# Plot horizontal lines for None kernel_bandwidth\n",
    "for _, row in df_none_bandwidth.iterrows():\n",
    "    if \"Split\" in row.name:\n",
    "        plt.axhline(y=row[f'{key}mean'], linestyle='--', color=alpha_colours.get(row['alpha'], 'gray'))\n",
    "\n",
    "# labelLines(plt.gca().get_lines(), zorder=2.5)\n",
    "\n",
    "# plt.title(f'Distance for CONFEXTree')\n",
    "plt.xlabel('Kernel Bandwidth')\n",
    "plt.ylabel(\"Distance\")\n",
    "plt.legend()\n",
    "plt.grid(True)\n",
    "plt.savefig(\"/Users/abilkhoo/CONFEX+ copy/results_f/1.png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8aff7697",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "key = \"LOF_20S\"\n",
    "# key = \"Distance_wNone_l1\"\n",
    "\n",
    "# Extract mean and standard deviation from the 'Distance_wNone_l1' column\n",
    "t21[f'{key}mean'] = t21[key].str.split(' ± ').str[0].astype(float)\n",
    "t21[f'{key}sd'] = t21[key].str.split(' ± ').str[1].astype(float)\n",
    "\n",
    "# Convert 'kernel_bandwidth' and 'alpha' to numeric for plotting\n",
    "t21['kernel_bandwidth'] = pd.to_numeric(t21['kernel_bandwidth'], errors='coerce')\n",
    "t21['alpha'] = pd.to_numeric(t21['alpha'], errors='coerce')\n",
    "\n",
    "# Filter rows with non-null kernel_bandwidth and alpha values\n",
    "df_filtered = t21.dropna(subset=['kernel_bandwidth', 'alpha'])\n",
    "\n",
    "# Separate rows with None kernel_bandwidth\n",
    "df_none_bandwidth = t21[t21['kernel_bandwidth'].isnull()]\n",
    "\n",
    "# Group by alpha and plot\n",
    "plt.figure(figsize=(6, 6))\n",
    "alpha_colours = {}\n",
    "\n",
    "for alpha, group in df_filtered.groupby('alpha'):\n",
    "    group = group.sort_values(by='kernel_bandwidth')  # Ensure the data is sorted by kernel_bandwidth\n",
    "    plt.errorbar(\n",
    "        group['kernel_bandwidth'], \n",
    "        group[f'{key}mean'], \n",
    "        yerr=group[f'{key}sd'], \n",
    "        fmt='o-', \n",
    "        label=f'alpha={alpha}',\n",
    "        capsize=2, elinewidth=1, markeredgewidth=0.5\n",
    "    )\n",
    "    alpha_colours[alpha] = plt.gca().lines[-1].get_color()  # Store the color of the line\n",
    "\n",
    "# Plot horizontal lines for None kernel_bandwidth\n",
    "for _, row in df_none_bandwidth.iterrows():\n",
    "    if \"Split\" in row.name:\n",
    "        plt.axhline(y=row[f'{key}mean'], linestyle='--', color=alpha_colours.get(row['alpha'], 'gray'))\n",
    "\n",
    "# labelLines(plt.gca().get_lines(), zorder=2.5)\n",
    "\n",
    "# plt.title(f'Plausibility for CONFEXTree')\n",
    "plt.xlabel('Kernel Bandwidth')\n",
    "plt.ylabel(\"Plausibility\")\n",
    "plt.legend()\n",
    "plt.grid(True)\n",
    "plt.savefig(\"/Users/abilkhoo/CONFEX+ copy/results_f/2.png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39a6f1bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(\"Adding LCP\")\n",
    "# benchmarker.set_additional_conformal(conformal_classes=[BaseLCP], conformal_config={\"kernel_name\": \"box_linf\"}, conformal_config_multi=ccf2)\n",
    "\n",
    "print(\"Test conformal...\")\n",
    "c2 = benchmarker.test_conformal()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9832fbba",
   "metadata": {},
   "outputs": [],
   "source": [
    "cov_gap = c2['results_v/0/CaliforniaHousing/models/PyTorchMLP_{\"epochs\":100,\"batch_size\":64}']['covgap']['mean']\n",
    "cov_gap_errors = c2['results_v/0/CaliforniaHousing/models/PyTorchMLP_{\"epochs\":100,\"batch_size\":64}']['covgap']['sd']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60cfb6aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "# Extract \"alpha\" and \"kernel_bandwidth\" values from the index and add them as new columns\n",
    "cov_gap['alpha'] = cov_gap.index.to_series().apply(lambda x: re.search(r'\"alpha\":(\\d+\\.?\\d*)', x).group(1) if re.search(r'\"alpha\":(\\d+\\.?\\d*)', x) else None)\n",
    "cov_gap['kernel_bandwidth'] = cov_gap.index.to_series().apply(lambda x: re.search(r'\"kernel_bandwidth\":(\\d+\\.?\\d*)', x).group(1) if re.search(r'\"kernel_bandwidth\":(\\d+\\.?\\d*)', x) else None)\n",
    "\n",
    "cov_gap_errors['alpha'] = cov_gap_errors.index.to_series().apply(lambda x: re.search(r'\"alpha\":(\\d+\\.?\\d*)', x).group(1) if re.search(r'\"alpha\":(\\d+\\.?\\d*)', x) else None)\n",
    "cov_gap_errors['kernel_bandwidth'] = cov_gap_errors.index.to_series().apply(lambda x: re.search(r'\"kernel_bandwidth\":(\\d+\\.?\\d*)', x).group(1) if re.search(r'\"kernel_bandwidth\":(\\d+\\.?\\d*)', x) else None)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6cdbba9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Filter rows with non-null kernel_bandwidth and alpha values\n",
    "df_filtered = cov_gap.dropna(subset=['alpha'])\n",
    "\n",
    "# Convert kernel_bandwidth and Counterfactual Sim (LOF_20S) to numeric\n",
    "df_filtered['kernel_bandwidth'] = pd.to_numeric(df_filtered['kernel_bandwidth'], errors='coerce')\n",
    "df_filtered['Counterfactual Sim'] = pd.to_numeric(df_filtered['Counterfactual Sim'], errors='coerce')\n",
    "\n",
    "# Separate rows with None kernel_bandwidth\n",
    "df_none_bandwidth = df_filtered[df_filtered['kernel_bandwidth'].isnull()]\n",
    "df_filtered = df_filtered.dropna(subset=['kernel_bandwidth'])\n",
    "\n",
    "# Group by alpha and plot\n",
    "\n",
    "alpha_colours = {}\n",
    "\n",
    "plt.figure(figsize=(6, 6))\n",
    "for alpha, group in df_filtered.groupby('alpha'):\n",
    "    group = group.sort_values(by='kernel_bandwidth')  # Ensure the data is sorted by kernel_bandwidth\n",
    "    line, = plt.plot(group['kernel_bandwidth'], group['Counterfactual Sim'], marker='o', linestyle='-', label=f'alpha={alpha}')\n",
    "    group_color = line.get_color()  # Store the color of the line\n",
    "    alpha_colours[alpha] = group_color\n",
    "\n",
    "# Plot horizontal lines for None kernel_bandwidth\n",
    "for _, row in df_none_bandwidth.iterrows():\n",
    "    plt.axhline(y=row['Counterfactual Sim'], linestyle='--', c=alpha_colours[row['alpha']])\n",
    "\n",
    "plt.axhline(y=0, color='red', linestyle='--', label='Target')\n",
    "\n",
    "# plt.title('Coverage gap for Simulated CFXs')\n",
    "plt.xlabel('Kernel Bandwidth')\n",
    "plt.ylabel('Coverage gap')\n",
    "plt.legend()\n",
    "plt.grid(True)\n",
    "plt.savefig(\"/Users/abilkhoo/CONFEX+ copy/results_f/3.png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d8d5584",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Filter rows with non-null kernel_bandwidth and alpha values\n",
    "df_filtered = cov_gap.dropna(subset=['alpha'])\n",
    "\n",
    "# Convert kernel_bandwidth and Counterfactual Sim (LOF_20S) to numeric\n",
    "df_filtered['kernel_bandwidth'] = pd.to_numeric(df_filtered['kernel_bandwidth'], errors='coerce')\n",
    "df_filtered['Counterfactual Sim'] = pd.to_numeric(df_filtered['Counterfactual Sim'], errors='coerce')\n",
    "\n",
    "# Separate rows with None kernel_bandwidth\n",
    "df_none_bandwidth = df_filtered[df_filtered['kernel_bandwidth'].isnull()]\n",
    "df_filtered = df_filtered.dropna(subset=['kernel_bandwidth'])\n",
    "\n",
    "# Group by alpha and plot\n",
    "\n",
    "alpha_colours = {}\n",
    "\n",
    "plt.figure(figsize=(6, 6))\n",
    "for alpha, group in df_filtered.groupby('alpha'):\n",
    "    group = group.sort_values(by='kernel_bandwidth')  # Ensure the data is sorted by kernel_bandwidth\n",
    "    plt.errorbar(\n",
    "        group['kernel_bandwidth'], \n",
    "        group['Counterfactual Sim'], \n",
    "        yerr=cov_gap_errors.loc[group.index, 'Counterfactual Sim'], \n",
    "        fmt='o-', \n",
    "        label=f'alpha={alpha}',\n",
    "        capsize=2, elinewidth=0.7, markeredgewidth=1\n",
    "    )\n",
    "    alpha_colours[alpha] = plt.gca().lines[-1].get_color()  # Store the color of the line\n",
    "\n",
    "# Plot horizontal lines for None kernel_bandwidth\n",
    "for _, row in df_none_bandwidth.iterrows():\n",
    "    plt.axhline(y=row['Counterfactual Sim'], linestyle='--', c=alpha_colours[row['alpha']])\n",
    "\n",
    "plt.axhline(y=0, color='red', linestyle='--', label='Target')\n",
    "\n",
    "# plt.title('Coverage gap for Simulated CFXs')\n",
    "plt.xlabel('Kernel Bandwidth')\n",
    "plt.ylabel('Coverage gap')\n",
    "plt.legend()\n",
    "plt.grid(True)\n",
    "plt.savefig(\"/Users/abilkhoo/CONFEX+ copy/results_f/4.png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "confexplus",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
