{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "07c4865a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/user/embedding_fusion\n"
     ]
    }
   ],
   "source": [
    "%cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8ebf7709",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✓ Task extensions applied: BaseTask.mask_features\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/user/anaconda3/envs/tabicl/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "/home/user/anaconda3/envs/tabicl/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/home/user/anaconda3/envs/tabicl/lib/python3.9/site-packages/torchvision/image.so: undefined symbol: _ZN3c1017RegisterOperatorsD1Ev'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?\n",
      "  warn(\n"
     ]
    }
   ],
   "source": [
    "from utils import TableData"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1833a7a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = \"/home/user/embedding_fusion/data/dfs-fs-table/ratebeer-user-active\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0c2e7c03",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " ==> load material dataset from /home/user/embedding_fusion/data/dfs-fs-table/ratebeer-user-active\n"
     ]
    }
   ],
   "source": [
    "table_data = TableData.load_from_dir(data_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "625a95db",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>beer_rating_count</th>\n",
       "      <th>avg_beer_rating</th>\n",
       "      <th>max_beer_rating</th>\n",
       "      <th>min_beer_rating</th>\n",
       "      <th>place_rating_count</th>\n",
       "      <th>favorite_count</th>\n",
       "      <th>user_type</th>\n",
       "      <th>COUNT(place_ratings)</th>\n",
       "      <th>MAX(place_ratings.ambiance_score)</th>\n",
       "      <th>...</th>\n",
       "      <th>NUM_UNIQUE(beer_ratings.YEAR(created_at))</th>\n",
       "      <th>PERCENT_TRUE(beer_ratings.beers.is_one_off)</th>\n",
       "      <th>PERCENT_TRUE(beer_ratings.beers.is_seasonal)</th>\n",
       "      <th>SUM(beer_ratings.beers.ibu)</th>\n",
       "      <th>SUM(beer_ratings.beers.last_9m_count)</th>\n",
       "      <th>SUM(beer_ratings.beers.rating_count)</th>\n",
       "      <th>SUM(beer_ratings.beers.rating_std_dev)</th>\n",
       "      <th>SUM(beer_ratings.beers.view_count)</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>is_active</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1755</td>\n",
       "      <td>2321.0</td>\n",
       "      <td>3.607324</td>\n",
       "      <td>4.6</td>\n",
       "      <td>0.5</td>\n",
       "      <td>77.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>multi_rater</td>\n",
       "      <td>0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>5.0</td>\n",
       "      <td>0.382937</td>\n",
       "      <td>0.250000</td>\n",
       "      <td>4630.0</td>\n",
       "      <td>1981.0</td>\n",
       "      <td>67606.0</td>\n",
       "      <td>20.654880</td>\n",
       "      <td>22627.0</td>\n",
       "      <td>2022-09-04</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>88534</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2022-12-03</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>79126</td>\n",
       "      <td>303.0</td>\n",
       "      <td>2.581518</td>\n",
       "      <td>5.0</td>\n",
       "      <td>0.6</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>beer_rater</td>\n",
       "      <td>0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>3.0</td>\n",
       "      <td>0.301205</td>\n",
       "      <td>0.056225</td>\n",
       "      <td>2644.0</td>\n",
       "      <td>2319.0</td>\n",
       "      <td>74191.0</td>\n",
       "      <td>39.775076</td>\n",
       "      <td>38353.0</td>\n",
       "      <td>2022-06-06</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2182</td>\n",
       "      <td>12466.0</td>\n",
       "      <td>3.524202</td>\n",
       "      <td>4.8</td>\n",
       "      <td>0.5</td>\n",
       "      <td>379.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>multi_rater</td>\n",
       "      <td>25</td>\n",
       "      <td>5.0</td>\n",
       "      <td>...</td>\n",
       "      <td>5.0</td>\n",
       "      <td>0.564909</td>\n",
       "      <td>0.365529</td>\n",
       "      <td>30914.0</td>\n",
       "      <td>8265.0</td>\n",
       "      <td>98809.0</td>\n",
       "      <td>23.079120</td>\n",
       "      <td>6366.0</td>\n",
       "      <td>2022-06-06</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4699</td>\n",
       "      <td>3045.0</td>\n",
       "      <td>3.452808</td>\n",
       "      <td>4.8</td>\n",
       "      <td>0.6</td>\n",
       "      <td>22.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>multi_rater</td>\n",
       "      <td>4</td>\n",
       "      <td>5.0</td>\n",
       "      <td>...</td>\n",
       "      <td>5.0</td>\n",
       "      <td>0.453608</td>\n",
       "      <td>0.184923</td>\n",
       "      <td>14797.0</td>\n",
       "      <td>10684.0</td>\n",
       "      <td>210924.0</td>\n",
       "      <td>55.147950</td>\n",
       "      <td>32112.0</td>\n",
       "      <td>2022-06-06</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 206 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   user_id  beer_rating_count  avg_beer_rating  max_beer_rating  \\\n",
       "0     1755             2321.0         3.607324              4.6   \n",
       "1    88534                NaN              NaN              NaN   \n",
       "2    79126              303.0         2.581518              5.0   \n",
       "3     2182            12466.0         3.524202              4.8   \n",
       "4     4699             3045.0         3.452808              4.8   \n",
       "\n",
       "   min_beer_rating  place_rating_count  favorite_count    user_type  \\\n",
       "0              0.5                77.0             0.0  multi_rater   \n",
       "1              NaN                 NaN             NaN          NaN   \n",
       "2              0.6                 0.0             0.0   beer_rater   \n",
       "3              0.5               379.0             0.0  multi_rater   \n",
       "4              0.6                22.0             0.0  multi_rater   \n",
       "\n",
       "   COUNT(place_ratings)  MAX(place_ratings.ambiance_score)  ...  \\\n",
       "0                     0                                NaN  ...   \n",
       "1                     0                                NaN  ...   \n",
       "2                     0                                NaN  ...   \n",
       "3                    25                                5.0  ...   \n",
       "4                     4                                5.0  ...   \n",
       "\n",
       "   NUM_UNIQUE(beer_ratings.YEAR(created_at))  \\\n",
       "0                                        5.0   \n",
       "1                                        NaN   \n",
       "2                                        3.0   \n",
       "3                                        5.0   \n",
       "4                                        5.0   \n",
       "\n",
       "   PERCENT_TRUE(beer_ratings.beers.is_one_off)  \\\n",
       "0                                     0.382937   \n",
       "1                                          NaN   \n",
       "2                                     0.301205   \n",
       "3                                     0.564909   \n",
       "4                                     0.453608   \n",
       "\n",
       "   PERCENT_TRUE(beer_ratings.beers.is_seasonal)  SUM(beer_ratings.beers.ibu)  \\\n",
       "0                                      0.250000                       4630.0   \n",
       "1                                           NaN                          0.0   \n",
       "2                                      0.056225                       2644.0   \n",
       "3                                      0.365529                      30914.0   \n",
       "4                                      0.184923                      14797.0   \n",
       "\n",
       "   SUM(beer_ratings.beers.last_9m_count)  \\\n",
       "0                                 1981.0   \n",
       "1                                    0.0   \n",
       "2                                 2319.0   \n",
       "3                                 8265.0   \n",
       "4                                10684.0   \n",
       "\n",
       "   SUM(beer_ratings.beers.rating_count)  \\\n",
       "0                               67606.0   \n",
       "1                                   0.0   \n",
       "2                               74191.0   \n",
       "3                               98809.0   \n",
       "4                              210924.0   \n",
       "\n",
       "   SUM(beer_ratings.beers.rating_std_dev)  SUM(beer_ratings.beers.view_count)  \\\n",
       "0                               20.654880                             22627.0   \n",
       "1                                0.000000                                 0.0   \n",
       "2                               39.775076                             38353.0   \n",
       "3                               23.079120                              6366.0   \n",
       "4                               55.147950                             32112.0   \n",
       "\n",
       "    timestamp  is_active  \n",
       "0  2022-09-04          1  \n",
       "1  2022-12-03          0  \n",
       "2  2022-06-06          0  \n",
       "3  2022-06-06          1  \n",
       "4  2022-06-06          0  \n",
       "\n",
       "[5 rows x 206 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "table_data.train_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d5e54e34",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tabicl import TabICLClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b8c3bae5",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = \"./tabICL/tabicl-classifier-v1.1-0506.ckpt\"\n",
    "clf = TabICLClassifier(\n",
    "    model_path = model_path,\n",
    "    device = \"cuda:5\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cbecdcce",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df = table_data.train_df\n",
    "val_df = table_data.val_df\n",
    "test_df = table_data.test_df\n",
    "\n",
    "# Get target column and separate features\n",
    "target_col = table_data.target_col\n",
    "X_train = train_df.drop(columns=[target_col])\n",
    "X_val = val_df.drop(columns=[target_col])\n",
    "X_test = test_df.drop(columns=[target_col])\n",
    "\n",
    "\n",
    "y_train = train_df[target_col].values\n",
    "y_val = val_df[target_col].values\n",
    "y_test = test_df[target_col].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "28c9310f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-1 {\n",
       "  /* Definition of color scheme common for light and dark mode */\n",
       "  --sklearn-color-text: #000;\n",
       "  --sklearn-color-text-muted: #666;\n",
       "  --sklearn-color-line: gray;\n",
       "  /* Definition of color scheme for unfitted estimators */\n",
       "  --sklearn-color-unfitted-level-0: #fff5e6;\n",
       "  --sklearn-color-unfitted-level-1: #f6e4d2;\n",
       "  --sklearn-color-unfitted-level-2: #ffe0b3;\n",
       "  --sklearn-color-unfitted-level-3: chocolate;\n",
       "  /* Definition of color scheme for fitted estimators */\n",
       "  --sklearn-color-fitted-level-0: #f0f8ff;\n",
       "  --sklearn-color-fitted-level-1: #d4ebff;\n",
       "  --sklearn-color-fitted-level-2: #b3dbfd;\n",
       "  --sklearn-color-fitted-level-3: cornflowerblue;\n",
       "\n",
       "  /* Specific color for light theme */\n",
       "  --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
       "  --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-icon: #696969;\n",
       "\n",
       "  @media (prefers-color-scheme: dark) {\n",
       "    /* Redefinition of color scheme for dark theme */\n",
       "    --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
       "    --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-icon: #878787;\n",
       "  }\n",
       "}\n",
       "\n",
       "#sk-container-id-1 {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 pre {\n",
       "  padding: 0;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 input.sk-hidden--visually {\n",
       "  border: 0;\n",
       "  clip: rect(1px 1px 1px 1px);\n",
       "  clip: rect(1px, 1px, 1px, 1px);\n",
       "  height: 1px;\n",
       "  margin: -1px;\n",
       "  overflow: hidden;\n",
       "  padding: 0;\n",
       "  position: absolute;\n",
       "  width: 1px;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-dashed-wrapped {\n",
       "  border: 1px dashed var(--sklearn-color-line);\n",
       "  margin: 0 0.4em 0.5em 0.4em;\n",
       "  box-sizing: border-box;\n",
       "  padding-bottom: 0.4em;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-container {\n",
       "  /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
       "     but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
       "     so we also need the `!important` here to be able to override the\n",
       "     default hidden behavior on the sphinx rendered scikit-learn.org.\n",
       "     See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
       "  display: inline-block !important;\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-text-repr-fallback {\n",
       "  display: none;\n",
       "}\n",
       "\n",
       "div.sk-parallel-item,\n",
       "div.sk-serial,\n",
       "div.sk-item {\n",
       "  /* draw centered vertical line to link estimators */\n",
       "  background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
       "  background-size: 2px 100%;\n",
       "  background-repeat: no-repeat;\n",
       "  background-position: center center;\n",
       "}\n",
       "\n",
       "/* Parallel-specific style estimator block */\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item::after {\n",
       "  content: \"\";\n",
       "  width: 100%;\n",
       "  border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
       "  flex-grow: 1;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel {\n",
       "  display: flex;\n",
       "  align-items: stretch;\n",
       "  justify-content: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item:first-child::after {\n",
       "  align-self: flex-end;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item:last-child::after {\n",
       "  align-self: flex-start;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item:only-child::after {\n",
       "  width: 0;\n",
       "}\n",
       "\n",
       "/* Serial-specific style estimator block */\n",
       "\n",
       "#sk-container-id-1 div.sk-serial {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "  align-items: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  padding-right: 1em;\n",
       "  padding-left: 1em;\n",
       "}\n",
       "\n",
       "\n",
       "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
       "clickable and can be expanded/collapsed.\n",
       "- Pipeline and ColumnTransformer use this feature and define the default style\n",
       "- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
       "*/\n",
       "\n",
       "/* Pipeline and ColumnTransformer style (default) */\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable {\n",
       "  /* Default theme specific background. It is overwritten whether we have a\n",
       "  specific estimator or a Pipeline/ColumnTransformer */\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "/* Toggleable label */\n",
       "#sk-container-id-1 label.sk-toggleable__label {\n",
       "  cursor: pointer;\n",
       "  display: flex;\n",
       "  width: 100%;\n",
       "  margin-bottom: 0;\n",
       "  padding: 0.5em;\n",
       "  box-sizing: border-box;\n",
       "  text-align: center;\n",
       "  align-items: start;\n",
       "  justify-content: space-between;\n",
       "  gap: 0.5em;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 label.sk-toggleable__label .caption {\n",
       "  font-size: 0.6rem;\n",
       "  font-weight: lighter;\n",
       "  color: var(--sklearn-color-text-muted);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n",
       "  /* Arrow on the left of the label */\n",
       "  content: \"▸\";\n",
       "  float: left;\n",
       "  margin-right: 0.25em;\n",
       "  color: var(--sklearn-color-icon);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "/* Toggleable content - dropdown */\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content {\n",
       "  max-height: 0;\n",
       "  max-width: 0;\n",
       "  overflow: hidden;\n",
       "  text-align: left;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content pre {\n",
       "  margin: 0.2em;\n",
       "  border-radius: 0.25em;\n",
       "  color: var(--sklearn-color-text);\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
       "  /* Expand drop-down */\n",
       "  max-height: 200px;\n",
       "  max-width: 100%;\n",
       "  overflow: auto;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
       "  content: \"▾\";\n",
       "}\n",
       "\n",
       "/* Pipeline/ColumnTransformer-specific style */\n",
       "\n",
       "#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator-specific style */\n",
       "\n",
       "/* Colorize estimator box */\n",
       "#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n",
       "#sk-container-id-1 div.sk-label label {\n",
       "  /* The background is the default theme color */\n",
       "  color: var(--sklearn-color-text-on-default-background);\n",
       "}\n",
       "\n",
       "/* On hover, darken the color of the background */\n",
       "#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "/* Label box, darken color on hover, fitted */\n",
       "#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator label */\n",
       "\n",
       "#sk-container-id-1 div.sk-label label {\n",
       "  font-family: monospace;\n",
       "  font-weight: bold;\n",
       "  display: inline-block;\n",
       "  line-height: 1.2em;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-label-container {\n",
       "  text-align: center;\n",
       "}\n",
       "\n",
       "/* Estimator-specific */\n",
       "#sk-container-id-1 div.sk-estimator {\n",
       "  font-family: monospace;\n",
       "  border: 1px dotted var(--sklearn-color-border-box);\n",
       "  border-radius: 0.25em;\n",
       "  box-sizing: border-box;\n",
       "  margin-bottom: 0.5em;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-estimator.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "/* on hover */\n",
       "#sk-container-id-1 div.sk-estimator:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-estimator.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
       "\n",
       "/* Common style for \"i\" and \"?\" */\n",
       "\n",
       ".sk-estimator-doc-link,\n",
       "a:link.sk-estimator-doc-link,\n",
       "a:visited.sk-estimator-doc-link {\n",
       "  float: right;\n",
       "  font-size: smaller;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1em;\n",
       "  height: 1em;\n",
       "  width: 1em;\n",
       "  text-decoration: none !important;\n",
       "  margin-left: 0.5em;\n",
       "  text-align: center;\n",
       "  /* unfitted */\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted,\n",
       "a:link.sk-estimator-doc-link.fitted,\n",
       "a:visited.sk-estimator-doc-link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "/* Span, style for the box shown on hovering the info icon */\n",
       ".sk-estimator-doc-link span {\n",
       "  display: none;\n",
       "  z-index: 9999;\n",
       "  position: relative;\n",
       "  font-weight: normal;\n",
       "  right: .2ex;\n",
       "  padding: .5ex;\n",
       "  margin: .5ex;\n",
       "  width: min-content;\n",
       "  min-width: 20ex;\n",
       "  max-width: 50ex;\n",
       "  color: var(--sklearn-color-text);\n",
       "  box-shadow: 2pt 2pt 4pt #999;\n",
       "  /* unfitted */\n",
       "  background: var(--sklearn-color-unfitted-level-0);\n",
       "  border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted span {\n",
       "  /* fitted */\n",
       "  background: var(--sklearn-color-fitted-level-0);\n",
       "  border: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link:hover span {\n",
       "  display: block;\n",
       "}\n",
       "\n",
       "/* \"?\"-specific style due to the `<a>` HTML tag */\n",
       "\n",
       "#sk-container-id-1 a.estimator_doc_link {\n",
       "  float: right;\n",
       "  font-size: 1rem;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1rem;\n",
       "  height: 1rem;\n",
       "  width: 1rem;\n",
       "  text-decoration: none;\n",
       "  /* unfitted */\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 a.estimator_doc_link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "#sk-container-id-1 a.estimator_doc_link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>TabICLClassifier(device=&#x27;cuda:5&#x27;,\n",
       "                 model_path=&#x27;./tabICL/tabicl-classifier-v1.1-0506.ckpt&#x27;)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow\"><div><div>TabICLClassifier</div></div><div><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></div></label><div class=\"sk-toggleable__content fitted\"><pre>TabICLClassifier(device=&#x27;cuda:5&#x27;,\n",
       "                 model_path=&#x27;./tabICL/tabicl-classifier-v1.1-0506.ckpt&#x27;)</pre></div> </div></div></div></div>"
      ],
      "text/plain": [
       "TabICLClassifier(device='cuda:5',\n",
       "                 model_path='./tabICL/tabicl-classifier-v1.1-0506.ckpt')"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "efe0c0c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_pred = clf.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "9d410201",
   "metadata": {},
   "outputs": [],
   "source": [
    "from relbench.base import TaskType\n",
    "from sklearn.metrics import mean_absolute_error, roc_auc_score, accuracy_score\n",
    "\n",
    "\n",
    "if table_data.task_type == TaskType.REGRESSION:\n",
    "    evaluate_metric_func = mean_absolute_error\n",
    "\n",
    "elif table_data.task_type == TaskType.BINARY_CLASSIFICATION:\n",
    "    evaluate_metric_func = roc_auc_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "1c3ce1ed",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8306133899721104"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluate_metric_func(y_test, test_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d23aed23",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tabicl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}