{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "671b19d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==========================================\n",
    "# Imports\n",
    "# ==========================================\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from ucimlrepo import fetch_ucirepo\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OrdinalEncoder, LabelEncoder\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.metrics import classification_report, accuracy_score\n",
    "from anova_module import ModelAnalysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9950d3db",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading the Connect-4 dataset via UCI Repo (ID: 26)...\n",
      "Encoding data into integers...\n",
      "Feature matrix shape (n x d): (67557, 42)\n",
      "Data type: int64\n",
      "Training the Random Forest...\n",
      "\n",
      "Accuracy: 0.8192\n",
      "\n",
      "Classification Report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "        draw       0.54      0.13      0.21      1290\n",
      "        loss       0.81      0.70      0.75      3327\n",
      "         win       0.83      0.96      0.89      8895\n",
      "\n",
      "    accuracy                           0.82     13512\n",
      "   macro avg       0.72      0.60      0.62     13512\n",
      "weighted avg       0.80      0.82      0.79     13512\n",
      "\n",
      "\n",
      "--- Useful Information ---\n",
      "Cell mapping (X): {0: 'b', 1: 'o', 2: 'x'}\n",
      "Target mapping (y): {0: 'draw', 1: 'loss', 2: 'win'}\n"
     ]
    }
   ],
   "source": [
    "# ==========================================\n",
    "# Data Loading & Preprocessing\n",
    "# ==========================================\n",
    "\n",
    "# 1. Download the UCI dataset (ID 26)\n",
    "print(\"Downloading the Connect-4 dataset via UCI Repo (ID: 26)...\")\n",
    "connect_4 = fetch_ucirepo(id=26)\n",
    "\n",
    "# Accessing raw data (Pandas DataFrames)\n",
    "X = connect_4.data.features\n",
    "y = connect_4.data.targets\n",
    "\n",
    "# 2. Integer array encoding (n x d)\n",
    "print(\"Encoding data into integers...\")\n",
    "\n",
    "# OrdinalEncoder for the 42 cells (categories: 'b', 'o', 'x')\n",
    "encoder_features = OrdinalEncoder()\n",
    "X_encoded = encoder_features.fit_transform(X).astype(np.int64)\n",
    "\n",
    "# LabelEncoder for the target (categories: 'win', 'loss', 'draw')\n",
    "encoder_target = LabelEncoder()\n",
    "# Using .ravel() to transform the y DataFrame into a flat vector\n",
    "y_encoded = encoder_target.fit_transform(y.values.ravel()).astype(np.int64)\n",
    "\n",
    "print(f\"Feature matrix shape (n x d): {X_encoded.shape}\")\n",
    "print(f\"Data type: {X_encoded.dtype}\")\n",
    "\n",
    "# 3. Train / Test Split\n",
    "# Using random_state=42 for reproducibility and stratify for class imbalance\n",
    "X_train, X_test, y_train, y_test = train_test_split(\n",
    "    X_encoded, y_encoded, test_size=0.2, random_state=42, stratify=y_encoded\n",
    ")\n",
    "\n",
    "# 4. Random Forest Training\n",
    "print(\"Training the Random Forest...\")\n",
    "# n_estimators=100 is a good speed/accuracy trade-off\n",
    "# n_jobs=-1 uses all available CPU cores\n",
    "clf = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)\n",
    "clf.fit(X_train, y_train)\n",
    "\n",
    "# 5. Evaluation\n",
    "y_pred = clf.predict(X_test)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "\n",
    "print(f\"\\nAccuracy: {accuracy:.4f}\")\n",
    "print(\"\\nClassification Report:\")\n",
    "print(classification_report(y_test, y_pred, target_names=encoder_target.classes_))\n",
    "\n",
    "# ==========================================\n",
    "# Data Inspection & Mapping\n",
    "# ==========================================\n",
    "print(\"\\n--- Useful Information ---\")\n",
    "# To see what the numbers 0, 1, 2 in X_encoded correspond to:\n",
    "for i, col_name in enumerate(X.columns[:1]): # Looking at the first column only as an example\n",
    "    mapping_X = dict(zip(range(len(encoder_features.categories_[i])), encoder_features.categories_[i]))\n",
    "    print(f\"Cell mapping (X): {mapping_X}\")\n",
    "\n",
    "# To see what the numbers in y_encoded correspond to:\n",
    "mapping_y = dict(zip(range(len(encoder_target.classes_)), encoder_target.classes_))\n",
    "print(f\"Target mapping (y): {mapping_y}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd4713e7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Constructing Basis Matrix: 100%|\u001b[32m██████████\u001b[0m| 83/83 [00:01<00:00, 78.34it/s] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computations complete. Results ready.\n",
      "0.4491710151975612 0.07301981505256158 0.1291350369813769\n",
      "CPU times: user 51.9 s, sys: 8.06 s, total: 60 s\n",
      "Wall time: 9.43 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# ==============================================\n",
    "# Functional ANOVA Decomposition (MAIN EFFECTS)\n",
    "# ==============================================\n",
    "\n",
    "def f(x): # class 2\n",
    "    return(clf.predict_proba(x)[:,2]) # Proba of winning\n",
    "\n",
    "A = ModelAnalysis(X_encoded , f , 0.124 , 1 , 1e-4) #0.124% of total dimension to have all main effects\n",
    "S , Matrix = A.functional_anova() # sets and f_A(X_A)\n",
    "print(A.get_R2() , A.get_L2_Error() , A.get_L2_Error_rel())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2b7f8d75",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Constructing Basis Matrix: 100%|\u001b[32m██████████\u001b[0m| 5066/5066 [34:12<00:00,  2.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computations complete. Results ready.\n",
      "0.7032825124247271 0.03933390700813979 0.06956174200745231\n",
      "CPU times: user 1h 20min 54s, sys: 39min 3s, total: 1h 59min 58s\n",
      "Wall time: 37min 21s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# ==========================================\n",
    "# Functional ANOVA Decomposition\n",
    "# ==========================================\n",
    "\n",
    "A = ModelAnalysis(X_encoded , f , 7.5 , 1 , 1e-4)\n",
    "S , Matrix = A.functional_anova() # sets and f_A(X_A)\n",
    "print(A.get_R2() , A.get_L2_Error() , A.get_L2_Error_rel())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hfd_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
