{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Usecase: TabEval models for classification tasks\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Prepare environment\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Set up the runtime\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Project directory: /home/xj265/phd/codebase/Euphratica/Euphratica-dev\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "project_dir = os.getcwd()\n",
    "while not os.path.exists(os.path.join(project_dir, \".git\")):\n",
    "    project_dir = os.path.dirname(project_dir)\n",
    "print(f\"Project directory: {project_dir}\")\n",
    "sys.path.insert(0, project_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Import customised libraries\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tabeval.plugins import Plugins\n",
    "from tabcamel.data.dataset import TabularDataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Specify the results of interest"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Generate the data with known causal structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((1000, 20), (1000,))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = TabularDataset(\n",
    "    dataset_name=\"statlog_german_credit_data\",\n",
    "    task_type=\"classification\"\n",
    ")\n",
    "X = dataset.X_df\n",
    "y = dataset.y_s\n",
    "X.shape, y.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Export results\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Fit the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initial accuracy is 0.7215\n",
      "Iteration number 1 reached accuracy of 0.469.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tabeval.plugins.generic.plugin_arf.ARFPlugin at 0x71c17d6c20b0>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = Plugins().get(\"arf\")\n",
    "model.fit(dataset.data_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Attribute1</th>\n",
       "      <th>Attribute2</th>\n",
       "      <th>Attribute3</th>\n",
       "      <th>Attribute4</th>\n",
       "      <th>Attribute5</th>\n",
       "      <th>Attribute6</th>\n",
       "      <th>Attribute7</th>\n",
       "      <th>Attribute8</th>\n",
       "      <th>Attribute9</th>\n",
       "      <th>Attribute10</th>\n",
       "      <th>...</th>\n",
       "      <th>Attribute12</th>\n",
       "      <th>Attribute13</th>\n",
       "      <th>Attribute14</th>\n",
       "      <th>Attribute15</th>\n",
       "      <th>Attribute16</th>\n",
       "      <th>Attribute17</th>\n",
       "      <th>Attribute18</th>\n",
       "      <th>Attribute19</th>\n",
       "      <th>Attribute20</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>A12</td>\n",
       "      <td>23</td>\n",
       "      <td>A31</td>\n",
       "      <td>A46</td>\n",
       "      <td>1154</td>\n",
       "      <td>A61</td>\n",
       "      <td>A71</td>\n",
       "      <td>4</td>\n",
       "      <td>A93</td>\n",
       "      <td>A101</td>\n",
       "      <td>...</td>\n",
       "      <td>A122</td>\n",
       "      <td>48</td>\n",
       "      <td>A143</td>\n",
       "      <td>A153</td>\n",
       "      <td>2</td>\n",
       "      <td>A172</td>\n",
       "      <td>1</td>\n",
       "      <td>A191</td>\n",
       "      <td>A201</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>A11</td>\n",
       "      <td>18</td>\n",
       "      <td>A34</td>\n",
       "      <td>A41</td>\n",
       "      <td>6121</td>\n",
       "      <td>A64</td>\n",
       "      <td>A71</td>\n",
       "      <td>2</td>\n",
       "      <td>A91</td>\n",
       "      <td>A101</td>\n",
       "      <td>...</td>\n",
       "      <td>A122</td>\n",
       "      <td>42</td>\n",
       "      <td>A143</td>\n",
       "      <td>A152</td>\n",
       "      <td>1</td>\n",
       "      <td>A173</td>\n",
       "      <td>1</td>\n",
       "      <td>A192</td>\n",
       "      <td>A201</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>A12</td>\n",
       "      <td>48</td>\n",
       "      <td>A33</td>\n",
       "      <td>A49</td>\n",
       "      <td>15270</td>\n",
       "      <td>A61</td>\n",
       "      <td>A71</td>\n",
       "      <td>2</td>\n",
       "      <td>A93</td>\n",
       "      <td>A101</td>\n",
       "      <td>...</td>\n",
       "      <td>A123</td>\n",
       "      <td>50</td>\n",
       "      <td>A143</td>\n",
       "      <td>A152</td>\n",
       "      <td>1</td>\n",
       "      <td>A171</td>\n",
       "      <td>1</td>\n",
       "      <td>A192</td>\n",
       "      <td>A201</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>A12</td>\n",
       "      <td>5</td>\n",
       "      <td>A32</td>\n",
       "      <td>A43</td>\n",
       "      <td>638</td>\n",
       "      <td>A61</td>\n",
       "      <td>A73</td>\n",
       "      <td>3</td>\n",
       "      <td>A93</td>\n",
       "      <td>A103</td>\n",
       "      <td>...</td>\n",
       "      <td>A121</td>\n",
       "      <td>48</td>\n",
       "      <td>A143</td>\n",
       "      <td>A152</td>\n",
       "      <td>1</td>\n",
       "      <td>A173</td>\n",
       "      <td>1</td>\n",
       "      <td>A191</td>\n",
       "      <td>A201</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>A11</td>\n",
       "      <td>7</td>\n",
       "      <td>A32</td>\n",
       "      <td>A49</td>\n",
       "      <td>889</td>\n",
       "      <td>A61</td>\n",
       "      <td>A73</td>\n",
       "      <td>4</td>\n",
       "      <td>A91</td>\n",
       "      <td>A101</td>\n",
       "      <td>...</td>\n",
       "      <td>A122</td>\n",
       "      <td>23</td>\n",
       "      <td>A143</td>\n",
       "      <td>A152</td>\n",
       "      <td>1</td>\n",
       "      <td>A173</td>\n",
       "      <td>1</td>\n",
       "      <td>A191</td>\n",
       "      <td>A201</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95</th>\n",
       "      <td>A11</td>\n",
       "      <td>31</td>\n",
       "      <td>A32</td>\n",
       "      <td>A40</td>\n",
       "      <td>1786</td>\n",
       "      <td>A61</td>\n",
       "      <td>A72</td>\n",
       "      <td>2</td>\n",
       "      <td>A92</td>\n",
       "      <td>A102</td>\n",
       "      <td>...</td>\n",
       "      <td>A122</td>\n",
       "      <td>23</td>\n",
       "      <td>A143</td>\n",
       "      <td>A152</td>\n",
       "      <td>1</td>\n",
       "      <td>A173</td>\n",
       "      <td>1</td>\n",
       "      <td>A191</td>\n",
       "      <td>A201</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>96</th>\n",
       "      <td>A14</td>\n",
       "      <td>28</td>\n",
       "      <td>A34</td>\n",
       "      <td>A40</td>\n",
       "      <td>3504</td>\n",
       "      <td>A65</td>\n",
       "      <td>A75</td>\n",
       "      <td>4</td>\n",
       "      <td>A92</td>\n",
       "      <td>A101</td>\n",
       "      <td>...</td>\n",
       "      <td>A123</td>\n",
       "      <td>55</td>\n",
       "      <td>A143</td>\n",
       "      <td>A152</td>\n",
       "      <td>1</td>\n",
       "      <td>A172</td>\n",
       "      <td>1</td>\n",
       "      <td>A191</td>\n",
       "      <td>A201</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>97</th>\n",
       "      <td>A14</td>\n",
       "      <td>13</td>\n",
       "      <td>A34</td>\n",
       "      <td>A42</td>\n",
       "      <td>2053</td>\n",
       "      <td>A61</td>\n",
       "      <td>A73</td>\n",
       "      <td>2</td>\n",
       "      <td>A92</td>\n",
       "      <td>A101</td>\n",
       "      <td>...</td>\n",
       "      <td>A123</td>\n",
       "      <td>32</td>\n",
       "      <td>A143</td>\n",
       "      <td>A152</td>\n",
       "      <td>1</td>\n",
       "      <td>A173</td>\n",
       "      <td>1</td>\n",
       "      <td>A191</td>\n",
       "      <td>A201</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>A12</td>\n",
       "      <td>19</td>\n",
       "      <td>A32</td>\n",
       "      <td>A41</td>\n",
       "      <td>10204</td>\n",
       "      <td>A63</td>\n",
       "      <td>A71</td>\n",
       "      <td>4</td>\n",
       "      <td>A93</td>\n",
       "      <td>A101</td>\n",
       "      <td>...</td>\n",
       "      <td>A124</td>\n",
       "      <td>38</td>\n",
       "      <td>A143</td>\n",
       "      <td>A153</td>\n",
       "      <td>1</td>\n",
       "      <td>A174</td>\n",
       "      <td>1</td>\n",
       "      <td>A192</td>\n",
       "      <td>A201</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>99</th>\n",
       "      <td>A14</td>\n",
       "      <td>9</td>\n",
       "      <td>A32</td>\n",
       "      <td>A43</td>\n",
       "      <td>2911</td>\n",
       "      <td>A61</td>\n",
       "      <td>A75</td>\n",
       "      <td>4</td>\n",
       "      <td>A93</td>\n",
       "      <td>A101</td>\n",
       "      <td>...</td>\n",
       "      <td>A121</td>\n",
       "      <td>36</td>\n",
       "      <td>A141</td>\n",
       "      <td>A152</td>\n",
       "      <td>2</td>\n",
       "      <td>A173</td>\n",
       "      <td>1</td>\n",
       "      <td>A191</td>\n",
       "      <td>A201</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>100 rows × 21 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   Attribute1  Attribute2 Attribute3 Attribute4  Attribute5 Attribute6  \\\n",
       "0         A12          23        A31        A46        1154        A61   \n",
       "1         A11          18        A34        A41        6121        A64   \n",
       "2         A12          48        A33        A49       15270        A61   \n",
       "3         A12           5        A32        A43         638        A61   \n",
       "4         A11           7        A32        A49         889        A61   \n",
       "..        ...         ...        ...        ...         ...        ...   \n",
       "95        A11          31        A32        A40        1786        A61   \n",
       "96        A14          28        A34        A40        3504        A65   \n",
       "97        A14          13        A34        A42        2053        A61   \n",
       "98        A12          19        A32        A41       10204        A63   \n",
       "99        A14           9        A32        A43        2911        A61   \n",
       "\n",
       "   Attribute7  Attribute8 Attribute9 Attribute10  ...  Attribute12  \\\n",
       "0         A71           4        A93        A101  ...         A122   \n",
       "1         A71           2        A91        A101  ...         A122   \n",
       "2         A71           2        A93        A101  ...         A123   \n",
       "3         A73           3        A93        A103  ...         A121   \n",
       "4         A73           4        A91        A101  ...         A122   \n",
       "..        ...         ...        ...         ...  ...          ...   \n",
       "95        A72           2        A92        A102  ...         A122   \n",
       "96        A75           4        A92        A101  ...         A123   \n",
       "97        A73           2        A92        A101  ...         A123   \n",
       "98        A71           4        A93        A101  ...         A124   \n",
       "99        A75           4        A93        A101  ...         A121   \n",
       "\n",
       "   Attribute13  Attribute14 Attribute15 Attribute16  Attribute17 Attribute18  \\\n",
       "0           48         A143        A153           2         A172           1   \n",
       "1           42         A143        A152           1         A173           1   \n",
       "2           50         A143        A152           1         A171           1   \n",
       "3           48         A143        A152           1         A173           1   \n",
       "4           23         A143        A152           1         A173           1   \n",
       "..         ...          ...         ...         ...          ...         ...   \n",
       "95          23         A143        A152           1         A173           1   \n",
       "96          55         A143        A152           1         A172           1   \n",
       "97          32         A143        A152           1         A173           1   \n",
       "98          38         A143        A153           1         A174           1   \n",
       "99          36         A141        A152           2         A173           1   \n",
       "\n",
       "    Attribute19 Attribute20 target  \n",
       "0          A191        A201      2  \n",
       "1          A192        A201      1  \n",
       "2          A192        A201      1  \n",
       "3          A191        A201      1  \n",
       "4          A191        A201      2  \n",
       "..          ...         ...    ...  \n",
       "95         A191        A201      1  \n",
       "96         A191        A201      1  \n",
       "97         A191        A201      1  \n",
       "98         A192        A201      1  \n",
       "99         A191        A201      1  \n",
       "\n",
       "[100 rows x 21 columns]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.generate(100)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnn",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
