{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "46a74b66",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==========================================\n",
    "# Imports\n",
    "# ==========================================\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score\n",
    "from anova_module import FullSupportAnova, batch_shapley_values\n",
    "import shap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4592e872",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset downloaded.\n",
      "X matrix shape: torch.Size([12960, 8])\n",
      "\n",
      "Starting training...\n",
      "Epoch 10/30 - Loss: 0.0672\n",
      "Epoch 20/30 - Loss: 0.0319\n",
      "Epoch 30/30 - Loss: 0.0204\n",
      "\n",
      "Test set accuracy: 99.73%\n"
     ]
    }
   ],
   "source": [
    "# ==========================================\n",
    "# 1. Data Acquisition and Preprocessing\n",
    "# ==========================================\n",
    "url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/nursery/nursery.data\"\n",
    "\n",
    "# Load dataset directly (no header provided in source file)\n",
    "# Features: parents, has_nurs, form, children, housing, finance, social, health, class\n",
    "df = pd.read_csv(url, header=None)\n",
    "\n",
    "# Split features (X) and target (y)\n",
    "X_raw = df.iloc[:, :-1]  # First 8 columns (features)\n",
    "y_raw = df.iloc[:, -1]   # Last column (target class)\n",
    "\n",
    "# Encode categorical features to integers (Ordinal Encoding)\n",
    "# Note: Using LabelEncoder instead of OneHot to maintain the specific (N, 8) input dimensionality constraint\n",
    "le = LabelEncoder()\n",
    "X_encoded = X_raw.apply(le.fit_transform)\n",
    "y_encoded = le.fit_transform(y_raw)\n",
    "\n",
    "# Convert to PyTorch tensors\n",
    "X = torch.tensor(X_encoded.values, dtype=torch.float32) # nn.Linear requires Float32 input\n",
    "y = torch.tensor(y_encoded, dtype=torch.long)           # CrossEntropyLoss requires Long (int64) targets\n",
    "\n",
    "print(f\"Dataset downloaded.\")\n",
    "print(f\"X matrix shape: {X.shape}\")  # Should display torch.Size([12960, 8])\n",
    "\n",
    "# Train/Test Split (80% train, 20% test)\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
    "\n",
    "# ==========================================\n",
    "# 2. PyTorch MLP Architecture Definition\n",
    "# ==========================================\n",
    "class NurseryMLP(nn.Module):\n",
    "    def __init__(self, input_dim, output_dim):\n",
    "        super(NurseryMLP, self).__init__()\n",
    "        self.network = nn.Sequential(\n",
    "            nn.Linear(input_dim, 64),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(64, 32),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(32, output_dim)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.network(x)\n",
    "\n",
    "# Model Initialization\n",
    "input_dim = X.shape[1]           # 8 input features\n",
    "output_dim = len(set(y_encoded)) # 5 output classes\n",
    "model = NurseryMLP(input_dim, output_dim)\n",
    "\n",
    "# Loss and Optimizer configuration\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-2)\n",
    "\n",
    "# ==========================================\n",
    "# 3. Training Loop\n",
    "# ==========================================\n",
    "epochs = 30\n",
    "batch_size = 64\n",
    "train_dataset = torch.utils.data.TensorDataset(X_train, y_train)\n",
    "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "print(\"\\nStarting training...\")\n",
    "for epoch in range(epochs):\n",
    "    model.train()\n",
    "    running_loss = 0.0\n",
    "    for inputs, labels in train_loader:\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        running_loss += loss.item()\n",
    "    \n",
    "    if (epoch+1) % 10 == 0:\n",
    "        print(f\"Epoch {epoch+1}/{epochs} - Loss: {running_loss/len(train_loader):.4f}\")\n",
    "\n",
    "# ==========================================\n",
    "# 4. Evaluation (Accuracy)\n",
    "# ==========================================\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    outputs = model(X_test)\n",
    "    _, predicted = torch.max(outputs, 1)\n",
    "    accuracy = accuracy_score(y_test, predicted)\n",
    "\n",
    "print(f\"\\nTest set accuracy: {accuracy*100:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d8dba6f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2min 37s, sys: 4.94 s, total: 2min 42s\n",
      "Wall time: 53.7 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# ==========================================\n",
    "# Functional ANOVA Decomposition\n",
    "# ==========================================\n",
    "\n",
    "X = X_encoded.to_numpy() # Dataset\n",
    "d = X.shape[1] # dimension\n",
    "N = [ X[: , j].max() + 1 for j in range(d) ] # list of categories\n",
    "r = np.prod(N) # full dimension\n",
    "P = 1/r * np.ones( r ) # vector of probabilities\n",
    "\n",
    "def f(i, X_numpy): # Proba( MLP(x) = class i | x )\n",
    "\n",
    "    tensor_input = torch.tensor(X_numpy, dtype=torch.float32)\n",
    "    \n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        \n",
    "        logits = model(tensor_input)      \n",
    "        \n",
    "        probs = torch.softmax(logits, dim=1)\n",
    "        \n",
    "        predictions_class_i = probs[:, i]\n",
    "\n",
    "    return predictions_class_i.cpu().numpy()\n",
    "\n",
    "def f_0(x): # we focus only on class 0\n",
    "    return(f(0 , x))\n",
    "\n",
    "def f_1(x): # we focus only on class 1\n",
    "    return(f(1 , x))\n",
    "\n",
    "def f_2(x): # we focus only on class 2\n",
    "    return(f(2 , x))\n",
    "\n",
    "def f_3(x): # we focus only on class 3\n",
    "    return(f(3 , x))\n",
    "\n",
    "def f_4(x): # we focus only on class 4\n",
    "    return(f(4 , x))\n",
    "\n",
    "F = [f_0 , f_1 , f_2 , f_3 , f_4] # list of functions for each class\n",
    "\n",
    "anova_shap = [] # list of generalized shapley values based on functional anova\n",
    "for f_model in F:\n",
    "    A = FullSupportAnova(N , P , f_model)\n",
    "    S , Matrix = A.get_anova_full() # sets and f_A(X_A)\n",
    "    shap_i = batch_shapley_values(d , S , Matrix) # generalized shapley values matrix for all obs\n",
    "    anova_shap.append(shap_i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a2866888",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==========================================\n",
    "# KernelSHAP\n",
    "# ==========================================\n",
    "\n",
    "n_sample_background = 200\n",
    "background = X[:n_sample_background]\n",
    "\n",
    "# KernelSHAP\n",
    "def kernel_shap(f , X_explain):\n",
    "    explainer = shap.KernelExplainer(f , background)\n",
    "    shap_values = explainer.shap_values(X_explain)\n",
    "    return(shap_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6e9e82b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using 200 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f5c1392d89844217b78529807ea6f2be",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/12959 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using 200 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "85f0871bef7041cd84ff4f9780fa6848",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/12959 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using 200 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d51247f6ff3f4ba0bb6cd7cdf440e083",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/12959 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using 200 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2c1caa0cb80e426c87f974bfc0601c63",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/12959 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using 200 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5c69b45279964b1b8da47ef3bce30437",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/12959 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 7min 12s, sys: 3min 8s, total: 10min 21s\n",
      "Wall time: 5min 6s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# ==========================================\n",
    "# KernelSHAP\n",
    "# ==========================================\n",
    "\n",
    "number = X.shape[0] - 1\n",
    "X_explain = A._generate_tuples()[:number]\n",
    "\n",
    "all_kernel_shap = []\n",
    "for g in F:\n",
    "    kernel_shap_g = kernel_shap(g , X_explain)\n",
    "    all_kernel_shap.append(kernel_shap_g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "dc41cb22",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[3.04847562e-10, 2.40328281e-10, 2.22418924e-10, 7.28466824e-11,\n",
       "        2.75569893e-11, 1.40257661e-11, 4.72785170e-11, 1.11166186e-05],\n",
       "       [1.63788959e-02, 3.69209523e-02, 2.93930668e-03, 2.80880975e-03,\n",
       "        1.71145172e-03, 3.52273307e-04, 3.21433367e-03, 1.78844183e-02],\n",
       "       [2.61574330e-14, 2.95245237e-14, 2.15717993e-14, 3.42115072e-14,\n",
       "        1.18681794e-14, 1.07816684e-14, 1.90154506e-14, 2.53373432e-14],\n",
       "       [1.88924993e-02, 4.42082090e-02, 2.06204039e-03, 4.52357028e-04,\n",
       "        8.21612019e-04, 2.35157069e-04, 7.31613104e-04, 1.67232584e-02],\n",
       "       [2.48433624e-03, 2.84381425e-03, 1.42883943e-03, 1.33636015e-03,\n",
       "        2.56063193e-04, 2.38811581e-05, 9.46781525e-04, 3.20483052e-03]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# ==========================================\n",
    "# Table of MSE\n",
    "# ==========================================\n",
    "\n",
    "P_red = 1/number * np.ones(number)\n",
    "\n",
    "np.array([np.sum(((anova_shap[i][:number , :] - all_kernel_shap[i])**2).T * P_red , axis=1) for i in range(5)])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hfd_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
