{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "75ea33e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==========================================\n",
    "# Imports\n",
    "# ==========================================\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OrdinalEncoder, LabelEncoder\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.metrics import classification_report, accuracy_score\n",
    "from typing import Tuple, List, Any\n",
    "from anova_module import FullSupportAnova\n",
    "from anova_module import batch_shapley_values\n",
    "import shap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c78cdc09",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading data from https://archive.ics.uci.edu/ml/machine-learning-databases/car/car.data...\n",
      "Data loaded successfully. Shape: (1728, 7)\n",
      "Training model...\n",
      "\n",
      "========================================\n",
      "RESULTS (Accuracy: 0.9672)\n",
      "========================================\n",
      "\n",
      "--- Classification Report ---\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "         acc       0.92      0.93      0.93       115\n",
      "        good       0.90      0.90      0.90        21\n",
      "       unacc       0.98      0.98      0.98       363\n",
      "       vgood       1.00      0.95      0.97        20\n",
      "\n",
      "    accuracy                           0.97       519\n",
      "   macro avg       0.95      0.94      0.95       519\n",
      "weighted avg       0.97      0.97      0.97       519\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# ==========================================\n",
    "# Data Loading & Preprocessing\n",
    "# ==========================================\n",
    "\n",
    "def load_and_preprocess_data() -> Tuple[np.ndarray, np.ndarray, OrdinalEncoder, LabelEncoder]:\n",
    "    \"\"\"\n",
    "    Downloads the Car Evaluation dataset from UCI, performs ordinal encoding\n",
    "    on categorical features, and returns processed NumPy arrays.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    X_encoded : np.ndarray\n",
    "        The feature matrix with categorical variables encoded as integers.\n",
    "    y_encoded : np.ndarray\n",
    "        The target vector encoded as integers.\n",
    "    feature_encoder : OrdinalEncoder\n",
    "        The fitted encoder for the features (useful for inverse transformation).\n",
    "    target_encoder : LabelEncoder\n",
    "        The fitted encoder for the target labels.\n",
    "    \"\"\"\n",
    "    # Direct URL to the raw dataset on the UCI repository\n",
    "    url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/car/car.data\"\n",
    "    \n",
    "    # Define column names (the .data file lacks a header)\n",
    "    column_names = ['buying', 'maint', 'doors', 'persons', 'lug_boot', 'safety', 'class']\n",
    "    \n",
    "    print(f\"Downloading data from {url}...\")\n",
    "    df = pd.read_csv(url, names=column_names)\n",
    "    \n",
    "    print(f\"Data loaded successfully. Shape: {df.shape}\")\n",
    "    \n",
    "    # Separate Features (X) and Target (y)\n",
    "    X_raw = df.iloc[:, :-1].values\n",
    "    y_raw = df.iloc[:, -1].values\n",
    "    \n",
    "    # --- Encoding ---\n",
    "    # The raw data consists of strings (e.g., \"vhigh\", \"small\").\n",
    "    # We must convert these to numerical representations for the Random Forest.\n",
    "    \n",
    "    # 1. Feature Encoding\n",
    "    # OrdinalEncoder transforms each categorical feature into integers (0, 1, 2...)\n",
    "    feature_encoder = OrdinalEncoder()\n",
    "    X_encoded = feature_encoder.fit_transform(X_raw)\n",
    "    \n",
    "    # 2. Target Encoding\n",
    "    # LabelEncoder transforms class labels (\"unacc\", \"acc\", etc.) into integers.\n",
    "    target_encoder = LabelEncoder()\n",
    "    y_encoded = target_encoder.fit_transform(y_raw)\n",
    "    \n",
    "    return X_encoded, y_encoded, feature_encoder, target_encoder\n",
    "\n",
    "# ==========================================\n",
    "# Random Forest Training\n",
    "# ==========================================\n",
    "\n",
    "# 1. Data Loading & Preprocessing\n",
    "# We assume load_and_preprocess_data() is defined as in the previous step\n",
    "X, y, enc_X, enc_y = load_and_preprocess_data()\n",
    "\n",
    "# Extract class names for legible reporting (e.g., 'unacc', 'vgood')\n",
    "target_classes = enc_y.classes_ \n",
    "\n",
    "# 2. Experimental Setup: Stratified Train/Test Split\n",
    "# We use a 70/30 split. Stratification is crucial here due to class imbalance.\n",
    "X_train, X_test, y_train, y_test = train_test_split(\n",
    "    X, y, test_size=0.3, random_state=42, stratify=y\n",
    ")\n",
    "\n",
    "# 3. Model Initialization: Random Forest\n",
    "# Random Forests are robust baselines for tabular data with ordinal features.\n",
    "clf = RandomForestClassifier(n_estimators=100, random_state=42)\n",
    "\n",
    "print(\"Training model...\")\n",
    "clf.fit(X_train, y_train)\n",
    "\n",
    "# 4. Inference\n",
    "y_pred = clf.predict(X_test)\n",
    "\n",
    "# 5. Performance Reporting\n",
    "acc = accuracy_score(y_test, y_pred)\n",
    "\n",
    "print(\"\\n\" + \"=\"*40)\n",
    "print(f\"RESULTS (Accuracy: {acc:.4f})\")\n",
    "print(\"=\"*40)\n",
    "\n",
    "# Detailed classification report\n",
    "print(\"\\n--- Classification Report ---\")\n",
    "print(classification_report(y_test, y_pred, target_names=target_classes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9fddcc7c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 421 ms, sys: 28.2 ms, total: 449 ms\n",
      "Wall time: 450 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# ==========================================\n",
    "# Functional ANOVA Decomposition\n",
    "# ==========================================\n",
    "\n",
    "d = X.shape[1] # dimension\n",
    "N = np.array([ X[: , j].max() + 1 for j in range(d) ]) # list of categories\n",
    "r = int(np.prod(N)) # full dimension\n",
    "P = 1/r * np.ones( r ) # vector of probabilities\n",
    "\n",
    "def f_1(x): # class 1\n",
    "    return(clf.predict_proba(x)[:,0])\n",
    "\n",
    "def f_2(x): # class 2\n",
    "    return(clf.predict_proba(x)[:,1])\n",
    "\n",
    "def f_3(x): # class 3\n",
    "    return(clf.predict_proba(x)[:,2])\n",
    "\n",
    "def f_4(x): # class 4\n",
    "    return(clf.predict_proba(x)[:,3])\n",
    "\n",
    "F = [f_1 , f_2 , f_3 , f_4] # list of functions for each class\n",
    "\n",
    "\n",
    "anova_shap = [] # list of generalized shapley values based on functional anova\n",
    "for f_model in F:\n",
    "    A = FullSupportAnova(N , P , f_model)\n",
    "    S , Matrix = A.get_anova_full() # sets and f_A(X_A)\n",
    "    shap_i = batch_shapley_values(d , S , Matrix) # generalized shapley values matrix for all obs\n",
    "    anova_shap.append(shap_i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "837f611d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==========================================\n",
    "# KernelSHAP\n",
    "# ==========================================\n",
    "\n",
    "n_sample_background = 200\n",
    "background = X[:n_sample_background]\n",
    "\n",
    "# KernelSHAP\n",
    "def kernel_shap(f , X_explain):\n",
    "    explainer = shap.KernelExplainer(f , background)\n",
    "    shap_values = explainer.shap_values(X_explain)\n",
    "    return(shap_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "734f8cd7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using 200 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a29d19095b2a40e88371b1b61ad1db26",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1727 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using 200 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f86e41d051c3433a9f76803743efded9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1727 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using 200 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e823ecd94cbb411298accb2807682379",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1727 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using 200 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "22ef4a31d0e5407796cc524f7555f974",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1727 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1min 39s, sys: 2.45 s, total: 1min 42s\n",
      "Wall time: 1min 42s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# ==========================================\n",
    "# KernelSHAP\n",
    "# ==========================================\n",
    "\n",
    "number = r - 1\n",
    "X_explain = A._generate_tuples()[:number]\n",
    "\n",
    "all_kernel_shap = []\n",
    "for g in F:\n",
    "    kernel_shap_g = kernel_shap(g , X_explain)\n",
    "    all_kernel_shap.append(kernel_shap_g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90235669",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[2.01123261e-02, 1.29533867e-02, 9.62057670e-05, 2.75037594e-03,\n",
       "        2.50105313e-04, 2.98341343e-03],\n",
       "       [1.58470991e-03, 1.46948070e-03, 2.69347934e-05, 4.09998114e-04,\n",
       "        1.18857332e-04, 5.11330245e-04],\n",
       "       [2.89357167e-02, 1.45756384e-02, 1.70867649e-04, 7.08700651e-03,\n",
       "        4.56296966e-04, 7.07322515e-03],\n",
       "       [2.07139881e-03, 8.71936191e-04, 6.46691690e-05, 3.68443860e-04,\n",
       "        3.97373917e-04, 1.20824268e-03]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# ==========================================\n",
    "# Table of ISE\n",
    "# ==========================================\n",
    "\n",
    "P_red = 1/number * np.ones(number)\n",
    "\n",
    "np.array([np.sum(((anova_shap[i][:number , :] - all_kernel_shap[i])**2).T * P_red , axis=1) for i in range(4)])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hfd_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
