{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ElliCE: Efficient and Provably Robust Algorithmic Recourse\n",
    "\n",
    "This notebook demonstrates the features of the **ElliCE** library, as described in the [README](README.md).\n",
    "\n",
    "**ElliCE** generates provably robust counterfactual explanations ensuring validity across the Rashomon set of nearly-optimal models.\n",
    "\n",
    "## Table of Contents\n",
    "1. [Installation & Setup](#setup)\n",
    "2. [Quick Start](#quick-start)\n",
    "3. [Advanced Actionability Constraints](#constraints)\n",
    "   - Immutable Features\n",
    "   - Range Constraints\n",
    "   - One-Way Changes\n",
    "   - Categorical Features\n",
    "4. [Generators](#generators)\n",
    "   - Continuous (with Sparsity)\n",
    "   - Data-Supported\n",
    "5. [Custom Backend (PyTorch)](#custom-backend)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Installation & Setup <a id=\"setup\"></a>\n",
    "\n",
    "Ensure `ellice` is installed. If you are running this from the repo, you can install it in editable mode."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "aa05aab9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install ellice\n",
    "# Or if running from repo source:\n",
    "# !pip install -e ."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23019394",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "\n",
    "# Ensure display is defined if running in non-IPython environment (fallback)\n",
    "try:\n",
    "    from IPython.display import display, HTML\n",
    "except ImportError:\n",
    "    def display(*args):\n",
    "        for arg in args:\n",
    "            print(arg)\n",
    "    def HTML(text):\n",
    "        return text\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.datasets import load_breast_cancer\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "import ellice\n",
    "from ellice.configs import GenerationConfig, AlgorithmConfig\n",
    "\n",
    "# Global Configuration\n",
    "# In practive this robustness_epsilon could be set to 10% of train loss as default, or determited used set of proxi models (hyperparameter tuning is procedure described in the paper)\n",
    "# Here we use additive 0.01 to the train loss, so we have loss <= train loss + 0.01, 10% of train loss in practise would mean loss <= train loss * 1.1\n",
    "robustness_epsilon = 0.01\n",
    "regularization_coefficient = 0.005\n",
    "\n",
    "# Reproducibility\n",
    "def seed_everything(seed: int):\n",
    "    import random\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "\n",
    "seed_everything(42)\n",
    "\n",
    "# Helper function to display query vs CF with highlighted changes\n",
    "def display_query_vs_cf(query: pd.Series, cf: pd.Series, feature_names: list, threshold: float = 1e-4, \n",
    "                        explainer=None, target_class=None, robustness_epsilon=robustness_epsilon, regularization_coefficient=regularization_coefficient):\n",
    "    \"\"\"Display query and CF side by side, highlighting changed features.\n",
    "    \n",
    "    Args:\n",
    "        query: Original query instance\n",
    "        cf: Counterfactual instance\n",
    "        feature_names: List of feature names\n",
    "        threshold: Threshold for considering a feature changed\n",
    "        explainer: Optional ElliCE Explainer instance for computing predictions\n",
    "        target_class: Optional target class for robust probability calculation\n",
    "        robustness_epsilon: Epsilon value for robust probability calculation\n",
    "        regularization_coefficient: Regularization coefficient for robust probability calculation\n",
    "    \"\"\"\n",
    "    try:\n",
    "        from IPython.display import HTML\n",
    "        diff = cf[feature_names] - query[feature_names]\n",
    "        changed_features = diff[diff.abs() > threshold]\n",
    "        \n",
    "        # Create HTML table\n",
    "        html = \"<table border='1' style='border-collapse: collapse;'>\"\n",
    "        html += \"<tr><th>Feature</th><th>Original</th><th>Counterfactual</th><th>Change</th></tr>\"\n",
    "        \n",
    "        for feat in feature_names:\n",
    "            orig_val = query[feat]\n",
    "            cf_val = cf[feat]\n",
    "            change = diff[feat]\n",
    "            is_changed = abs(change) > threshold\n",
    "            \n",
    "            color = \"#cce5ff\" if is_changed else \"#ffffff\"\n",
    "            html += f\"<tr style='background-color: {color}'>\"\n",
    "            html += f\"<td>{feat}</td>\"\n",
    "            html += f\"<td>{orig_val:.4f}</td>\"\n",
    "            html += f\"<td>{cf_val:.4f}</td>\"\n",
    "            html += f\"<td>{change:.4f}</td>\"\n",
    "            html += \"</tr>\"\n",
    "        \n",
    "        html += \"</table>\"\n",
    "        display(HTML(html))\n",
    "        \n",
    "        # Print prediction information if explainer is provided\n",
    "        if explainer is not None:\n",
    "            cf_features = pd.DataFrame([cf[feature_names]], columns=feature_names)\n",
    "            \n",
    "            # Get model probability (original model)\n",
    "            model_probs = explainer.model.predict_proba(cf_features.values)\n",
    "            predicted_class = 1 if model_probs[0, 1] > 0.5 else 0\n",
    "            \n",
    "            # Get probability for target class\n",
    "            if target_class is not None:\n",
    "                predicted_prob = model_probs[0, target_class]\n",
    "            else:\n",
    "                predicted_prob = model_probs[0, predicted_class]\n",
    "            \n",
    "            # Get robust probability (worst case model)\n",
    "            robust_prob = predicted_prob  # Default fallback\n",
    "            if target_class is not None:\n",
    "                try:\n",
    "                    from ellice.generators.continuous import ContinuousGenerator\n",
    "                    temp_gen = ContinuousGenerator(\n",
    "                        model=explainer.model,\n",
    "                        data=explainer.data,\n",
    "                        eps=robustness_epsilon,\n",
    "                        reg_coef=regularization_coefficient,\n",
    "                        device=str(explainer.device)\n",
    "                    )\n",
    "                    robust_probs = temp_gen.get_worst_case_prob(cf_features, target_class=target_class)\n",
    "                    robust_prob = robust_probs[0]\n",
    "                except Exception as e:\n",
    "                    # Fallback: just show model prob\n",
    "                    robust_prob = predicted_prob\n",
    "            \n",
    "            print(f\"\\nPrediction Information:\")\n",
    "            print(f\"  Predicted Class: {predicted_class}\")\n",
    "            print(f\"  Predicted Probability of Target Class: {predicted_prob:.4f}\")\n",
    "            if target_class is not None:\n",
    "                print(f\"  Robust Probability (Worst Case) of Target Class: {robust_prob:.4f}\")\n",
    "        \n",
    "    except:\n",
    "        # Fallback to simple print if HTML fails\n",
    "        print(\"\\n=== Original Query ===\")\n",
    "        print(query[feature_names])\n",
    "        print(\"\\n=== Counterfactual ===\")\n",
    "        print(cf[feature_names])\n",
    "        diff = cf[feature_names] - query[feature_names]\n",
    "        changed = diff[diff.abs() > threshold]\n",
    "        print(f\"\\n=== Changed Features ({len(changed)}) ===\")\n",
    "        print(changed)\n",
    "        \n",
    "        # Print prediction info in fallback mode too\n",
    "        if explainer is not None:\n",
    "            cf_features = pd.DataFrame([cf[feature_names]], columns=feature_names)\n",
    "            model_probs = explainer.model.predict_proba(cf_features.values)\n",
    "            predicted_class = 1 if model_probs[0, 1] > 0.5 else 0\n",
    "            # Get probability for target class\n",
    "            if target_class is not None:\n",
    "                predicted_prob = model_probs[0, target_class]\n",
    "            else:\n",
    "                predicted_prob = model_probs[0, predicted_class]\n",
    "            print(f\"\\nPrediction Information:\")\n",
    "            print(f\"  Predicted Class: {predicted_class}\")\n",
    "            print(f\"  Predicted Probability of Target Class: {predicted_prob:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4a519cb",
   "metadata": {},
   "source": [
    "## 2. Quick Start <a id=\"quick-start\"></a>\n",
    "\n",
    "We'll use the Breast Cancer dataset and a simple Logistic Regression model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2bdff74d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model Accuracy: 0.9561\n"
     ]
    }
   ],
   "source": [
    "# 1. Load Data\n",
    "data_raw = load_breast_cancer()\n",
    "X = pd.DataFrame(data_raw.data, columns=data_raw.feature_names)\n",
    "y = pd.Series(data_raw.target, name=\"target\")\n",
    "\n",
    "# Split data\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
    "\n",
    "#When processing the data, preferably we need to normalize it, butfor simplicity, we will not do that\n",
    "#The implication is that we will use larger regularization_coefficient for the robust probability calculation stability\n",
    "\n",
    "# 2. Train Model\n",
    "clf = LogisticRegression(max_iter=5000, solver='liblinear').fit(X_train, y_train)\n",
    "print(f\"Model Accuracy: {clf.score(X_test, y_test):.4f}\")\n",
    "\n",
    "# 3. Initialize ElliCE\n",
    "# We need a dataframe that includes the target for ElliCE's Data object\n",
    "full_df = X_train.copy()\n",
    "full_df['target'] = y_train\n",
    "\n",
    "data = ellice.Data(dataframe=full_df, target_column='target')\n",
    "\n",
    "exp = ellice.Explainer(\n",
    "    model=clf,\n",
    "    data=data,\n",
    "    backend='sklearn',\n",
    "    device='auto'  # Automatically selects CUDA/MPS if available. Use 'cpu' if you encounter CUDA errors.\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7bc1a4a",
   "metadata": {},
   "source": [
    "### PyTorch Backend Example\n",
    "We can also use a PyTorch neural network model with ElliCE.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7e77a37b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 20/100, Loss: 0.5206, Test Accuracy: 0.8158\n",
      "Epoch 40/100, Loss: 0.2844, Test Accuracy: 0.9298\n",
      "Epoch 60/100, Loss: 0.2578, Test Accuracy: 0.9474\n",
      "Epoch 80/100, Loss: 0.2452, Test Accuracy: 0.9561\n",
      "Epoch 100/100, Loss: 0.2363, Test Accuracy: 0.9649\n",
      "Final PyTorch Model Accuracy: 0.9649\n",
      "\n",
      "PyTorch Model - Original Prediction: 0\n",
      "PyTorch Model - Target Prediction: 1\n",
      "Progress Bar Enabled\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating CF:   9%|▉         | 89/1000 [00:00<00:00, 1021.71it/s, Prob=0.730, RobLogit=-0.000, BestRobLogit=0.000]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "PyTorch Counterfactual Found!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table border='1' style='border-collapse: collapse;'><tr><th>Feature</th><th>Original</th><th>Counterfactual</th><th>Change</th></tr><tr style='background-color: #cce5ff'><td>mean radius</td><td>12.4700</td><td>13.0353</td><td>0.5653</td></tr><tr style='background-color: #cce5ff'><td>mean texture</td><td>18.6000</td><td>22.6056</td><td>4.0056</td></tr><tr style='background-color: #cce5ff'><td>mean perimeter</td><td>81.0900</td><td>87.2303</td><td>6.1403</td></tr><tr style='background-color: #cce5ff'><td>mean area</td><td>481.9000</td><td>490.4293</td><td>8.5293</td></tr><tr style='background-color: #cce5ff'><td>mean smoothness</td><td>0.0997</td><td>0.1634</td><td>0.0637</td></tr><tr style='background-color: #cce5ff'><td>mean compactness</td><td>0.1058</td><td>0.1632</td><td>0.0574</td></tr><tr style='background-color: #cce5ff'><td>mean concavity</td><td>0.0800</td><td>0.0000</td><td>-0.0800</td></tr><tr style='background-color: #cce5ff'><td>mean concave points</td><td>0.0382</td><td>0.0000</td><td>-0.0382</td></tr><tr style='background-color: #cce5ff'><td>mean symmetry</td><td>0.1925</td><td>0.3040</td><td>0.1115</td></tr><tr style='background-color: #cce5ff'><td>mean fractal dimension</td><td>0.0637</td><td>0.0974</td><td>0.0337</td></tr><tr style='background-color: #cce5ff'><td>radius error</td><td>0.3961</td><td>0.9718</td><td>0.5757</td></tr><tr style='background-color: #cce5ff'><td>texture error</td><td>1.0440</td><td>1.2328</td><td>0.1888</td></tr><tr style='background-color: #cce5ff'><td>perimeter error</td><td>2.4970</td><td>3.1350</td><td>0.6380</td></tr><tr style='background-color: #cce5ff'><td>area error</td><td>30.2900</td><td>21.8057</td><td>-8.4843</td></tr><tr style='background-color: #cce5ff'><td>smoothness error</td><td>0.0070</td><td>0.0017</td><td>-0.0052</td></tr><tr style='background-color: #cce5ff'><td>compactness error</td><td>0.0191</td><td>0.0023</td><td>-0.0169</td></tr><tr style='background-color: #cce5ff'><td>concavity error</td><td>0.0270</td><td>0.0000</td><td>-0.0270</td></tr><tr style='background-color: #cce5ff'><td>concave points error</td><td>0.0104</td><td>0.0528</td><td>0.0424</td></tr><tr style='background-color: #cce5ff'><td>symmetry error</td><td>0.0178</td><td>0.0615</td><td>0.0436</td></tr><tr style='background-color: #cce5ff'><td>fractal dimension error</td><td>0.0036</td><td>0.0298</td><td>0.0263</td></tr><tr style='background-color: #cce5ff'><td>worst radius</td><td>14.9700</td><td>16.1698</td><td>1.1998</td></tr><tr style='background-color: #cce5ff'><td>worst texture</td><td>24.6400</td><td>32.1057</td><td>7.4657</td></tr><tr style='background-color: #cce5ff'><td>worst perimeter</td><td>96.0500</td><td>102.3064</td><td>6.2564</td></tr><tr style='background-color: #cce5ff'><td>worst area</td><td>677.9000</td><td>669.2981</td><td>-8.6019</td></tr><tr style='background-color: #cce5ff'><td>worst smoothness</td><td>0.1426</td><td>0.0712</td><td>-0.0714</td></tr><tr style='background-color: #cce5ff'><td>worst compactness</td><td>0.2378</td><td>0.0273</td><td>-0.2105</td></tr><tr style='background-color: #cce5ff'><td>worst concavity</td><td>0.2671</td><td>0.1761</td><td>-0.0910</td></tr><tr style='background-color: #cce5ff'><td>worst concave points</td><td>0.1015</td><td>0.2910</td><td>0.1895</td></tr><tr style='background-color: #cce5ff'><td>worst symmetry</td><td>0.3014</td><td>0.1565</td><td>-0.1449</td></tr><tr style='background-color: #cce5ff'><td>worst fractal dimension</td><td>0.0875</td><td>0.1730</td><td>0.0855</td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Prediction Information:\n",
      "  Predicted Class: 1\n",
      "  Predicted Probability of Target Class: 0.7300\n",
      "  Robust Probability (Worst Case) of Target Class: 0.5001\n"
     ]
    }
   ],
   "source": [
    "# 1. Define PyTorch Model\n",
    "class SimpleNN(nn.Module):\n",
    "    def __init__(self, input_dim):\n",
    "        super().__init__()\n",
    "        self.layer1 = nn.Linear(input_dim, 32)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.layer2 = nn.Linear(32, 1)  # Binary output (logits)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.layer1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.layer2(x)\n",
    "        return x\n",
    "\n",
    "# 3. Train PyTorch Model\n",
    "torch_model = SimpleNN(input_dim=X_train.shape[1])\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "torch_model = torch_model.to(device)\n",
    "\n",
    "# Convert data to tensors\n",
    "X_train_t = torch.FloatTensor(X_train.values).to(device)\n",
    "y_train_t = torch.FloatTensor(y_train.values).unsqueeze(1).to(device)\n",
    "X_test_t = torch.FloatTensor(X_test.values).to(device)\n",
    "y_test_t = torch.FloatTensor(y_test.values).unsqueeze(1).to(device)\n",
    "\n",
    "# Training loop\n",
    "criterion = nn.BCEWithLogitsLoss()\n",
    "optimizer = optim.Adam(torch_model.parameters(), lr=0.001)\n",
    "epochs = 100\n",
    "\n",
    "torch_model.train()\n",
    "for epoch in range(epochs):\n",
    "    optimizer.zero_grad()\n",
    "    outputs = torch_model(X_train_t)\n",
    "    loss = criterion(outputs, y_train_t)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    if (epoch + 1) % 20 == 0:\n",
    "        with torch.no_grad():\n",
    "            torch_model.eval()\n",
    "            test_outputs = torch_model(X_test_t)\n",
    "            test_preds = (torch.sigmoid(test_outputs) > 0.5).float()\n",
    "            accuracy = (test_preds == y_test_t).float().mean().item()\n",
    "            print(f\"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}, Test Accuracy: {accuracy:.4f}\")\n",
    "            torch_model.train()\n",
    "\n",
    "torch_model.eval()\n",
    "print(f\"Final PyTorch Model Accuracy: {accuracy:.4f}\")\n",
    "\n",
    "# 4. Use with ElliCE\n",
    "exp_torch_quick = ellice.Explainer(\n",
    "    model=torch_model,\n",
    "    data=data,\n",
    "    backend='pytorch',\n",
    "    #backend_model_class=QuickStartModelWrapper,\n",
    "    device='auto'\n",
    ")\n",
    "\n",
    "# Generate CF with PyTorch model\n",
    "query_torch = X_test.iloc[0]\n",
    "original_pred_torch = (torch.sigmoid(torch_model(torch.FloatTensor(query_torch.values).unsqueeze(0).to(device))) > 0.5).item()\n",
    "target_class_torch = 1 - int(original_pred_torch)\n",
    "\n",
    "print(f\"\\nPyTorch Model - Original Prediction: {int(original_pred_torch)}\")\n",
    "print(f\"PyTorch Model - Target Prediction: {target_class_torch}\")\n",
    "\n",
    "cf_torch_quick = exp_torch_quick.generate_counterfactuals(\n",
    "    query_instances=query_torch,\n",
    "    method='continuous',\n",
    "    target_class=target_class_torch,\n",
    "    robustness_epsilon=robustness_epsilon,\n",
    "    regularization_coefficient=regularization_coefficient\n",
    ")\n",
    "\n",
    "if not cf_torch_quick.empty:\n",
    "    print(\"\\nPyTorch Counterfactual Found!\")\n",
    "    display_query_vs_cf(query_torch, cf_torch_quick.iloc[0], data.feature_names,\n",
    "                       explainer=exp_torch_quick, target_class=target_class_torch,\n",
    "                       robustness_epsilon=robustness_epsilon, regularization_coefficient=regularization_coefficient)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9778841",
   "metadata": {},
   "source": [
    "### Generate a Robust Counterfactual\n",
    "We pick a query instance and generate a counterfactual that flips the prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "689ba667",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/bt4811/anaconda3/envs/ellice/lib/python3.12/site-packages/sklearn/utils/validation.py:2749: UserWarning: X does not have valid feature names, but LogisticRegression was fitted with feature names\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original Prediction: 1 (Benign)\n",
      "Target Prediction:   0 (Malignant)\n",
      "Progress Bar Enabled\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating CF:   0%|          | 3/1000 [00:00<00:01, 649.21it/s, Prob=0.850, RobLogit=0.093, BestRobLogit=0.093]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Counterfactual Found!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "/Users/bt4811/anaconda3/envs/ellice/lib/python3.12/site-packages/sklearn/utils/validation.py:2749: UserWarning: X does not have valid feature names, but LogisticRegression was fitted with feature names\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table border='1' style='border-collapse: collapse;'><tr><th>Feature</th><th>Original</th><th>Counterfactual</th><th>Change</th></tr><tr style='background-color: #cce5ff'><td>mean radius</td><td>12.4700</td><td>12.1697</td><td>-0.3003</td></tr><tr style='background-color: #cce5ff'><td>mean texture</td><td>18.6000</td><td>18.3005</td><td>-0.2995</td></tr><tr style='background-color: #cce5ff'><td>mean perimeter</td><td>81.0900</td><td>81.3127</td><td>0.2227</td></tr><tr style='background-color: #cce5ff'><td>mean area</td><td>481.9000</td><td>482.2005</td><td>0.3005</td></tr><tr style='background-color: #cce5ff'><td>mean smoothness</td><td>0.0997</td><td>0.1634</td><td>0.0637</td></tr><tr style='background-color: #cce5ff'><td>mean compactness</td><td>0.1058</td><td>0.3114</td><td>0.2056</td></tr><tr style='background-color: #cce5ff'><td>mean concavity</td><td>0.0800</td><td>0.3778</td><td>0.2978</td></tr><tr style='background-color: #cce5ff'><td>mean concave points</td><td>0.0382</td><td>0.2012</td><td>0.1630</td></tr><tr style='background-color: #cce5ff'><td>mean symmetry</td><td>0.1925</td><td>0.3040</td><td>0.1115</td></tr><tr style='background-color: #cce5ff'><td>mean fractal dimension</td><td>0.0637</td><td>0.0974</td><td>0.0337</td></tr><tr style='background-color: #cce5ff'><td>radius error</td><td>0.3961</td><td>0.3825</td><td>-0.0136</td></tr><tr style='background-color: #cce5ff'><td>texture error</td><td>1.0440</td><td>0.7442</td><td>-0.2998</td></tr><tr style='background-color: #cce5ff'><td>perimeter error</td><td>2.4970</td><td>2.7863</td><td>0.2893</td></tr><tr style='background-color: #cce5ff'><td>area error</td><td>30.2900</td><td>30.5904</td><td>0.3004</td></tr><tr style='background-color: #cce5ff'><td>smoothness error</td><td>0.0070</td><td>0.0281</td><td>0.0211</td></tr><tr style='background-color: #cce5ff'><td>compactness error</td><td>0.0191</td><td>0.0023</td><td>-0.0169</td></tr><tr style='background-color: #cce5ff'><td>concavity error</td><td>0.0270</td><td>0.0544</td><td>0.0274</td></tr><tr style='background-color: #cce5ff'><td>concave points error</td><td>0.0104</td><td>0.0427</td><td>0.0323</td></tr><tr style='background-color: #cce5ff'><td>symmetry error</td><td>0.0178</td><td>0.0615</td><td>0.0436</td></tr><tr style='background-color: #cce5ff'><td>fractal dimension error</td><td>0.0036</td><td>0.0009</td><td>-0.0027</td></tr><tr style='background-color: #cce5ff'><td>worst radius</td><td>14.9700</td><td>14.6699</td><td>-0.3001</td></tr><tr style='background-color: #cce5ff'><td>worst texture</td><td>24.6400</td><td>24.9401</td><td>0.3001</td></tr><tr style='background-color: #cce5ff'><td>worst perimeter</td><td>96.0500</td><td>96.3499</td><td>0.2999</td></tr><tr style='background-color: #cce5ff'><td>worst area</td><td>677.9000</td><td>678.2000</td><td>0.3000</td></tr><tr style='background-color: #cce5ff'><td>worst smoothness</td><td>0.1426</td><td>0.2184</td><td>0.0758</td></tr><tr style='background-color: #cce5ff'><td>worst compactness</td><td>0.2378</td><td>0.5377</td><td>0.2999</td></tr><tr style='background-color: #cce5ff'><td>worst concavity</td><td>0.2671</td><td>0.5673</td><td>0.3002</td></tr><tr style='background-color: #cce5ff'><td>worst concave points</td><td>0.1015</td><td>0.2910</td><td>0.1895</td></tr><tr style='background-color: #cce5ff'><td>worst symmetry</td><td>0.3014</td><td>0.5997</td><td>0.2983</td></tr><tr style='background-color: #cce5ff'><td>worst fractal dimension</td><td>0.0875</td><td>0.1730</td><td>0.0855</td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Prediction Information:\n",
      "  Predicted Class: 0\n",
      "  Predicted Probability of Target Class: 0.8504\n",
      "  Robust Probability (Worst Case) of Target Class: 0.5233\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/bt4811/anaconda3/envs/ellice/lib/python3.12/site-packages/sklearn/utils/validation.py:2749: UserWarning: X does not have valid feature names, but LogisticRegression was fitted with feature names\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "query = X_test.iloc[0]\n",
    "original_pred = clf.predict([query])[0]\n",
    "target_class = 1 - original_pred\n",
    "\n",
    "print(f\"Original Prediction: {original_pred} ({'Malignant' if original_pred==0 else 'Benign'})\")\n",
    "print(f\"Target Prediction:   {target_class} ({'Malignant' if target_class==0 else 'Benign'})\")\n",
    "\n",
    "# Generate CF\n",
    "try:\n",
    "    cf = exp.generate_counterfactuals(\n",
    "        query_instances=query,\n",
    "        method='continuous',\n",
    "        target_class=target_class,\n",
    "        robustness_epsilon=robustness_epsilon,\n",
    "        regularization_coefficient=regularization_coefficient,\n",
    "        features_to_vary='all',\n",
    "        return_probs=True\n",
    "    )\n",
    "\n",
    "    # Display results\n",
    "    if not cf.empty:\n",
    "        print(\"\\nCounterfactual Found!\")\n",
    "        display_query_vs_cf(query, cf.iloc[0], data.feature_names,\n",
    "                          explainer=exp, target_class=target_class,\n",
    "                          robustness_epsilon=robustness_epsilon, regularization_coefficient=regularization_coefficient)\n",
    "    else:\n",
    "        print(\"No counterfactual found.\")\n",
    "except Exception as e:\n",
    "    print(f\"Error: {e}\")\n",
    "    import traceback\n",
    "    traceback.print_exc()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "853a64f0",
   "metadata": {},
   "source": [
    "## 3. Advanced Actionability Constraints <a id=\"constraints\"></a>\n",
    "\n",
    "ElliCE supports various constraints to make counterfactuals realistic."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e2f3901",
   "metadata": {},
   "source": [
    "### Immutable Features\n",
    "Prevent features like 'mean radius' from changing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d7b96cf5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Progress Bar Enabled\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating CF:   1%|          | 6/1000 [00:00<00:01, 856.07it/s, Prob=0.904, RobLogit=0.337, BestRobLogit=0.337]   "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Immutable Features Counterfactual:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table border='1' style='border-collapse: collapse;'><tr><th>Feature</th><th>Original</th><th>Counterfactual</th><th>Change</th></tr><tr style='background-color: #ffffff'><td>mean radius</td><td>12.4700</td><td>12.4700</td><td>0.0000</td></tr><tr style='background-color: #cce5ff'><td>mean texture</td><td>18.6000</td><td>18.0045</td><td>-0.5955</td></tr><tr style='background-color: #cce5ff'><td>mean perimeter</td><td>81.0900</td><td>81.5254</td><td>0.4354</td></tr><tr style='background-color: #cce5ff'><td>mean area</td><td>481.9000</td><td>482.5030</td><td>0.6030</td></tr><tr style='background-color: #cce5ff'><td>mean smoothness</td><td>0.0997</td><td>0.1634</td><td>0.0637</td></tr><tr style='background-color: #cce5ff'><td>mean compactness</td><td>0.1058</td><td>0.3114</td><td>0.2056</td></tr><tr style='background-color: #cce5ff'><td>mean concavity</td><td>0.0800</td><td>0.4268</td><td>0.3468</td></tr><tr style='background-color: #cce5ff'><td>mean concave points</td><td>0.0382</td><td>0.2012</td><td>0.1630</td></tr><tr style='background-color: #cce5ff'><td>mean symmetry</td><td>0.1925</td><td>0.3040</td><td>0.1115</td></tr><tr style='background-color: #cce5ff'><td>mean fractal dimension</td><td>0.0637</td><td>0.0974</td><td>0.0337</td></tr><tr style='background-color: #cce5ff'><td>radius error</td><td>0.3961</td><td>0.4121</td><td>0.0160</td></tr><tr style='background-color: #ffffff'><td>texture error</td><td>1.0440</td><td>1.0440</td><td>0.0000</td></tr><tr style='background-color: #cce5ff'><td>perimeter error</td><td>2.4970</td><td>2.9859</td><td>0.4889</td></tr><tr style='background-color: #cce5ff'><td>area error</td><td>30.2900</td><td>30.8934</td><td>0.6034</td></tr><tr style='background-color: #cce5ff'><td>smoothness error</td><td>0.0070</td><td>0.0191</td><td>0.0122</td></tr><tr style='background-color: #cce5ff'><td>compactness error</td><td>0.0191</td><td>0.0086</td><td>-0.0105</td></tr><tr style='background-color: #ffffff'><td>concavity error</td><td>0.0270</td><td>0.0270</td><td>-0.0000</td></tr><tr style='background-color: #cce5ff'><td>concave points error</td><td>0.0104</td><td>0.0528</td><td>0.0424</td></tr><tr style='background-color: #ffffff'><td>symmetry error</td><td>0.0178</td><td>0.0178</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>fractal dimension error</td><td>0.0036</td><td>0.0036</td><td>0.0000</td></tr><tr style='background-color: #cce5ff'><td>worst radius</td><td>14.9700</td><td>14.3732</td><td>-0.5968</td></tr><tr style='background-color: #cce5ff'><td>worst texture</td><td>24.6400</td><td>25.2398</td><td>0.5998</td></tr><tr style='background-color: #cce5ff'><td>worst perimeter</td><td>96.0500</td><td>96.6531</td><td>0.6031</td></tr><tr style='background-color: #cce5ff'><td>worst area</td><td>677.9000</td><td>678.4972</td><td>0.5972</td></tr><tr style='background-color: #cce5ff'><td>worst smoothness</td><td>0.1426</td><td>0.2184</td><td>0.0758</td></tr><tr style='background-color: #cce5ff'><td>worst compactness</td><td>0.2378</td><td>0.8330</td><td>0.5952</td></tr><tr style='background-color: #cce5ff'><td>worst concavity</td><td>0.2671</td><td>0.8654</td><td>0.5983</td></tr><tr style='background-color: #cce5ff'><td>worst concave points</td><td>0.1015</td><td>0.2910</td><td>0.1895</td></tr><tr style='background-color: #cce5ff'><td>worst symmetry</td><td>0.3014</td><td>0.6638</td><td>0.3624</td></tr><tr style='background-color: #cce5ff'><td>worst fractal dimension</td><td>0.0875</td><td>0.1730</td><td>0.0855</td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Prediction Information:\n",
      "  Predicted Class: 0\n",
      "  Predicted Probability of Target Class: 0.9044\n",
      "  Robust Probability (Worst Case) of Target Class: 0.5833\n",
      "\n",
      "Verification - ['mean radius', 'texture error', 'concavity error', 'symmetry error', 'fractal dimension error']:\n",
      "  Original: mean radius                12.470000\n",
      "texture error               1.044000\n",
      "concavity error             0.027010\n",
      "symmetry error              0.017820\n",
      "fractal dimension error     0.003586\n",
      "Name: 204, dtype: float64\n",
      "  CF:       mean radius                12.470000\n",
      "texture error               1.044000\n",
      "concavity error             0.027010\n",
      "symmetry error              0.017820\n",
      "fractal dimension error     0.003586\n",
      "Name: 0, dtype: float64\n",
      "  Changed? mean radius                False\n",
      "texture error              False\n",
      "concavity error            False\n",
      "symmetry error             False\n",
      "fractal dimension error    False\n",
      "dtype: bool\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/bt4811/anaconda3/envs/ellice/lib/python3.12/site-packages/sklearn/utils/validation.py:2749: UserWarning: X does not have valid feature names, but LogisticRegression was fitted with feature names\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "feature_to_freeze = ['mean radius', 'texture error', 'concavity error', 'symmetry error', 'fractal dimension error']\n",
    "features_to_vary = [col for col in X.columns if col not in feature_to_freeze]\n",
    "\n",
    "cf_immutable = exp.generate_counterfactuals(\n",
    "    query_instances=query,\n",
    "    method='continuous',\n",
    "    target_class=target_class,\n",
    "    features_to_vary=features_to_vary,\n",
    "    robustness_epsilon=robustness_epsilon,\n",
    "    regularization_coefficient=regularization_coefficient\n",
    ")\n",
    "\n",
    "if not cf_immutable.empty:\n",
    "    print(\"\\nImmutable Features Counterfactual:\")\n",
    "    display_query_vs_cf(query, cf_immutable.iloc[0], data.feature_names,\n",
    "                      explainer=exp, target_class=target_class,\n",
    "                      robustness_epsilon=robustness_epsilon, regularization_coefficient=regularization_coefficient)\n",
    "    \n",
    "    # Verify immutable feature\n",
    "    original_val = query[feature_to_freeze]\n",
    "    cf_val = cf_immutable.iloc[0][feature_to_freeze]\n",
    "    print(f\"\\nVerification - {feature_to_freeze}:\")\n",
    "    print(f\"  Original: {original_val}\")\n",
    "    print(f\"  CF:       {cf_val}\")\n",
    "    print(f\"  Changed? {abs(original_val - cf_val) > 1e-5}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6adc27d5",
   "metadata": {},
   "source": [
    "### Range Constraints & One-Way Changes\n",
    "Restrict `mean texture` to a specific range and force `mean area` to only increase."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "804b5aff",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Progress Bar Enabled\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating CF:   0%|          | 3/1000 [00:00<00:01, 782.62it/s, Prob=0.850, RobLogit=0.093, BestRobLogit=0.093]   "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Constrained Counterfactual (Range & One-Way):\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table border='1' style='border-collapse: collapse;'><tr><th>Feature</th><th>Original</th><th>Counterfactual</th><th>Change</th></tr><tr style='background-color: #cce5ff'><td>mean radius</td><td>12.4700</td><td>12.1697</td><td>-0.3003</td></tr><tr style='background-color: #cce5ff'><td>mean texture</td><td>18.6000</td><td>18.3005</td><td>-0.2995</td></tr><tr style='background-color: #cce5ff'><td>mean perimeter</td><td>81.0900</td><td>81.3127</td><td>0.2227</td></tr><tr style='background-color: #cce5ff'><td>mean area</td><td>481.9000</td><td>482.2005</td><td>0.3005</td></tr><tr style='background-color: #cce5ff'><td>mean smoothness</td><td>0.0997</td><td>0.1634</td><td>0.0637</td></tr><tr style='background-color: #cce5ff'><td>mean compactness</td><td>0.1058</td><td>0.3114</td><td>0.2056</td></tr><tr style='background-color: #cce5ff'><td>mean concavity</td><td>0.0800</td><td>0.3778</td><td>0.2978</td></tr><tr style='background-color: #cce5ff'><td>mean concave points</td><td>0.0382</td><td>0.2012</td><td>0.1630</td></tr><tr style='background-color: #cce5ff'><td>mean symmetry</td><td>0.1925</td><td>0.3040</td><td>0.1115</td></tr><tr style='background-color: #cce5ff'><td>mean fractal dimension</td><td>0.0637</td><td>0.0974</td><td>0.0337</td></tr><tr style='background-color: #cce5ff'><td>radius error</td><td>0.3961</td><td>0.3825</td><td>-0.0136</td></tr><tr style='background-color: #cce5ff'><td>texture error</td><td>1.0440</td><td>0.7442</td><td>-0.2998</td></tr><tr style='background-color: #cce5ff'><td>perimeter error</td><td>2.4970</td><td>2.7863</td><td>0.2893</td></tr><tr style='background-color: #cce5ff'><td>area error</td><td>30.2900</td><td>30.5904</td><td>0.3004</td></tr><tr style='background-color: #cce5ff'><td>smoothness error</td><td>0.0070</td><td>0.0281</td><td>0.0211</td></tr><tr style='background-color: #cce5ff'><td>compactness error</td><td>0.0191</td><td>0.0023</td><td>-0.0169</td></tr><tr style='background-color: #cce5ff'><td>concavity error</td><td>0.0270</td><td>0.0544</td><td>0.0274</td></tr><tr style='background-color: #cce5ff'><td>concave points error</td><td>0.0104</td><td>0.0427</td><td>0.0323</td></tr><tr style='background-color: #cce5ff'><td>symmetry error</td><td>0.0178</td><td>0.0615</td><td>0.0436</td></tr><tr style='background-color: #cce5ff'><td>fractal dimension error</td><td>0.0036</td><td>0.0009</td><td>-0.0027</td></tr><tr style='background-color: #cce5ff'><td>worst radius</td><td>14.9700</td><td>14.6699</td><td>-0.3001</td></tr><tr style='background-color: #cce5ff'><td>worst texture</td><td>24.6400</td><td>24.9401</td><td>0.3001</td></tr><tr style='background-color: #cce5ff'><td>worst perimeter</td><td>96.0500</td><td>96.3499</td><td>0.2999</td></tr><tr style='background-color: #cce5ff'><td>worst area</td><td>677.9000</td><td>678.2000</td><td>0.3000</td></tr><tr style='background-color: #cce5ff'><td>worst smoothness</td><td>0.1426</td><td>0.2184</td><td>0.0758</td></tr><tr style='background-color: #cce5ff'><td>worst compactness</td><td>0.2378</td><td>0.5377</td><td>0.2999</td></tr><tr style='background-color: #cce5ff'><td>worst concavity</td><td>0.2671</td><td>0.5673</td><td>0.3002</td></tr><tr style='background-color: #cce5ff'><td>worst concave points</td><td>0.1015</td><td>0.2910</td><td>0.1895</td></tr><tr style='background-color: #cce5ff'><td>worst symmetry</td><td>0.3014</td><td>0.5997</td><td>0.2983</td></tr><tr style='background-color: #cce5ff'><td>worst fractal dimension</td><td>0.0875</td><td>0.1730</td><td>0.0855</td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Prediction Information:\n",
      "  Predicted Class: 0\n",
      "  Predicted Probability of Target Class: 0.8504\n",
      "  Robust Probability (Worst Case) of Target Class: 0.5233\n",
      "\n",
      "Constraint Verification:\n",
      "  Mean Texture: 18.3005 (Allowed: [10.0, 25.0])\n",
      "  Mean Area: 482.2005 (Original: 481.9000)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/bt4811/anaconda3/envs/ellice/lib/python3.12/site-packages/sklearn/utils/validation.py:2749: UserWarning: X does not have valid feature names, but LogisticRegression was fitted with feature names\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "# Setup constraints\n",
    "ranges = {'mean texture': [10.0, 25.0]}\n",
    "one_way = {'mean area': 'increase'}\n",
    "\n",
    "cf_constrained = exp.generate_counterfactuals(\n",
    "    query_instances=query,\n",
    "    method='continuous',\n",
    "    target_class=target_class,\n",
    "    permitted_range=ranges,\n",
    "    one_way_change=one_way,\n",
    "    robustness_epsilon=robustness_epsilon,\n",
    "    regularization_coefficient=regularization_coefficient\n",
    ")\n",
    "\n",
    "if not cf_constrained.empty:\n",
    "    print(\"\\nConstrained Counterfactual (Range & One-Way):\")\n",
    "    display_query_vs_cf(query, cf_constrained.iloc[0], data.feature_names,\n",
    "                       explainer=exp, target_class=target_class,\n",
    "                       robustness_epsilon=robustness_epsilon, regularization_coefficient=regularization_coefficient)\n",
    "    \n",
    "    # Verify constraints\n",
    "    res = cf_constrained.iloc[0]\n",
    "    print(f\"\\nConstraint Verification:\")\n",
    "    print(f\"  Mean Texture: {res['mean texture']:.4f} (Allowed: {ranges['mean texture']})\")\n",
    "    print(f\"  Mean Area: {res['mean area']:.4f} (Original: {query['mean area']:.4f})\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a5ddefa",
   "metadata": {},
   "source": [
    "### Categorical Features (One-Hot Encoding)\n",
    "To demonstrate this, let's modify our dataset to include a categorical feature by binning 'mean smoothness' into 'Low', 'Medium', 'High'."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "80b8cdfc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "One-hot groups: [['smoothness_Low', 'smoothness_Medium', 'smoothness_High']]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/bt4811/anaconda3/envs/ellice/lib/python3.12/site-packages/sklearn/utils/validation.py:2749: UserWarning: X does not have valid feature names, but LogisticRegression was fitted with feature names\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Categorical Model Accuracy: 0.9737\n",
      "Progress Bar Enabled\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating CF:   1%|          | 7/1000 [00:00<00:01, 769.92it/s, Prob=0.942, RobLogit=0.210, BestRobLogit=0.210]  \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Categorical Features Counterfactual:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/bt4811/anaconda3/envs/ellice/lib/python3.12/site-packages/sklearn/utils/validation.py:2749: UserWarning: X does not have valid feature names, but LogisticRegression was fitted with feature names\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table border='1' style='border-collapse: collapse;'><tr><th>Feature</th><th>Original</th><th>Counterfactual</th><th>Change</th></tr><tr style='background-color: #cce5ff'><td>mean radius</td><td>12.4700</td><td>11.7849</td><td>-0.6851</td></tr><tr style='background-color: #cce5ff'><td>mean texture</td><td>18.6000</td><td>17.9027</td><td>-0.6973</td></tr><tr style='background-color: #cce5ff'><td>mean perimeter</td><td>81.0900</td><td>81.5300</td><td>0.4400</td></tr><tr style='background-color: #cce5ff'><td>mean area</td><td>481.9000</td><td>482.3299</td><td>0.4299</td></tr><tr style='background-color: #cce5ff'><td>mean compactness</td><td>0.1058</td><td>0.3049</td><td>0.1991</td></tr><tr style='background-color: #cce5ff'><td>mean concavity</td><td>0.0800</td><td>0.4268</td><td>0.3468</td></tr><tr style='background-color: #cce5ff'><td>mean concave points</td><td>0.0382</td><td>0.2012</td><td>0.1630</td></tr><tr style='background-color: #cce5ff'><td>mean symmetry</td><td>0.1925</td><td>0.3040</td><td>0.1115</td></tr><tr style='background-color: #cce5ff'><td>mean fractal dimension</td><td>0.0637</td><td>0.0852</td><td>0.0215</td></tr><tr style='background-color: #cce5ff'><td>radius error</td><td>0.3961</td><td>0.4844</td><td>0.0883</td></tr><tr style='background-color: #cce5ff'><td>texture error</td><td>1.0440</td><td>0.3602</td><td>-0.6838</td></tr><tr style='background-color: #cce5ff'><td>perimeter error</td><td>2.4970</td><td>2.5338</td><td>0.0368</td></tr><tr style='background-color: #cce5ff'><td>area error</td><td>30.2900</td><td>30.9938</td><td>0.7038</td></tr><tr style='background-color: #cce5ff'><td>smoothness error</td><td>0.0070</td><td>0.0311</td><td>0.0242</td></tr><tr style='background-color: #cce5ff'><td>compactness error</td><td>0.0191</td><td>0.0023</td><td>-0.0169</td></tr><tr style='background-color: #cce5ff'><td>concavity error</td><td>0.0270</td><td>0.0957</td><td>0.0687</td></tr><tr style='background-color: #cce5ff'><td>concave points error</td><td>0.0104</td><td>0.0528</td><td>0.0424</td></tr><tr style='background-color: #cce5ff'><td>symmetry error</td><td>0.0178</td><td>0.0615</td><td>0.0436</td></tr><tr style='background-color: #cce5ff'><td>fractal dimension error</td><td>0.0036</td><td>0.0009</td><td>-0.0027</td></tr><tr style='background-color: #cce5ff'><td>worst radius</td><td>14.9700</td><td>14.3635</td><td>-0.6065</td></tr><tr style='background-color: #cce5ff'><td>worst texture</td><td>24.6400</td><td>25.3438</td><td>0.7038</td></tr><tr style='background-color: #cce5ff'><td>worst perimeter</td><td>96.0500</td><td>96.7545</td><td>0.7045</td></tr><tr style='background-color: #cce5ff'><td>worst area</td><td>677.9000</td><td>678.5945</td><td>0.6945</td></tr><tr style='background-color: #cce5ff'><td>worst smoothness</td><td>0.1426</td><td>0.2184</td><td>0.0758</td></tr><tr style='background-color: #cce5ff'><td>worst compactness</td><td>0.2378</td><td>0.9281</td><td>0.6903</td></tr><tr style='background-color: #cce5ff'><td>worst concavity</td><td>0.2671</td><td>0.9700</td><td>0.7029</td></tr><tr style='background-color: #cce5ff'><td>worst concave points</td><td>0.1015</td><td>0.2910</td><td>0.1895</td></tr><tr style='background-color: #cce5ff'><td>worst symmetry</td><td>0.3014</td><td>0.6638</td><td>0.3624</td></tr><tr style='background-color: #cce5ff'><td>worst fractal dimension</td><td>0.0875</td><td>0.1730</td><td>0.0855</td></tr><tr style='background-color: #ffffff'><td>smoothness_Low</td><td>0.0000</td><td>0.0000</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>smoothness_Medium</td><td>1.0000</td><td>1.0000</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>smoothness_High</td><td>0.0000</td><td>0.0000</td><td>0.0000</td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/bt4811/anaconda3/envs/ellice/lib/python3.12/site-packages/sklearn/utils/validation.py:2749: UserWarning: X does not have valid feature names, but LogisticRegression was fitted with feature names\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Prediction Information:\n",
      "  Predicted Class: 0\n",
      "  Predicted Probability of Target Class: 0.9417\n",
      "  Robust Probability (Worst Case) of Target Class: 0.5523\n",
      "\n",
      "Categorical Feature Values (Should sum to 1):\n",
      "smoothness_Low       0.0\n",
      "smoothness_Medium    1.0\n",
      "smoothness_High      0.0\n",
      "Name: 0, dtype: float64\n",
      "Sum: 1.0\n"
     ]
    }
   ],
   "source": [
    "# Create modified dataset with categorical feature\n",
    "X_cat = X.copy()\n",
    "X_cat['smoothness_cat'] = pd.cut(X_cat['mean smoothness'], bins=3, labels=['Low', 'Medium', 'High'])\n",
    "X_cat = X_cat.drop(columns=['mean smoothness'])\n",
    "X_cat_encoded = pd.get_dummies(X_cat, columns=['smoothness_cat'], prefix='smoothness', dtype=float)\n",
    "\n",
    "# Identify one-hot columns\n",
    "one_hot_cols = [c for c in X_cat_encoded.columns if c.startswith('smoothness_')]\n",
    "one_hot_groups = [one_hot_cols]\n",
    "print(\"One-hot groups:\", one_hot_groups)\n",
    "\n",
    "# Retrain model on new data\n",
    "X_train_c, X_test_c, y_train_c, y_test_c = train_test_split(X_cat_encoded, y, test_size=0.2, random_state=42)\n",
    "clf_c = LogisticRegression(max_iter=5000, solver='lbfgs').fit(X_train_c, y_train_c)\n",
    "print(f\"Categorical Model Accuracy: {clf_c.score(X_test_c, y_test_c):.4f}\")\n",
    "\n",
    "# New Explainer\n",
    "full_df_c = X_train_c.copy()\n",
    "full_df_c['target'] = y_train_c\n",
    "data_c = ellice.Data(dataframe=full_df_c, target_column='target')\n",
    "exp_c = ellice.Explainer(clf_c, data_c, backend='sklearn')\n",
    "\n",
    "# Generate CF with categorical handling\n",
    "query_c = X_test_c.iloc[0]\n",
    "cf_cat = exp_c.generate_counterfactuals(\n",
    "    query_instances=query_c,\n",
    "    method='continuous',\n",
    "    target_class=1 - clf_c.predict([query_c])[0],\n",
    "    robustness_epsilon=robustness_epsilon,\n",
    "    regularization_coefficient=regularization_coefficient,\n",
    "    one_hot_groups=one_hot_groups\n",
    ")\n",
    "\n",
    "if not cf_cat.empty:\n",
    "    print(\"\\nCategorical Features Counterfactual:\")\n",
    "    target_class_c = 1 - clf_c.predict([query_c])[0]\n",
    "    display_query_vs_cf(query_c, cf_cat.iloc[0], data_c.feature_names,\n",
    "                       explainer=exp_c, target_class=target_class_c,\n",
    "                       robustness_epsilon=robustness_epsilon, regularization_coefficient=regularization_coefficient)\n",
    "    \n",
    "    print(\"\\nCategorical Feature Values (Should sum to 1):\")\n",
    "    print(cf_cat.iloc[0][one_hot_cols])\n",
    "    print(\"Sum:\", cf_cat.iloc[0][one_hot_cols].sum())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a54dbe82",
   "metadata": {},
   "source": [
    "## 4. Generators <a id=\"generators\"></a>\n",
    "\n",
    "### Continuous Generator with Sparsity\n",
    "Try to minimize the number of features changed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b9d3d645",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Sparse Optimization...\n",
      "Valid CF found with 2 active features (or groups).\n",
      "\n",
      "Sparse Counterfactual (Minimal Features Changed):\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table border='1' style='border-collapse: collapse;'><tr><th>Feature</th><th>Original</th><th>Counterfactual</th><th>Change</th></tr><tr style='background-color: #cce5ff'><td>mean radius</td><td>12.4700</td><td>11.3655</td><td>-1.1045</td></tr><tr style='background-color: #ffffff'><td>mean texture</td><td>18.6000</td><td>18.6000</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>mean perimeter</td><td>81.0900</td><td>81.0900</td><td>-0.0000</td></tr><tr style='background-color: #ffffff'><td>mean area</td><td>481.9000</td><td>481.9000</td><td>-0.0000</td></tr><tr style='background-color: #ffffff'><td>mean smoothness</td><td>0.0997</td><td>0.0997</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>mean compactness</td><td>0.1058</td><td>0.1058</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>mean concavity</td><td>0.0800</td><td>0.0800</td><td>-0.0000</td></tr><tr style='background-color: #ffffff'><td>mean concave points</td><td>0.0382</td><td>0.0382</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>mean symmetry</td><td>0.1925</td><td>0.1925</td><td>-0.0000</td></tr><tr style='background-color: #ffffff'><td>mean fractal dimension</td><td>0.0637</td><td>0.0637</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>radius error</td><td>0.3961</td><td>0.3961</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>texture error</td><td>1.0440</td><td>1.0440</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>perimeter error</td><td>2.4970</td><td>2.4970</td><td>-0.0000</td></tr><tr style='background-color: #ffffff'><td>area error</td><td>30.2900</td><td>30.2900</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>smoothness error</td><td>0.0070</td><td>0.0070</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>compactness error</td><td>0.0191</td><td>0.0191</td><td>-0.0000</td></tr><tr style='background-color: #ffffff'><td>concavity error</td><td>0.0270</td><td>0.0270</td><td>-0.0000</td></tr><tr style='background-color: #ffffff'><td>concave points error</td><td>0.0104</td><td>0.0104</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>symmetry error</td><td>0.0178</td><td>0.0178</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>fractal dimension error</td><td>0.0036</td><td>0.0036</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>worst radius</td><td>14.9700</td><td>14.9700</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>worst texture</td><td>24.6400</td><td>24.6400</td><td>-0.0000</td></tr><tr style='background-color: #ffffff'><td>worst perimeter</td><td>96.0500</td><td>96.0500</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>worst area</td><td>677.9000</td><td>677.9000</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>worst smoothness</td><td>0.1426</td><td>0.1426</td><td>-0.0000</td></tr><tr style='background-color: #ffffff'><td>worst compactness</td><td>0.2378</td><td>0.2378</td><td>0.0000</td></tr><tr style='background-color: #cce5ff'><td>worst concavity</td><td>0.2671</td><td>1.2520</td><td>0.9849</td></tr><tr style='background-color: #ffffff'><td>worst concave points</td><td>0.1015</td><td>0.1015</td><td>-0.0000</td></tr><tr style='background-color: #ffffff'><td>worst symmetry</td><td>0.3014</td><td>0.3014</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>worst fractal dimension</td><td>0.0875</td><td>0.0875</td><td>-0.0000</td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Prediction Information:\n",
      "  Predicted Class: 0\n",
      "  Predicted Probability of Target Class: 0.9110\n",
      "  Robust Probability (Worst Case) of Target Class: 0.5189\n",
      "\n",
      "Number of features changed: 2\n",
      "Changed features: ['mean radius', 'worst concavity']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/bt4811/anaconda3/envs/ellice/lib/python3.12/site-packages/sklearn/utils/validation.py:2749: UserWarning: X does not have valid feature names, but LogisticRegression was fitted with feature names\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "cf_sparse = exp.generate_counterfactuals(\n",
    "    query_instances=query,\n",
    "    method='continuous',\n",
    "    target_class=target_class,\n",
    "    sparsity=True,  # Enable sparsity\n",
    "    robustness_epsilon=robustness_epsilon,\n",
    "    regularization_coefficient=regularization_coefficient\n",
    ")\n",
    "\n",
    "if not cf_sparse.empty:\n",
    "    print(\"\\nSparse Counterfactual (Minimal Features Changed):\")\n",
    "    display_query_vs_cf(query, cf_sparse.iloc[0], data.feature_names,\n",
    "                      explainer=exp, target_class=target_class,\n",
    "                      robustness_epsilon=robustness_epsilon, regularization_coefficient=regularization_coefficient)\n",
    "    \n",
    "    diff = cf_sparse.iloc[0][data.feature_names] - query\n",
    "    changed = diff[diff.abs() > 1e-4]\n",
    "    print(f\"\\nNumber of features changed: {len(changed)}\")\n",
    "    print(\"Changed features:\", changed.index.tolist())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With freezed features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Sparse Optimization...\n",
      "Valid CF found with 6 active features (or groups).\n",
      "\n",
      "Sparse Counterfactual (Minimal Features Changed):\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table border='1' style='border-collapse: collapse;'><tr><th>Feature</th><th>Original</th><th>Counterfactual</th><th>Change</th></tr><tr style='background-color: #ffffff'><td>mean radius</td><td>12.4700</td><td>12.4700</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>mean texture</td><td>18.6000</td><td>18.6000</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>mean perimeter</td><td>81.0900</td><td>81.0900</td><td>-0.0000</td></tr><tr style='background-color: #ffffff'><td>mean area</td><td>481.9000</td><td>481.9000</td><td>-0.0000</td></tr><tr style='background-color: #ffffff'><td>mean smoothness</td><td>0.0997</td><td>0.0997</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>mean compactness</td><td>0.1058</td><td>0.1058</td><td>0.0000</td></tr><tr style='background-color: #cce5ff'><td>mean concavity</td><td>0.0800</td><td>0.4268</td><td>0.3468</td></tr><tr style='background-color: #ffffff'><td>mean concave points</td><td>0.0382</td><td>0.0382</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>mean symmetry</td><td>0.1925</td><td>0.1925</td><td>-0.0000</td></tr><tr style='background-color: #ffffff'><td>mean fractal dimension</td><td>0.0637</td><td>0.0637</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>radius error</td><td>0.3961</td><td>0.3961</td><td>0.0000</td></tr><tr style='background-color: #cce5ff'><td>texture error</td><td>1.0440</td><td>0.3602</td><td>-0.6838</td></tr><tr style='background-color: #ffffff'><td>perimeter error</td><td>2.4970</td><td>2.4970</td><td>-0.0000</td></tr><tr style='background-color: #ffffff'><td>area error</td><td>30.2900</td><td>30.2900</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>smoothness error</td><td>0.0070</td><td>0.0070</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>compactness error</td><td>0.0191</td><td>0.0191</td><td>-0.0000</td></tr><tr style='background-color: #ffffff'><td>concavity error</td><td>0.0270</td><td>0.0270</td><td>-0.0000</td></tr><tr style='background-color: #ffffff'><td>concave points error</td><td>0.0104</td><td>0.0104</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>symmetry error</td><td>0.0178</td><td>0.0178</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>fractal dimension error</td><td>0.0036</td><td>0.0036</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>worst radius</td><td>14.9700</td><td>14.9700</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>worst texture</td><td>24.6400</td><td>24.6400</td><td>-0.0000</td></tr><tr style='background-color: #ffffff'><td>worst perimeter</td><td>96.0500</td><td>96.0500</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>worst area</td><td>677.9000</td><td>677.9000</td><td>0.0000</td></tr><tr style='background-color: #ffffff'><td>worst smoothness</td><td>0.1426</td><td>0.1426</td><td>-0.0000</td></tr><tr style='background-color: #cce5ff'><td>worst compactness</td><td>0.2378</td><td>0.9379</td><td>0.7001</td></tr><tr style='background-color: #cce5ff'><td>worst concavity</td><td>0.2671</td><td>1.1689</td><td>0.9018</td></tr><tr style='background-color: #cce5ff'><td>worst concave points</td><td>0.1015</td><td>0.2910</td><td>0.1895</td></tr><tr style='background-color: #cce5ff'><td>worst symmetry</td><td>0.3014</td><td>0.6638</td><td>0.3624</td></tr><tr style='background-color: #ffffff'><td>worst fractal dimension</td><td>0.0875</td><td>0.0875</td><td>-0.0000</td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Prediction Information:\n",
      "  Predicted Class: 0\n",
      "  Predicted Probability of Target Class: 0.9073\n",
      "  Robust Probability (Worst Case) of Target Class: 0.5073\n",
      "\n",
      "Number of features changed: 6\n",
      "Changed features: ['mean concavity', 'texture error', 'worst compactness', 'worst concavity', 'worst concave points', 'worst symmetry']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/bt4811/anaconda3/envs/ellice/lib/python3.12/site-packages/sklearn/utils/validation.py:2749: UserWarning: X does not have valid feature names, but LogisticRegression was fitted with feature names\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "feature_to_freeze = ['mean radius']\n",
    "features_to_vary = [col for col in X.columns if col not in feature_to_freeze]\n",
    "\n",
    "cf_sparse = exp.generate_counterfactuals(\n",
    "    query_instances=query,\n",
    "    method='continuous',\n",
    "    target_class=target_class,\n",
    "    sparsity=True,  # Enable sparsity\n",
    "    robustness_epsilon=robustness_epsilon,\n",
    "    regularization_coefficient=regularization_coefficient,\n",
    "    features_to_vary=features_to_vary\n",
    ")\n",
    "\n",
    "if not cf_sparse.empty:\n",
    "    print(\"\\nSparse Counterfactual (Minimal Features Changed):\")\n",
    "    display_query_vs_cf(query, cf_sparse.iloc[0], data.feature_names,\n",
    "                      explainer=exp, target_class=target_class,\n",
    "                      robustness_epsilon=robustness_epsilon, regularization_coefficient=regularization_coefficient)\n",
    "    \n",
    "    diff = cf_sparse.iloc[0][data.feature_names] - query\n",
    "    changed = diff[diff.abs() > 1e-4]\n",
    "    print(f\"\\nNumber of features changed: {len(changed)}\")\n",
    "    print(\"Changed features:\", changed.index.tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9faa4b59",
   "metadata": {},
   "outputs": [],
   "source": [
    "# cf_sparse.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ece63eed",
   "metadata": {},
   "source": [
    "### Data-Supported Generator\n",
    "Finds a counterfactual from the actual training data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "dee3f517",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Data-Supported Counterfactual:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table border='1' style='border-collapse: collapse;'><tr><th>Feature</th><th>Original</th><th>Counterfactual</th><th>Change</th></tr><tr style='background-color: #cce5ff'><td>mean radius</td><td>12.4700</td><td>13.0000</td><td>0.5300</td></tr><tr style='background-color: #cce5ff'><td>mean texture</td><td>18.6000</td><td>21.8200</td><td>3.2200</td></tr><tr style='background-color: #cce5ff'><td>mean perimeter</td><td>81.0900</td><td>87.5000</td><td>6.4100</td></tr><tr style='background-color: #cce5ff'><td>mean area</td><td>481.9000</td><td>519.8000</td><td>37.9000</td></tr><tr style='background-color: #cce5ff'><td>mean smoothness</td><td>0.0997</td><td>0.1273</td><td>0.0276</td></tr><tr style='background-color: #cce5ff'><td>mean compactness</td><td>0.1058</td><td>0.1932</td><td>0.0874</td></tr><tr style='background-color: #cce5ff'><td>mean concavity</td><td>0.0800</td><td>0.1859</td><td>0.1059</td></tr><tr style='background-color: #cce5ff'><td>mean concave points</td><td>0.0382</td><td>0.0935</td><td>0.0553</td></tr><tr style='background-color: #cce5ff'><td>mean symmetry</td><td>0.1925</td><td>0.2350</td><td>0.0425</td></tr><tr style='background-color: #cce5ff'><td>mean fractal dimension</td><td>0.0637</td><td>0.0739</td><td>0.0102</td></tr><tr style='background-color: #cce5ff'><td>radius error</td><td>0.3961</td><td>0.3063</td><td>-0.0898</td></tr><tr style='background-color: #cce5ff'><td>texture error</td><td>1.0440</td><td>1.0020</td><td>-0.0420</td></tr><tr style='background-color: #cce5ff'><td>perimeter error</td><td>2.4970</td><td>2.4060</td><td>-0.0910</td></tr><tr style='background-color: #cce5ff'><td>area error</td><td>30.2900</td><td>24.3200</td><td>-5.9700</td></tr><tr style='background-color: #cce5ff'><td>smoothness error</td><td>0.0070</td><td>0.0057</td><td>-0.0012</td></tr><tr style='background-color: #cce5ff'><td>compactness error</td><td>0.0191</td><td>0.0350</td><td>0.0159</td></tr><tr style='background-color: #cce5ff'><td>concavity error</td><td>0.0270</td><td>0.0355</td><td>0.0085</td></tr><tr style='background-color: #cce5ff'><td>concave points error</td><td>0.0104</td><td>0.0123</td><td>0.0019</td></tr><tr style='background-color: #cce5ff'><td>symmetry error</td><td>0.0178</td><td>0.0214</td><td>0.0036</td></tr><tr style='background-color: #cce5ff'><td>fractal dimension error</td><td>0.0036</td><td>0.0037</td><td>0.0002</td></tr><tr style='background-color: #cce5ff'><td>worst radius</td><td>14.9700</td><td>15.4900</td><td>0.5200</td></tr><tr style='background-color: #cce5ff'><td>worst texture</td><td>24.6400</td><td>30.7300</td><td>6.0900</td></tr><tr style='background-color: #cce5ff'><td>worst perimeter</td><td>96.0500</td><td>106.2000</td><td>10.1500</td></tr><tr style='background-color: #cce5ff'><td>worst area</td><td>677.9000</td><td>739.3000</td><td>61.4000</td></tr><tr style='background-color: #cce5ff'><td>worst smoothness</td><td>0.1426</td><td>0.1703</td><td>0.0277</td></tr><tr style='background-color: #cce5ff'><td>worst compactness</td><td>0.2378</td><td>0.5401</td><td>0.3023</td></tr><tr style='background-color: #cce5ff'><td>worst concavity</td><td>0.2671</td><td>0.5390</td><td>0.2719</td></tr><tr style='background-color: #cce5ff'><td>worst concave points</td><td>0.1015</td><td>0.2060</td><td>0.1045</td></tr><tr style='background-color: #cce5ff'><td>worst symmetry</td><td>0.3014</td><td>0.4378</td><td>0.1364</td></tr><tr style='background-color: #cce5ff'><td>worst fractal dimension</td><td>0.0875</td><td>0.1072</td><td>0.0197</td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Prediction Information:\n",
      "  Predicted Class: 0\n",
      "  Predicted Probability of Target Class: 0.9076\n",
      "  Robust Probability (Worst Case) of Target Class: 0.6839\n",
      "\n",
      "Is this a real data point? True\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/bt4811/anaconda3/envs/ellice/lib/python3.12/site-packages/sklearn/utils/validation.py:2749: UserWarning: X does not have valid feature names, but LogisticRegression was fitted with feature names\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "cf_data = exp.generate_counterfactuals(\n",
    "    query_instances=query,\n",
    "    method='data_supported',\n",
    "    search_mode='kdtree',  # Fast search\n",
    "    target_class=target_class\n",
    ")\n",
    "\n",
    "if not cf_data.empty:\n",
    "    print(\"\\nData-Supported Counterfactual:\")\n",
    "    display_query_vs_cf(query, cf_data.iloc[0], data.feature_names,\n",
    "                      explainer=exp, target_class=target_class)\n",
    "    \n",
    "    # Verify it's a real point (check index or exact match)\n",
    "    is_real = (X_train == cf_data.iloc[0][data.feature_names]).all(axis=1).any()\n",
    "    print(f\"\\nIs this a real data point? {is_real}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88963855",
   "metadata": {},
   "source": [
    "## 5. Custom Backend (PyTorch) <a id=\"custom-backend\"></a>\n",
    "\n",
    "Defining a custom model wrapper for a PyTorch model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "fed8e27f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training PyTorch model...\n",
      "Epoch 20/100, Loss: 1.2975, Test Accuracy: 0.6491\n",
      "Epoch 40/100, Loss: 0.5844, Test Accuracy: 0.8158\n",
      "Epoch 60/100, Loss: 0.3606, Test Accuracy: 0.9035\n",
      "Epoch 80/100, Loss: 0.3439, Test Accuracy: 0.9123\n",
      "Epoch 100/100, Loss: 0.3224, Test Accuracy: 0.9123\n",
      "Training complete! Final Accuracy: 0.9123\n",
      "Wrapper initialized successfully!\n",
      "Progress Bar Enabled\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating CF:  10%|▉         | 97/1000 [00:00<00:00, 915.76it/s, Prob=0.703, RobLogit=0.014, BestRobLogit=0.014]   "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Custom PyTorch Backend Counterfactual:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table border='1' style='border-collapse: collapse;'><tr><th>Feature</th><th>Original</th><th>Counterfactual</th><th>Change</th></tr><tr style='background-color: #cce5ff'><td>mean radius</td><td>12.4700</td><td>12.2708</td><td>-0.1992</td></tr><tr style='background-color: #cce5ff'><td>mean texture</td><td>18.6000</td><td>10.0003</td><td>-8.5997</td></tr><tr style='background-color: #cce5ff'><td>mean perimeter</td><td>81.0900</td><td>72.2290</td><td>-8.8610</td></tr><tr style='background-color: #cce5ff'><td>mean area</td><td>481.9000</td><td>490.9725</td><td>9.0725</td></tr><tr style='background-color: #cce5ff'><td>mean smoothness</td><td>0.0997</td><td>0.0526</td><td>-0.0470</td></tr><tr style='background-color: #cce5ff'><td>mean compactness</td><td>0.1058</td><td>0.0194</td><td>-0.0864</td></tr><tr style='background-color: #cce5ff'><td>mean concavity</td><td>0.0800</td><td>0.4268</td><td>0.3468</td></tr><tr style='background-color: #cce5ff'><td>mean concave points</td><td>0.0382</td><td>0.2012</td><td>0.1630</td></tr><tr style='background-color: #cce5ff'><td>mean symmetry</td><td>0.1925</td><td>0.1167</td><td>-0.0758</td></tr><tr style='background-color: #cce5ff'><td>mean fractal dimension</td><td>0.0637</td><td>0.0974</td><td>0.0337</td></tr><tr style='background-color: #cce5ff'><td>radius error</td><td>0.3961</td><td>2.8730</td><td>2.4769</td></tr><tr style='background-color: #cce5ff'><td>texture error</td><td>1.0440</td><td>4.8850</td><td>3.8410</td></tr><tr style='background-color: #cce5ff'><td>perimeter error</td><td>2.4970</td><td>0.7570</td><td>-1.7400</td></tr><tr style='background-color: #cce5ff'><td>area error</td><td>30.2900</td><td>23.9754</td><td>-6.3146</td></tr><tr style='background-color: #cce5ff'><td>smoothness error</td><td>0.0070</td><td>0.0017</td><td>-0.0052</td></tr><tr style='background-color: #cce5ff'><td>compactness error</td><td>0.0191</td><td>0.1354</td><td>0.1163</td></tr><tr style='background-color: #cce5ff'><td>concavity error</td><td>0.0270</td><td>0.0000</td><td>-0.0270</td></tr><tr style='background-color: #cce5ff'><td>concave points error</td><td>0.0104</td><td>0.0528</td><td>0.0424</td></tr><tr style='background-color: #cce5ff'><td>symmetry error</td><td>0.0178</td><td>0.0615</td><td>0.0436</td></tr><tr style='background-color: #cce5ff'><td>fractal dimension error</td><td>0.0036</td><td>0.0298</td><td>0.0263</td></tr><tr style='background-color: #cce5ff'><td>worst radius</td><td>14.9700</td><td>12.4191</td><td>-2.5509</td></tr><tr style='background-color: #cce5ff'><td>worst texture</td><td>24.6400</td><td>20.0218</td><td>-4.6182</td></tr><tr style='background-color: #cce5ff'><td>worst perimeter</td><td>96.0500</td><td>86.9703</td><td>-9.0797</td></tr><tr style='background-color: #cce5ff'><td>worst area</td><td>677.9000</td><td>686.7802</td><td>8.8802</td></tr><tr style='background-color: #cce5ff'><td>worst smoothness</td><td>0.1426</td><td>0.2184</td><td>0.0758</td></tr><tr style='background-color: #cce5ff'><td>worst compactness</td><td>0.2378</td><td>0.0273</td><td>-0.2105</td></tr><tr style='background-color: #cce5ff'><td>worst concavity</td><td>0.2671</td><td>0.0000</td><td>-0.2671</td></tr><tr style='background-color: #cce5ff'><td>worst concave points</td><td>0.1015</td><td>0.0000</td><td>-0.1015</td></tr><tr style='background-color: #cce5ff'><td>worst symmetry</td><td>0.3014</td><td>0.1565</td><td>-0.1449</td></tr><tr style='background-color: #cce5ff'><td>worst fractal dimension</td><td>0.0875</td><td>0.0550</td><td>-0.0325</td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Prediction Information:\n",
      "  Predicted Class: 0\n",
      "  Predicted Probability of Target Class: 0.7029\n",
      "  Robust Probability (Worst Case) of Target Class: 0.5035\n"
     ]
    }
   ],
   "source": [
    "from ellice.models.wrappers import ModelWrapper\n",
    "from typing import Tuple\n",
    "\n",
    "# 1. Define PyTorch Model\n",
    "class SimpleNN(nn.Module):\n",
    "    def __init__(self, input_dim):\n",
    "        super().__init__()\n",
    "        self.layer1 = nn.Linear(input_dim, 16)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.layer2 = nn.Linear(16, 1) # Binary output (logits)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.layer1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.layer2(x)\n",
    "        return x\n",
    "\n",
    "# 2. Create Wrapper\n",
    "class MyModelWrapper(ModelWrapper):\n",
    "    def __init__(self, model):\n",
    "        super().__init__(model, backend='custom')\n",
    "        self.model.eval()\n",
    "        \n",
    "    def get_torch_model(self) -> nn.Module:\n",
    "        return self.model\n",
    "        \n",
    "    def split_model(self) -> Tuple[nn.Module, torch.Tensor]:\n",
    "        # Split into penultimate features (layer1+relu) and last layer (layer2)\n",
    "        penult = nn.Sequential(self.model.layer1, self.model.relu)\n",
    "        \n",
    "        # Get last layer params [weights, bias]\n",
    "        last = self.model.layer2\n",
    "        theta = torch.cat([last.weight.detach().view(-1), last.bias.detach()])\n",
    "        return penult, theta\n",
    "        \n",
    "    def predict_proba(self, X: np.ndarray) -> np.ndarray:\n",
    "        device = next(self.model.parameters()).device\n",
    "        X_t = torch.from_numpy(X).float().to(device)\n",
    "        with torch.no_grad():\n",
    "            logits = self.model(X_t)\n",
    "            probs_1 = torch.sigmoid(logits)\n",
    "            probs_0 = 1 - probs_1\n",
    "            return torch.cat([probs_0, probs_1], dim=1).cpu().numpy()\n",
    "\n",
    "# 3. Setup and Train\n",
    "torch_model = SimpleNN(input_dim=X_train.shape[1])\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "torch_model = torch_model.to(device)\n",
    "\n",
    "# Convert data to tensors\n",
    "X_train_t = torch.FloatTensor(X_train.values).to(device)\n",
    "y_train_t = torch.FloatTensor(y_train.values).unsqueeze(1).to(device)\n",
    "X_test_t = torch.FloatTensor(X_test.values).to(device)\n",
    "y_test_t = torch.FloatTensor(y_test.values).unsqueeze(1).to(device)\n",
    "\n",
    "# Training loop\n",
    "criterion = nn.BCEWithLogitsLoss()\n",
    "optimizer = optim.Adam(torch_model.parameters(), lr=0.001)\n",
    "epochs = 100\n",
    "\n",
    "print(\"Training PyTorch model...\")\n",
    "torch_model.train()\n",
    "for epoch in range(epochs):\n",
    "    optimizer.zero_grad()\n",
    "    outputs = torch_model(X_train_t)\n",
    "    loss = criterion(outputs, y_train_t)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    if (epoch + 1) % 20 == 0:\n",
    "        with torch.no_grad():\n",
    "            torch_model.eval()\n",
    "            test_outputs = torch_model(X_test_t)\n",
    "            test_preds = (torch.sigmoid(test_outputs) > 0.5).float()\n",
    "            accuracy = (test_preds == y_test_t).float().mean().item()\n",
    "            print(f\"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}, Test Accuracy: {accuracy:.4f}\")\n",
    "            torch_model.train()\n",
    "\n",
    "torch_model.eval()\n",
    "print(f\"Training complete! Final Accuracy: {accuracy:.4f}\")\n",
    "\n",
    "# 4. Use with ElliCE\n",
    "exp_torch = ellice.Explainer(\n",
    "    model=torch_model,\n",
    "    data=data,\n",
    "    backend='custom',\n",
    "    backend_model_class=MyModelWrapper,\n",
    "    device='auto'\n",
    ")\n",
    "\n",
    "print(\"Wrapper initialized successfully!\")\n",
    "\n",
    "# Generate CF using the custom PyTorch backend\n",
    "cf_torch = exp_torch.generate_counterfactuals(\n",
    "    query_instances=query,\n",
    "    method='continuous',\n",
    "    target_class=target_class,\n",
    "    robustness_epsilon=robustness_epsilon,\n",
    "    regularization_coefficient=regularization_coefficient\n",
    ")\n",
    "\n",
    "if not cf_torch.empty:\n",
    "    print(\"\\nCustom PyTorch Backend Counterfactual:\")\n",
    "    display_query_vs_cf(query, cf_torch.iloc[0], data.feature_names,\n",
    "                       explainer=exp_torch, target_class=target_class,\n",
    "                       robustness_epsilon=robustness_epsilon, regularization_coefficient=regularization_coefficient)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ellice",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
