{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cac3f223",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==========================================\n",
    "# Imports\n",
    "# ==========================================\n",
    "\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from sklearn.metrics import accuracy_score, classification_report\n",
    "from anova_module import ModelAnalysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "774f8298",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading and reading data into memory...\n",
      "Starting training on cpu...\n",
      "Epoch 1: Average Loss = 1.0617\n",
      "Epoch 2: Average Loss = 0.9641\n",
      "Epoch 3: Average Loss = 0.6143\n",
      "Epoch 4: Average Loss = 0.3050\n",
      "Epoch 5: Average Loss = 0.1807\n",
      "Epoch 6: Average Loss = 0.1381\n",
      "Epoch 7: Average Loss = 0.1227\n",
      "Epoch 8: Average Loss = 0.1086\n",
      "Epoch 9: Average Loss = 0.0990\n",
      "Epoch 10: Average Loss = 0.0893\n",
      "\n",
      "--- Test Set Results ---\n",
      "Accuracy: 0.9843\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       0.99      1.00      0.99    501209\n",
      "           1       1.00      1.00      1.00    422498\n",
      "           2       0.86      0.97      0.91     47622\n",
      "           3       0.88      0.64      0.74     21121\n",
      "           4       0.00      0.00      0.00      3885\n",
      "           5       0.00      0.00      0.00      1996\n",
      "           6       0.83      0.75      0.79      1424\n",
      "           7       0.00      0.00      0.00       230\n",
      "           8       0.00      0.00      0.00        12\n",
      "           9       0.00      0.00      0.00         3\n",
      "\n",
      "    accuracy                           0.98   1000000\n",
      "   macro avg       0.46      0.44      0.44   1000000\n",
      "weighted avg       0.98      0.98      0.98   1000000\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# ==========================================\n",
    "# 1. LOADING AND PREPARATION (Direct URL)\n",
    "# ==========================================\n",
    "print(\"Downloading and reading data into memory...\")\n",
    "# Reading directly from UCI URLs\n",
    "train_df = pd.read_csv(\"https://archive.ics.uci.edu/ml/machine-learning-databases/poker/poker-hand-training-true.data\", header=None)\n",
    "test_df = pd.read_csv(\"https://archive.ics.uci.edu/ml/machine-learning-databases/poker/poker-hand-testing.data\", header=None)\n",
    "\n",
    "# Splitting X/y\n",
    "X_train = train_df.iloc[:, :-1].values\n",
    "y_train = train_df.iloc[:, -1].values\n",
    "X_test = test_df.iloc[:, :-1].values\n",
    "y_test = test_df.iloc[:, -1].values\n",
    "\n",
    "# Adjusting indices to start at 0 (required for PyTorch Embedding)\n",
    "# S (Even columns): 1..4 -> 0..3 | C (Odd columns): 1..13 -> 0..12\n",
    "X_train[:, 0::2] -= 1; X_train[:, 1::2] -= 1\n",
    "X_test[:, 0::2] -= 1;  X_test[:, 1::2] -= 1\n",
    "\n",
    "# Conversion to Tensors and DataLoaders\n",
    "BATCH_SIZE = 256\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "train_loader = DataLoader(TensorDataset(torch.LongTensor(X_train), torch.LongTensor(y_train)), batch_size=BATCH_SIZE, shuffle=True)\n",
    "test_loader = DataLoader(TensorDataset(torch.LongTensor(X_test), torch.LongTensor(y_test)), batch_size=BATCH_SIZE)\n",
    "\n",
    "# ==========================================\n",
    "# 2. TRANSFORMER MODEL\n",
    "# ==========================================\n",
    "class TabularTransformer(nn.Module):\n",
    "    def __init__(self, embed_dim=32, num_heads=4, ff_dim=64):\n",
    "        super().__init__()\n",
    "        self.suit_embed = nn.Embedding(4, embed_dim)   # 4 suits\n",
    "        self.rank_embed = nn.Embedding(13, embed_dim)  # 13 ranks\n",
    "        \n",
    "        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim, batch_first=True)\n",
    "        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=3)\n",
    "        \n",
    "        # 10 features * embed_dim input for the classifier\n",
    "        self.classifier = nn.Sequential(\n",
    "            nn.Linear(10 * embed_dim, ff_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(ff_dim, 10) # 10 poker hand classes\n",
    "        )\n",
    "        \n",
    "    def forward(self, x):\n",
    "        # Separation and embedding\n",
    "        s_emb = self.suit_embed(x[:, 0::2]) # (Batch, 5, Dim)\n",
    "        r_emb = self.rank_embed(x[:, 1::2]) # (Batch, 5, Dim)\n",
    "        \n",
    "        # Interleaving to reconstruct the sequence [S1, C1, S2, C2...]\n",
    "        x_emb = torch.zeros(x.shape[0], 10, s_emb.shape[-1], device=x.device)\n",
    "        x_emb[:, 0::2] = s_emb\n",
    "        x_emb[:, 1::2] = r_emb\n",
    "        \n",
    "        # Transformer -> Flatten -> MLP\n",
    "        out = self.transformer(x_emb)\n",
    "        return self.classifier(out.reshape(out.shape[0], -1))\n",
    "\n",
    "model = TabularTransformer().to(device)\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "# ==========================================\n",
    "# 3. TRAINING\n",
    "# ==========================================\n",
    "print(f\"Starting training on {device}...\")\n",
    "model.train()\n",
    "for epoch in range(10): # 10 Epochs are sufficient for testing\n",
    "    total_loss = 0\n",
    "    for inputs, labels in train_loader:\n",
    "        inputs, labels = inputs.to(device), labels.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        loss = criterion(model(inputs), labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        total_loss += loss.item()\n",
    "    print(f\"Epoch {epoch+1}: Average Loss = {total_loss/len(train_loader):.4f}\")\n",
    "\n",
    "# ==========================================\n",
    "# 4. FINAL EVALUATION\n",
    "# ==========================================\n",
    "print(\"\\n--- Test Set Results ---\")\n",
    "model.eval()\n",
    "all_preds, all_targets = [], []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for inputs, labels in test_loader:\n",
    "        inputs = inputs.to(device)\n",
    "        outputs = model(inputs)\n",
    "        _, preds = torch.max(outputs, 1)\n",
    "        all_preds.extend(preds.cpu().numpy())\n",
    "        all_targets.extend(labels.numpy())\n",
    "\n",
    "print(f\"Accuracy: {accuracy_score(all_targets, all_preds):.4f}\")\n",
    "print(classification_report(all_targets, all_preds, zero_division=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b9c8314",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==========================================\n",
    "# Function\n",
    "# ==========================================\n",
    "\n",
    "X_numpy = X_train\n",
    "r , d = X_numpy.shape\n",
    "\n",
    "# Function of interest (proba class 0)\n",
    "def f_model(X_numpy):\n",
    "    \n",
    "    X_tensor = torch.LongTensor(X_numpy).to(device)\n",
    "    \n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        logits = model(X_tensor)      \n",
    "        probs = torch.softmax(logits, dim=1)\n",
    "        \n",
    "        predictions_class_1 = probs[:, 0]\n",
    "\n",
    "    return predictions_class_1.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "473f98d0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Constructing Basis Matrix: 100%|\u001b[32m██████████\u001b[0m| 76/76 [00:00<00:00, 363.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computations complete. Results ready.\n",
      "0.0029647047048213526 0.24470471391357562 0.4925174685588371\n",
      "CPU times: user 3.86 s, sys: 723 ms, total: 4.58 s\n",
      "Wall time: 1.77 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# =============================================\n",
    "# Functional ANOVA Decomposition (MAIN EFFECTS)\n",
    "# =============================================\n",
    "\n",
    "A = ModelAnalysis(X_numpy , f_model , 0.305 , 1 , 1e-4) # percentage = 0.305 to have exactly all main effects\n",
    "S , Matrix = A.functional_anova() # sets and f_A(X_A)\n",
    "print(A.get_R2() , A.get_L2_Error() , A.get_L2_Error_rel())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "92caef24",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Constructing Basis Matrix: 100%|\u001b[32m██████████\u001b[0m| 5001/5001 [11:07<00:00,  7.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computations complete. Results ready.\n",
      "0.7870539727935567 0.052263843529388815 0.10519149999431907\n",
      "CPU times: user 11min 41s, sys: 7min 52s, total: 19min 34s\n",
      "Wall time: 11min 45s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# =============================================\n",
    "# Functional ANOVA Decomposition\n",
    "# =============================================\n",
    "\n",
    "A = ModelAnalysis(X_numpy , f_model , 20 , 1 , 1e-4)\n",
    "S , Matrix = A.functional_anova() # sets and f_A(X_A)\n",
    "print(A.get_R2() , A.get_L2_Error() , A.get_L2_Error_rel())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hfd_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
