{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cfba4cba-031f-4e17-a759-e842c1692c26",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from scipy.spatial.distance import cdist\n",
    "import random\n",
    "from sklearn.cluster import KMeans\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "from icfesl import *\n",
    "from utility_functions import *\n",
    "from sklearn.feature_selection import VarianceThreshold\n",
    "from sklearn.metrics import roc_auc_score, accuracy_score\n",
    "from xgboost import XGBClassifier\n",
    "from pytorch_tabnet.tab_model import TabNetClassifier\n",
    "import time    \n",
    "from sklearn.model_selection import train_test_split\n",
    "from catboost import CatBoostClassifier, Pool\n",
    "from sklearn.preprocessing import LabelEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "506ea2ca-8d50-4f17-9eee-5dfcb4c4044f",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv(\"../../../writing/UCI datasets/mushroom/agaricus-lepiota.data\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e23be722-daf7-4b34-bee3-024fa990fe71",
   "metadata": {},
   "source": [
    "## Preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ac7bac62-f13f-4ff2-afd1-f3ec2fbca50d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#convert target to binary\n",
    "y = data['p'].apply(lambda x: 1 if x == 'p' else 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b6c44cf9-ae25-4a92-bc93-dc127b354046",
   "metadata": {},
   "outputs": [],
   "source": [
    "#drop this because it causes complete separation\n",
    "data = data.drop('p.1', axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8399b1fe-2ff3-4712-8ecb-8eb047ae6fe4",
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_vars = data.columns.tolist()\n",
    "cat_vars.remove('p')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7b58fba0-c9e9-4359-a9ee-6b4c413c84a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = LabelEncoder()\n",
    "for var in cat_vars:\n",
    "    data[var] = encoder.fit_transform(data[var])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2acf1e79-bedf-48de-bcd5-743ac9dc8202",
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_vars = data.columns.tolist()\n",
    "cat_vars.remove('p')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "291672fa-dabd-4273-af1f-fbe5951c8778",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = data[cat_vars]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7d98142d-b77c-44f7-9bc5-437444581271",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ba11f53f-8a1a-43d6-bf7e-2e98161d789b",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = X_train.reset_index(drop=True)\n",
    "X_test = X_test.reset_index(drop=True)\n",
    "y_train = y_train.reset_index(drop=True)\n",
    "y_test = y_test.reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "9cd1b154-2c5c-436e-b9ac-270fa08230aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "X2, encoder = icfesl.f_get_dummies(X_train, cat_vars)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "49c5a38e-ebc1-417e-a46e-0bad6c004271",
   "metadata": {},
   "outputs": [],
   "source": [
    "X2_test = icfesl.f_get_dummies(X_test, cat_vars, encoder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "ddb7ecc3-4148-4cd4-a13a-4523c8411be1",
   "metadata": {},
   "outputs": [],
   "source": [
    "selector = VarianceThreshold(threshold=np.mean(y_train)*0.1)\n",
    "\n",
    "selector.fit(X2)\n",
    "\n",
    "selected_features_mask = selector.get_support()\n",
    "\n",
    "selected_column_names = X2.columns[selected_features_mask]\n",
    "\n",
    "X2 = X2[selected_column_names]\n",
    "\n",
    "X2_test = X2_test[selected_column_names]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "b581d02a-5aa5-4d4a-8dec-33523662accb",
   "metadata": {},
   "outputs": [],
   "source": [
    "for c in X2.columns.tolist():\n",
    "    X2[c] = X2[c].astype('int')\n",
    "    X2_test[c] = X2_test[c].astype('int')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95e9d69d-e4ae-4bbb-9345-343a6cd1855c",
   "metadata": {},
   "source": [
    "## ICFESL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "53e1c2a9-2a7f-47e3-86cd-5168ead7edc2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m2025-12-02 06:16:56.848\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36micfesl\u001b[0m:\u001b[36mregularized_search_algorun\u001b[0m:\u001b[36m397\u001b[0m - \u001b[1mrunning algorithm with L2 regularization factor = 1 ------>\u001b[0m\n",
      "\u001b[32m2025-12-02 06:16:57.994\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36micfesl\u001b[0m:\u001b[36mregularized_search_algorun\u001b[0m:\u001b[36m420\u001b[0m - \u001b[1mRunning logit with ICFESL encoding\u001b[0m\n",
      "\u001b[32m2025-12-02 06:16:58.135\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36micfesl\u001b[0m:\u001b[36mregularized_search_algorun\u001b[0m:\u001b[36m438\u001b[0m - \u001b[1mRunning xgbClassifier with ICFESL encoding\u001b[0m\n",
      "\u001b[32m2025-12-02 06:16:58.408\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36micfesl\u001b[0m:\u001b[36mregularized_search_algorun\u001b[0m:\u001b[36m485\u001b[0m - \u001b[1mCompleted: running algorithm with L2 regularization factor = 1 ------>\u001b[0m\n",
      "\u001b[32m2025-12-02 06:16:58.408\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36micfesl\u001b[0m:\u001b[36mregularized_search_algorun\u001b[0m:\u001b[36m397\u001b[0m - \u001b[1mrunning algorithm with L2 regularization factor = 5 ------>\u001b[0m\n",
      "\u001b[32m2025-12-02 06:16:59.272\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36micfesl\u001b[0m:\u001b[36mregularized_search_algorun\u001b[0m:\u001b[36m420\u001b[0m - \u001b[1mRunning logit with ICFESL encoding\u001b[0m\n",
      "\u001b[32m2025-12-02 06:16:59.397\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36micfesl\u001b[0m:\u001b[36mregularized_search_algorun\u001b[0m:\u001b[36m438\u001b[0m - \u001b[1mRunning xgbClassifier with ICFESL encoding\u001b[0m\n",
      "\u001b[32m2025-12-02 06:16:59.568\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36micfesl\u001b[0m:\u001b[36mregularized_search_algorun\u001b[0m:\u001b[36m485\u001b[0m - \u001b[1mCompleted: running algorithm with L2 regularization factor = 5 ------>\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "fit_info_panel, best_index, fit_figs, cluster_groups, criterions, inertias, gap_statss = icfesl.regularized_search_algorun(\n",
    "    X2, pd.Series(y_train), X2_test, pd.Series(y_test), cat_vars, 'classification', alphas = [1, 5], cbine_column=False,\n",
    "    distance_threshold=0.002, figure=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "2f27aeac-8890-4402-aee9-ee56b7cfed90",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Experiment</th>\n",
       "      <th>dof</th>\n",
       "      <th>reg_fit_time</th>\n",
       "      <th>reg_training_auroc</th>\n",
       "      <th>reg_testing_auroc</th>\n",
       "      <th>xgb_fit_time</th>\n",
       "      <th>xgb_training_auroc</th>\n",
       "      <th>xgb_testing_auroc</th>\n",
       "      <th>var_inf</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>34</td>\n",
       "      <td>0.1344</td>\n",
       "      <td>0.999660</td>\n",
       "      <td>0.999229</td>\n",
       "      <td>0.2636</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.000054</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>34</td>\n",
       "      <td>0.1153</td>\n",
       "      <td>0.999366</td>\n",
       "      <td>0.998926</td>\n",
       "      <td>0.1499</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.000046</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Experiment  dof  reg_fit_time  reg_training_auroc  reg_testing_auroc  \\\n",
       "0           0   34        0.1344            0.999660           0.999229   \n",
       "1           1   34        0.1153            0.999366           0.998926   \n",
       "\n",
       "   xgb_fit_time  xgb_training_auroc  xgb_testing_auroc   var_inf  \n",
       "0        0.2636                 1.0                1.0  0.000054  \n",
       "1        0.1499                 1.0                1.0  0.000046  "
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fit_info_panel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "09c5ad8a-555c-418f-9594-dfa75504ccf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "X3 = icfesl.combine_features(X2, cluster_groups[best_index])\n",
    "X3_test = icfesl.combine_features(X2_test, cluster_groups[best_index])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "ca8394c1-e89e-4aaf-9c3c-5b58fdda63dd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "best_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "7ad23cdd-c393-4084-960a-b2997fc5a023",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'x+0': ['x::3', 'x::5'],\n",
       "  'x+1': ['x::2'],\n",
       "  's+0': ['s::2', 's::3'],\n",
       "  'n+0': ['n::3', 'n::4', 'n::9'],\n",
       "  'n+1': ['n::8'],\n",
       "  'n+2': ['n::2'],\n",
       "  't+0': ['t::1'],\n",
       "  'c+0': ['c::1'],\n",
       "  'n.1+0': ['n.1::1'],\n",
       "  'k+0': ['k::2', 'k::3', 'k::5', 'k::7', 'k::9', 'k::10'],\n",
       "  'e+0': ['e::1'],\n",
       "  'e.1+0': ['e.1::2', 'e.1::3'],\n",
       "  'e.1+1': ['e.1::1'],\n",
       "  's.1+0': ['s.1::2'],\n",
       "  's.1+1': ['s.1::1'],\n",
       "  's.2+0': ['s.2::2'],\n",
       "  's.2+1': ['s.2::1'],\n",
       "  'w+0': ['w::6', 'w::7'],\n",
       "  'w+1': ['w::3'],\n",
       "  'w+2': ['w::4'],\n",
       "  'w.1+0': ['w.1::6', 'w.1::7'],\n",
       "  'w.1+1': ['w.1::4'],\n",
       "  'w.1+2': ['w.1::3'],\n",
       "  'o+0': ['o::2'],\n",
       "  'o+1': ['o::1'],\n",
       "  'p.3+0': ['p.3::4'],\n",
       "  'p.3+1': ['p.3::2'],\n",
       "  'k.1+0': ['k.1::2', 'k.1::3'],\n",
       "  'k.1+1': ['k.1::1', 'k.1::7'],\n",
       "  's.3+0': ['s.3::4'],\n",
       "  's.3+1': ['s.3::5'],\n",
       "  's.3+2': ['s.3::3'],\n",
       "  'u+0': ['u::2', 'u::4'],\n",
       "  'u+1': ['u::1']},\n",
       " {'x+0': ['x::2', 'x::3', 'x::5'],\n",
       "  's+0': ['s::3'],\n",
       "  's+1': ['s::2'],\n",
       "  'n+0': ['n::3', 'n::4', 'n::9'],\n",
       "  'n+1': ['n::8'],\n",
       "  'n+2': ['n::2'],\n",
       "  't+0': ['t::1'],\n",
       "  'c+0': ['c::1'],\n",
       "  'n.1+0': ['n.1::1'],\n",
       "  'k+0': ['k::5', 'k::7', 'k::9', 'k::10'],\n",
       "  'k+1': ['k::2', 'k::3'],\n",
       "  'e+0': ['e::1'],\n",
       "  'e.1+0': ['e.1::2', 'e.1::3'],\n",
       "  'e.1+1': ['e.1::1'],\n",
       "  's.1+0': ['s.1::2'],\n",
       "  's.1+1': ['s.1::1'],\n",
       "  's.2+0': ['s.2::2'],\n",
       "  's.2+1': ['s.2::1'],\n",
       "  'w+0': ['w::4', 'w::6'],\n",
       "  'w+1': ['w::3'],\n",
       "  'w+2': ['w::7'],\n",
       "  'w.1+0': ['w.1::6', 'w.1::7'],\n",
       "  'w.1+1': ['w.1::4'],\n",
       "  'w.1+2': ['w.1::3'],\n",
       "  'o+0': ['o::2'],\n",
       "  'o+1': ['o::1'],\n",
       "  'p.3+0': ['p.3::4'],\n",
       "  'p.3+1': ['p.3::2'],\n",
       "  'k.1+0': ['k.1::2', 'k.1::3'],\n",
       "  'k.1+1': ['k.1::1', 'k.1::7'],\n",
       "  's.3+0': ['s.3::3', 's.3::4'],\n",
       "  's.3+1': ['s.3::5'],\n",
       "  'u+0': ['u::2'],\n",
       "  'u+1': ['u::1', 'u::4']}]"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cluster_groups"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd487b13-b989-442b-9d57-4d2dde847b4a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d686a721-91c0-433b-8834-b575613019ec",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f597b250-7f14-4fa2-889b-014c17def955",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52622bc8-e6c6-4d92-93b0-1ad0326afb35",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
