{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Usecase: TabEval models for classification tasks\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Prepare environment\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Set up the runtime\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Project directory: /home/xj265/phd/codebase/Euphratica/Euphratica-dev\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "project_dir = os.getcwd()\n",
    "while not os.path.exists(os.path.join(project_dir, \".git\")):\n",
    "    project_dir = os.path.dirname(project_dir)\n",
    "print(f\"Project directory: {project_dir}\")\n",
    "sys.path.insert(0, project_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Import customised libraries\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.datasets import make_classification\n",
    "from tabeval.plugins import Plugins"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Specify the results of interest"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Generate the data with known causal structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((1000, 10), (1000,))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X, y = make_classification(\n",
    "    n_samples=1000,\n",
    "    n_features=10,\n",
    "    n_informative=5,\n",
    "    n_classes=2,\n",
    "    random_state=42,\n",
    ")\n",
    "X.shape, y.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Export results\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Fit the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initial accuracy is 0.8625\n",
      "Iteration number 1 reached accuracy of 0.455.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tabeval.plugins.generic.plugin_arf.ARFPlugin at 0x7738f108d270>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = Plugins().get(\"arf\")\n",
    "model.fit(pd.DataFrame(np.concatenate([X, y.reshape(-1, 1)], axis=1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>9</th>\n",
       "      <th>10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.042295</td>\n",
       "      <td>0.160339</td>\n",
       "      <td>-1.657675</td>\n",
       "      <td>-1.575995</td>\n",
       "      <td>0.613059</td>\n",
       "      <td>-0.320766</td>\n",
       "      <td>-2.184236</td>\n",
       "      <td>-1.431800</td>\n",
       "      <td>-0.686708</td>\n",
       "      <td>2.432317</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>-0.953402</td>\n",
       "      <td>0.503905</td>\n",
       "      <td>-0.091632</td>\n",
       "      <td>-0.971582</td>\n",
       "      <td>2.565245</td>\n",
       "      <td>-1.054052</td>\n",
       "      <td>1.025605</td>\n",
       "      <td>0.971286</td>\n",
       "      <td>1.026680</td>\n",
       "      <td>1.663101</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>-1.382051</td>\n",
       "      <td>-0.661474</td>\n",
       "      <td>-0.239147</td>\n",
       "      <td>-1.154431</td>\n",
       "      <td>-0.133984</td>\n",
       "      <td>-0.002358</td>\n",
       "      <td>-3.755475</td>\n",
       "      <td>-1.034527</td>\n",
       "      <td>0.016368</td>\n",
       "      <td>1.542894</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>-0.525810</td>\n",
       "      <td>-0.790756</td>\n",
       "      <td>-1.156567</td>\n",
       "      <td>-1.560428</td>\n",
       "      <td>-0.387313</td>\n",
       "      <td>-1.925261</td>\n",
       "      <td>-4.230447</td>\n",
       "      <td>-0.446434</td>\n",
       "      <td>-1.622859</td>\n",
       "      <td>1.894724</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.465693</td>\n",
       "      <td>0.356729</td>\n",
       "      <td>-0.692498</td>\n",
       "      <td>-0.671326</td>\n",
       "      <td>-2.498939</td>\n",
       "      <td>1.848587</td>\n",
       "      <td>-3.078555</td>\n",
       "      <td>-0.327987</td>\n",
       "      <td>-0.090019</td>\n",
       "      <td>-2.022645</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95</th>\n",
       "      <td>0.932068</td>\n",
       "      <td>-0.898824</td>\n",
       "      <td>-1.727597</td>\n",
       "      <td>-0.975531</td>\n",
       "      <td>0.523537</td>\n",
       "      <td>-1.288263</td>\n",
       "      <td>-0.124227</td>\n",
       "      <td>-1.195657</td>\n",
       "      <td>0.493122</td>\n",
       "      <td>0.786079</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>96</th>\n",
       "      <td>0.931139</td>\n",
       "      <td>0.816138</td>\n",
       "      <td>-1.351555</td>\n",
       "      <td>-0.565883</td>\n",
       "      <td>0.976779</td>\n",
       "      <td>3.117676</td>\n",
       "      <td>0.325670</td>\n",
       "      <td>-2.065272</td>\n",
       "      <td>-0.508554</td>\n",
       "      <td>-0.455072</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>97</th>\n",
       "      <td>2.416385</td>\n",
       "      <td>0.496423</td>\n",
       "      <td>-0.425257</td>\n",
       "      <td>-1.958529</td>\n",
       "      <td>1.974365</td>\n",
       "      <td>-0.028789</td>\n",
       "      <td>2.771621</td>\n",
       "      <td>-0.048143</td>\n",
       "      <td>-0.384051</td>\n",
       "      <td>3.080874</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>-1.128855</td>\n",
       "      <td>-1.153887</td>\n",
       "      <td>-0.913614</td>\n",
       "      <td>-1.718022</td>\n",
       "      <td>0.725244</td>\n",
       "      <td>0.448233</td>\n",
       "      <td>-2.244349</td>\n",
       "      <td>0.064845</td>\n",
       "      <td>-0.053124</td>\n",
       "      <td>2.077524</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>99</th>\n",
       "      <td>1.497729</td>\n",
       "      <td>-0.669920</td>\n",
       "      <td>-1.594899</td>\n",
       "      <td>-0.456832</td>\n",
       "      <td>2.245004</td>\n",
       "      <td>0.410677</td>\n",
       "      <td>3.961171</td>\n",
       "      <td>-1.595605</td>\n",
       "      <td>-1.194712</td>\n",
       "      <td>0.340254</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>100 rows × 11 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "           0         1         2         3         4         5         6  \\\n",
       "0   1.042295  0.160339 -1.657675 -1.575995  0.613059 -0.320766 -2.184236   \n",
       "1  -0.953402  0.503905 -0.091632 -0.971582  2.565245 -1.054052  1.025605   \n",
       "2  -1.382051 -0.661474 -0.239147 -1.154431 -0.133984 -0.002358 -3.755475   \n",
       "3  -0.525810 -0.790756 -1.156567 -1.560428 -0.387313 -1.925261 -4.230447   \n",
       "4   0.465693  0.356729 -0.692498 -0.671326 -2.498939  1.848587 -3.078555   \n",
       "..       ...       ...       ...       ...       ...       ...       ...   \n",
       "95  0.932068 -0.898824 -1.727597 -0.975531  0.523537 -1.288263 -0.124227   \n",
       "96  0.931139  0.816138 -1.351555 -0.565883  0.976779  3.117676  0.325670   \n",
       "97  2.416385  0.496423 -0.425257 -1.958529  1.974365 -0.028789  2.771621   \n",
       "98 -1.128855 -1.153887 -0.913614 -1.718022  0.725244  0.448233 -2.244349   \n",
       "99  1.497729 -0.669920 -1.594899 -0.456832  2.245004  0.410677  3.961171   \n",
       "\n",
       "           7         8         9   10  \n",
       "0  -1.431800 -0.686708  2.432317  1.0  \n",
       "1   0.971286  1.026680  1.663101  0.0  \n",
       "2  -1.034527  0.016368  1.542894  0.0  \n",
       "3  -0.446434 -1.622859  1.894724  0.0  \n",
       "4  -0.327987 -0.090019 -2.022645  1.0  \n",
       "..       ...       ...       ...  ...  \n",
       "95 -1.195657  0.493122  0.786079  0.0  \n",
       "96 -2.065272 -0.508554 -0.455072  0.0  \n",
       "97 -0.048143 -0.384051  3.080874  1.0  \n",
       "98  0.064845 -0.053124  2.077524  0.0  \n",
       "99 -1.595605 -1.194712  0.340254  0.0  \n",
       "\n",
       "[100 rows x 11 columns]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.generate(100)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnn",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
