{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "34f98c97",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import openml\n",
    "import data_preprocess as dp\n",
    "import warnings\n",
    "\n",
    "data = openml.datasets.get_dataset(846)\n",
    "\n",
    "X, y, categorical_indicator, attribute_names = data.get_data(\n",
    "    target=data.default_target_attribute, dataset_format=\"dataframe\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a854335e",
   "metadata": {},
   "source": [
    "# Data Preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b83124ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "import numpy as np\n",
    "\n",
    "nominal = [b for a, b in zip(categorical_indicator, attribute_names) if a]\n",
    "numerical = [b for a, b in zip(categorical_indicator, attribute_names) if not a]\n",
    "\n",
    "\n",
    "encoded_data = deepcopy(X)\n",
    "\n",
    "for col in nominal:\n",
    "    \n",
    "    mapping = {c: i+1 for i, c in enumerate(encoded_data[col].unique())}\n",
    "    encoded_data[col] = encoded_data[col].replace(mapping)\n",
    "    try:\n",
    "        encoded_data[col] = encoded_data[col].cat.add_categories([0])\n",
    "    except:\n",
    "        continue\n",
    "\n",
    "encoded_data = encoded_data[numerical + nominal]\n",
    "encoded_data.fillna(0, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e5c57b9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "mapping = {v: i for i, v in enumerate(y.unique())}\n",
    "\n",
    "\n",
    "y = y.replace(mapping)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "747e1b0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "train_data, left_out, train_label, y_left_out = train_test_split(encoded_data, y, test_size=0.3, \n",
    "                                                                 stratify=y, random_state=42)\n",
    "test_data, dev_data, test_label, dev_label = train_test_split(left_out, y_left_out, test_size=0.5, \n",
    "                                                              stratify=y_left_out, random_state=42)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "cfa7fd8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import Normalizer, StandardScaler\n",
    "\n",
    "nn = StandardScaler()\n",
    "\n",
    "nn.fit(train_data[numerical])\n",
    "\n",
    "train_data[numerical] = nn.transform(train_data[numerical])\n",
    "dev_data[numerical] = nn.transform(dev_data[numerical])\n",
    "test_data[numerical] = nn.transform(test_data[numerical])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ba008d69",
   "metadata": {},
   "outputs": [],
   "source": [
    "from imblearn.over_sampling import RandomOverSampler\n",
    "\n",
    "oversample = RandomOverSampler(sampling_strategy='minority')\n",
    "X_over, y_over = oversample.fit_resample(train_data, train_label)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00a1d08a",
   "metadata": {},
   "source": [
    "## TabNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "dde976a9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 0  | loss: 0.63984 | val_0_auc: 0.79251 |  0:00:03s\n",
      "epoch 1  | loss: 0.54791 | val_0_auc: 0.81473 |  0:00:05s\n",
      "epoch 2  | loss: 0.52802 | val_0_auc: 0.82344 |  0:00:06s\n",
      "epoch 3  | loss: 0.51133 | val_0_auc: 0.83073 |  0:00:07s\n",
      "epoch 4  | loss: 0.49654 | val_0_auc: 0.84365 |  0:00:08s\n",
      "epoch 5  | loss: 0.48282 | val_0_auc: 0.84679 |  0:00:09s\n",
      "epoch 6  | loss: 0.46772 | val_0_auc: 0.85469 |  0:00:10s\n",
      "epoch 7  | loss: 0.45423 | val_0_auc: 0.86379 |  0:00:11s\n",
      "epoch 8  | loss: 0.4379  | val_0_auc: 0.89178 |  0:00:12s\n",
      "epoch 9  | loss: 0.4079  | val_0_auc: 0.90698 |  0:00:14s\n",
      "epoch 10 | loss: 0.38636 | val_0_auc: 0.91231 |  0:00:15s\n",
      "epoch 11 | loss: 0.36134 | val_0_auc: 0.92157 |  0:00:16s\n",
      "epoch 12 | loss: 0.3539  | val_0_auc: 0.90968 |  0:00:17s\n",
      "epoch 13 | loss: 0.34717 | val_0_auc: 0.93247 |  0:00:18s\n",
      "epoch 14 | loss: 0.32291 | val_0_auc: 0.93311 |  0:00:19s\n",
      "epoch 15 | loss: 0.32151 | val_0_auc: 0.93353 |  0:00:20s\n",
      "epoch 16 | loss: 0.31649 | val_0_auc: 0.93592 |  0:00:21s\n",
      "epoch 17 | loss: 0.30743 | val_0_auc: 0.94073 |  0:00:23s\n",
      "epoch 18 | loss: 0.3056  | val_0_auc: 0.93802 |  0:00:24s\n",
      "epoch 19 | loss: 0.30045 | val_0_auc: 0.94068 |  0:00:25s\n",
      "epoch 20 | loss: 0.29045 | val_0_auc: 0.94383 |  0:00:26s\n",
      "epoch 21 | loss: 0.29508 | val_0_auc: 0.9435  |  0:00:27s\n",
      "epoch 22 | loss: 0.29396 | val_0_auc: 0.94346 |  0:00:28s\n",
      "epoch 23 | loss: 0.28716 | val_0_auc: 0.93735 |  0:00:29s\n",
      "epoch 24 | loss: 0.28694 | val_0_auc: 0.94295 |  0:00:30s\n",
      "epoch 25 | loss: 0.28375 | val_0_auc: 0.93958 |  0:00:32s\n",
      "epoch 26 | loss: 0.27988 | val_0_auc: 0.94483 |  0:00:33s\n",
      "epoch 27 | loss: 0.27863 | val_0_auc: 0.94325 |  0:00:34s\n",
      "epoch 28 | loss: 0.2815  | val_0_auc: 0.9401  |  0:00:35s\n",
      "epoch 29 | loss: 0.27965 | val_0_auc: 0.94225 |  0:00:36s\n",
      "epoch 30 | loss: 0.27335 | val_0_auc: 0.94156 |  0:00:37s\n",
      "epoch 31 | loss: 0.27233 | val_0_auc: 0.94286 |  0:00:38s\n",
      "epoch 32 | loss: 0.27011 | val_0_auc: 0.94474 |  0:00:39s\n",
      "epoch 33 | loss: 0.27719 | val_0_auc: 0.94517 |  0:00:40s\n",
      "epoch 34 | loss: 0.27108 | val_0_auc: 0.9387  |  0:00:42s\n",
      "epoch 35 | loss: 0.27161 | val_0_auc: 0.94569 |  0:00:43s\n",
      "epoch 36 | loss: 0.2676  | val_0_auc: 0.94229 |  0:00:44s\n",
      "epoch 37 | loss: 0.27112 | val_0_auc: 0.94441 |  0:00:45s\n",
      "epoch 38 | loss: 0.27506 | val_0_auc: 0.94089 |  0:00:46s\n",
      "epoch 39 | loss: 0.27328 | val_0_auc: 0.93615 |  0:00:47s\n",
      "epoch 40 | loss: 0.26855 | val_0_auc: 0.94116 |  0:00:48s\n",
      "epoch 41 | loss: 0.26501 | val_0_auc: 0.94433 |  0:00:49s\n",
      "epoch 42 | loss: 0.26643 | val_0_auc: 0.94241 |  0:00:50s\n",
      "epoch 43 | loss: 0.2619  | val_0_auc: 0.94576 |  0:00:52s\n",
      "epoch 44 | loss: 0.26434 | val_0_auc: 0.94273 |  0:00:53s\n",
      "epoch 45 | loss: 0.27282 | val_0_auc: 0.93722 |  0:00:54s\n",
      "epoch 46 | loss: 0.27301 | val_0_auc: 0.93865 |  0:00:55s\n",
      "epoch 47 | loss: 0.27032 | val_0_auc: 0.94261 |  0:00:56s\n",
      "epoch 48 | loss: 0.26629 | val_0_auc: 0.9441  |  0:00:57s\n",
      "epoch 49 | loss: 0.27001 | val_0_auc: 0.94216 |  0:00:58s\n",
      "epoch 50 | loss: 0.26542 | val_0_auc: 0.94448 |  0:01:00s\n",
      "epoch 51 | loss: 0.26077 | val_0_auc: 0.94381 |  0:01:01s\n",
      "epoch 52 | loss: 0.25953 | val_0_auc: 0.94257 |  0:01:02s\n",
      "epoch 53 | loss: 0.26436 | val_0_auc: 0.94366 |  0:01:03s\n",
      "\n",
      "Early stopping occurred at epoch 53 with best_epoch = 43 and best_val_0_auc = 0.94576\n",
      "roc is : 0.9421035940803384\n",
      "prec 0.8632093674290946\n",
      "recall 0.8603103292056781\n",
      "f-score 0.8617342789712804\n",
      "epoch 0  | loss: 0.60561 | val_0_auc: 0.80747 |  0:00:01s\n",
      "epoch 1  | loss: 0.52531 | val_0_auc: 0.8249  |  0:00:02s\n",
      "epoch 2  | loss: 0.50206 | val_0_auc: 0.84358 |  0:00:03s\n",
      "epoch 3  | loss: 0.48468 | val_0_auc: 0.85387 |  0:00:04s\n",
      "epoch 4  | loss: 0.44572 | val_0_auc: 0.88213 |  0:00:06s\n",
      "epoch 5  | loss: 0.38366 | val_0_auc: 0.90626 |  0:00:07s\n",
      "epoch 6  | loss: 0.34717 | val_0_auc: 0.9249  |  0:00:08s\n",
      "epoch 7  | loss: 0.32513 | val_0_auc: 0.9271  |  0:00:09s\n",
      "epoch 8  | loss: 0.3201  | val_0_auc: 0.93423 |  0:00:10s\n",
      "epoch 9  | loss: 0.31592 | val_0_auc: 0.92683 |  0:00:11s\n",
      "epoch 10 | loss: 0.29929 | val_0_auc: 0.93611 |  0:00:12s\n",
      "epoch 11 | loss: 0.30415 | val_0_auc: 0.93944 |  0:00:14s\n",
      "epoch 12 | loss: 0.29734 | val_0_auc: 0.94154 |  0:00:15s\n",
      "epoch 13 | loss: 0.30032 | val_0_auc: 0.93781 |  0:00:16s\n",
      "epoch 14 | loss: 0.28834 | val_0_auc: 0.93865 |  0:00:17s\n",
      "epoch 15 | loss: 0.28495 | val_0_auc: 0.93843 |  0:00:18s\n",
      "epoch 16 | loss: 0.28445 | val_0_auc: 0.94028 |  0:00:19s\n",
      "epoch 17 | loss: 0.27799 | val_0_auc: 0.94083 |  0:00:21s\n",
      "epoch 18 | loss: 0.28023 | val_0_auc: 0.94092 |  0:00:22s\n",
      "epoch 19 | loss: 0.27567 | val_0_auc: 0.94134 |  0:00:23s\n",
      "epoch 20 | loss: 0.27737 | val_0_auc: 0.94368 |  0:00:24s\n",
      "epoch 21 | loss: 0.27957 | val_0_auc: 0.93935 |  0:00:25s\n",
      "epoch 22 | loss: 0.28097 | val_0_auc: 0.94272 |  0:00:26s\n",
      "epoch 23 | loss: 0.27248 | val_0_auc: 0.94492 |  0:00:27s\n",
      "epoch 24 | loss: 0.27992 | val_0_auc: 0.94325 |  0:00:29s\n",
      "epoch 25 | loss: 0.27261 | val_0_auc: 0.94196 |  0:00:30s\n",
      "epoch 26 | loss: 0.27622 | val_0_auc: 0.9442  |  0:00:31s\n",
      "epoch 27 | loss: 0.27476 | val_0_auc: 0.9433  |  0:00:32s\n",
      "epoch 28 | loss: 0.27129 | val_0_auc: 0.93914 |  0:00:33s\n",
      "epoch 29 | loss: 0.27391 | val_0_auc: 0.9395  |  0:00:35s\n",
      "epoch 30 | loss: 0.27541 | val_0_auc: 0.94629 |  0:00:36s\n",
      "epoch 31 | loss: 0.26543 | val_0_auc: 0.94459 |  0:00:37s\n",
      "epoch 32 | loss: 0.26924 | val_0_auc: 0.94035 |  0:00:38s\n",
      "epoch 33 | loss: 0.26419 | val_0_auc: 0.94676 |  0:00:39s\n",
      "epoch 34 | loss: 0.2634  | val_0_auc: 0.94401 |  0:00:41s\n",
      "epoch 35 | loss: 0.26797 | val_0_auc: 0.94585 |  0:00:42s\n",
      "epoch 36 | loss: 0.26961 | val_0_auc: 0.93667 |  0:00:43s\n",
      "epoch 37 | loss: 0.27829 | val_0_auc: 0.94015 |  0:00:44s\n",
      "epoch 38 | loss: 0.27264 | val_0_auc: 0.94108 |  0:00:46s\n",
      "epoch 39 | loss: 0.27388 | val_0_auc: 0.93997 |  0:00:47s\n",
      "epoch 40 | loss: 0.2643  | val_0_auc: 0.94479 |  0:00:48s\n",
      "epoch 41 | loss: 0.26166 | val_0_auc: 0.94067 |  0:00:49s\n",
      "epoch 42 | loss: 0.26437 | val_0_auc: 0.94008 |  0:00:50s\n",
      "epoch 43 | loss: 0.25969 | val_0_auc: 0.94482 |  0:00:52s\n",
      "\n",
      "Early stopping occurred at epoch 43 with best_epoch = 33 and best_val_0_auc = 0.94676\n",
      "roc is : 0.9413742071881607\n",
      "prec 0.8549583759310647\n",
      "recall 0.8670190274841437\n",
      "f-score 0.8604886736795235\n",
      "epoch 0  | loss: 0.59931 | val_0_auc: 0.81559 |  0:00:01s\n",
      "epoch 1  | loss: 0.51787 | val_0_auc: 0.83227 |  0:00:02s\n",
      "epoch 2  | loss: 0.51047 | val_0_auc: 0.83299 |  0:00:03s\n",
      "epoch 3  | loss: 0.4973  | val_0_auc: 0.84325 |  0:00:05s\n",
      "epoch 4  | loss: 0.48425 | val_0_auc: 0.84969 |  0:00:06s\n",
      "epoch 5  | loss: 0.46943 | val_0_auc: 0.85522 |  0:00:07s\n",
      "epoch 6  | loss: 0.45976 | val_0_auc: 0.8727  |  0:00:08s\n",
      "epoch 7  | loss: 0.43881 | val_0_auc: 0.88166 |  0:00:10s\n",
      "epoch 8  | loss: 0.40974 | val_0_auc: 0.89056 |  0:00:11s\n",
      "epoch 9  | loss: 0.38247 | val_0_auc: 0.90745 |  0:00:12s\n",
      "epoch 10 | loss: 0.36857 | val_0_auc: 0.91888 |  0:00:13s\n",
      "epoch 11 | loss: 0.34122 | val_0_auc: 0.92938 |  0:00:16s\n",
      "epoch 12 | loss: 0.32203 | val_0_auc: 0.9352  |  0:00:19s\n",
      "epoch 13 | loss: 0.31098 | val_0_auc: 0.93658 |  0:00:23s\n",
      "epoch 14 | loss: 0.30695 | val_0_auc: 0.93717 |  0:00:26s\n",
      "epoch 15 | loss: 0.30306 | val_0_auc: 0.93988 |  0:00:29s\n",
      "epoch 16 | loss: 0.29938 | val_0_auc: 0.93689 |  0:00:32s\n",
      "epoch 17 | loss: 0.30033 | val_0_auc: 0.93874 |  0:00:34s\n",
      "epoch 18 | loss: 0.2972  | val_0_auc: 0.94325 |  0:00:36s\n",
      "epoch 19 | loss: 0.29549 | val_0_auc: 0.937   |  0:00:39s\n",
      "epoch 20 | loss: 0.28233 | val_0_auc: 0.94494 |  0:00:42s\n",
      "epoch 21 | loss: 0.28192 | val_0_auc: 0.94121 |  0:00:44s\n",
      "epoch 22 | loss: 0.28193 | val_0_auc: 0.94344 |  0:00:46s\n",
      "epoch 23 | loss: 0.28199 | val_0_auc: 0.94437 |  0:00:49s\n",
      "epoch 24 | loss: 0.27804 | val_0_auc: 0.93974 |  0:00:51s\n",
      "epoch 25 | loss: 0.28092 | val_0_auc: 0.93751 |  0:00:54s\n",
      "epoch 26 | loss: 0.28241 | val_0_auc: 0.94215 |  0:00:57s\n",
      "epoch 27 | loss: 0.27963 | val_0_auc: 0.93301 |  0:00:59s\n",
      "epoch 28 | loss: 0.28825 | val_0_auc: 0.93725 |  0:01:02s\n",
      "epoch 29 | loss: 0.27915 | val_0_auc: 0.94287 |  0:01:04s\n",
      "epoch 30 | loss: 0.27484 | val_0_auc: 0.94334 |  0:01:06s\n",
      "\n",
      "Early stopping occurred at epoch 30 with best_epoch = 20 and best_val_0_auc = 0.94494\n",
      "roc is : 0.9424875415282392\n",
      "prec 0.8623346846154756\n",
      "recall 0.87037148897614\n",
      "f-score 0.866143425438125\n",
      "epoch 0  | loss: 0.60705 | val_0_auc: 0.79818 |  0:00:01s\n",
      "epoch 1  | loss: 0.53744 | val_0_auc: 0.82724 |  0:00:02s\n",
      "epoch 2  | loss: 0.50487 | val_0_auc: 0.84633 |  0:00:04s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 3  | loss: 0.48736 | val_0_auc: 0.84995 |  0:00:05s\n",
      "epoch 4  | loss: 0.47227 | val_0_auc: 0.85893 |  0:00:06s\n",
      "epoch 5  | loss: 0.46204 | val_0_auc: 0.88006 |  0:00:08s\n",
      "epoch 6  | loss: 0.43237 | val_0_auc: 0.87107 |  0:00:09s\n",
      "epoch 7  | loss: 0.40461 | val_0_auc: 0.90178 |  0:00:11s\n",
      "epoch 8  | loss: 0.37902 | val_0_auc: 0.90339 |  0:00:12s\n",
      "epoch 9  | loss: 0.35    | val_0_auc: 0.92705 |  0:00:14s\n",
      "epoch 10 | loss: 0.33631 | val_0_auc: 0.92884 |  0:00:15s\n",
      "epoch 11 | loss: 0.32257 | val_0_auc: 0.91848 |  0:00:17s\n",
      "epoch 12 | loss: 0.3185  | val_0_auc: 0.93085 |  0:00:18s\n",
      "epoch 13 | loss: 0.31436 | val_0_auc: 0.92976 |  0:00:20s\n",
      "epoch 14 | loss: 0.306   | val_0_auc: 0.92915 |  0:00:21s\n",
      "epoch 15 | loss: 0.3062  | val_0_auc: 0.93528 |  0:00:23s\n",
      "epoch 16 | loss: 0.29977 | val_0_auc: 0.91744 |  0:00:24s\n",
      "epoch 17 | loss: 0.29831 | val_0_auc: 0.93978 |  0:00:26s\n",
      "epoch 18 | loss: 0.29117 | val_0_auc: 0.94194 |  0:00:27s\n",
      "epoch 19 | loss: 0.28691 | val_0_auc: 0.92405 |  0:00:29s\n",
      "epoch 20 | loss: 0.29112 | val_0_auc: 0.93988 |  0:00:30s\n",
      "epoch 21 | loss: 0.28717 | val_0_auc: 0.92862 |  0:00:32s\n",
      "epoch 22 | loss: 0.29311 | val_0_auc: 0.92939 |  0:00:33s\n",
      "epoch 23 | loss: 0.28852 | val_0_auc: 0.94229 |  0:00:35s\n",
      "epoch 24 | loss: 0.28413 | val_0_auc: 0.92937 |  0:00:36s\n",
      "epoch 25 | loss: 0.2845  | val_0_auc: 0.93173 |  0:00:38s\n",
      "epoch 26 | loss: 0.29574 | val_0_auc: 0.92421 |  0:00:39s\n",
      "epoch 27 | loss: 0.28495 | val_0_auc: 0.94336 |  0:00:40s\n",
      "epoch 28 | loss: 0.28983 | val_0_auc: 0.92977 |  0:00:42s\n",
      "epoch 29 | loss: 0.28781 | val_0_auc: 0.94045 |  0:00:43s\n",
      "epoch 30 | loss: 0.28055 | val_0_auc: 0.94394 |  0:00:45s\n",
      "epoch 31 | loss: 0.27736 | val_0_auc: 0.93359 |  0:00:46s\n",
      "epoch 32 | loss: 0.28249 | val_0_auc: 0.93329 |  0:00:48s\n",
      "epoch 33 | loss: 0.28953 | val_0_auc: 0.94071 |  0:00:49s\n",
      "epoch 34 | loss: 0.283   | val_0_auc: 0.93995 |  0:00:50s\n",
      "epoch 35 | loss: 0.28565 | val_0_auc: 0.93269 |  0:00:52s\n",
      "epoch 36 | loss: 0.28033 | val_0_auc: 0.9409  |  0:00:53s\n",
      "epoch 37 | loss: 0.28336 | val_0_auc: 0.93362 |  0:00:55s\n",
      "epoch 38 | loss: 0.27899 | val_0_auc: 0.92297 |  0:00:56s\n",
      "epoch 39 | loss: 0.27516 | val_0_auc: 0.93761 |  0:00:57s\n",
      "epoch 40 | loss: 0.27672 | val_0_auc: 0.93852 |  0:00:59s\n",
      "\n",
      "Early stopping occurred at epoch 40 with best_epoch = 30 and best_val_0_auc = 0.94394\n",
      "roc is : 0.9397160978556327\n",
      "prec 0.8408385093167702\n",
      "recall 0.8646179401993355\n",
      "f-score 0.8504504504504504\n",
      "epoch 0  | loss: 0.64094 | val_0_auc: 0.79015 |  0:00:01s\n",
      "epoch 1  | loss: 0.54939 | val_0_auc: 0.82454 |  0:00:02s\n",
      "epoch 2  | loss: 0.51781 | val_0_auc: 0.84228 |  0:00:04s\n",
      "epoch 3  | loss: 0.49524 | val_0_auc: 0.84694 |  0:00:05s\n",
      "epoch 4  | loss: 0.48083 | val_0_auc: 0.8552  |  0:00:07s\n",
      "epoch 5  | loss: 0.47076 | val_0_auc: 0.85745 |  0:00:08s\n",
      "epoch 6  | loss: 0.46706 | val_0_auc: 0.85896 |  0:00:09s\n",
      "epoch 7  | loss: 0.44979 | val_0_auc: 0.86547 |  0:00:11s\n",
      "epoch 8  | loss: 0.42592 | val_0_auc: 0.88356 |  0:00:12s\n",
      "epoch 9  | loss: 0.39081 | val_0_auc: 0.89783 |  0:00:14s\n",
      "epoch 10 | loss: 0.37872 | val_0_auc: 0.91143 |  0:00:15s\n",
      "epoch 11 | loss: 0.35353 | val_0_auc: 0.92092 |  0:00:17s\n",
      "epoch 12 | loss: 0.34467 | val_0_auc: 0.92072 |  0:00:18s\n",
      "epoch 13 | loss: 0.32463 | val_0_auc: 0.93007 |  0:00:19s\n",
      "epoch 14 | loss: 0.31591 | val_0_auc: 0.91611 |  0:00:21s\n",
      "epoch 15 | loss: 0.31185 | val_0_auc: 0.93142 |  0:00:22s\n",
      "epoch 16 | loss: 0.30325 | val_0_auc: 0.93865 |  0:00:24s\n",
      "epoch 17 | loss: 0.29743 | val_0_auc: 0.93874 |  0:00:25s\n",
      "epoch 18 | loss: 0.29291 | val_0_auc: 0.93966 |  0:00:26s\n",
      "epoch 19 | loss: 0.28621 | val_0_auc: 0.9396  |  0:00:28s\n",
      "epoch 20 | loss: 0.28605 | val_0_auc: 0.93448 |  0:00:29s\n",
      "epoch 21 | loss: 0.28578 | val_0_auc: 0.93958 |  0:00:30s\n",
      "epoch 22 | loss: 0.27731 | val_0_auc: 0.93919 |  0:00:32s\n",
      "epoch 23 | loss: 0.27671 | val_0_auc: 0.93743 |  0:00:33s\n",
      "epoch 24 | loss: 0.28498 | val_0_auc: 0.93866 |  0:00:35s\n",
      "epoch 25 | loss: 0.27711 | val_0_auc: 0.93315 |  0:00:36s\n",
      "epoch 26 | loss: 0.28234 | val_0_auc: 0.9389  |  0:00:37s\n",
      "epoch 27 | loss: 0.28034 | val_0_auc: 0.93597 |  0:00:39s\n",
      "epoch 28 | loss: 0.28028 | val_0_auc: 0.94285 |  0:00:40s\n",
      "epoch 29 | loss: 0.27585 | val_0_auc: 0.93986 |  0:00:41s\n",
      "epoch 30 | loss: 0.27351 | val_0_auc: 0.93908 |  0:00:43s\n",
      "epoch 31 | loss: 0.27242 | val_0_auc: 0.94422 |  0:00:44s\n",
      "epoch 32 | loss: 0.27278 | val_0_auc: 0.94107 |  0:00:45s\n",
      "epoch 33 | loss: 0.26693 | val_0_auc: 0.94026 |  0:00:47s\n",
      "epoch 34 | loss: 0.26946 | val_0_auc: 0.94116 |  0:00:48s\n",
      "epoch 35 | loss: 0.27459 | val_0_auc: 0.9416  |  0:00:49s\n",
      "epoch 36 | loss: 0.26612 | val_0_auc: 0.94177 |  0:00:51s\n",
      "epoch 37 | loss: 0.26926 | val_0_auc: 0.94237 |  0:00:52s\n",
      "epoch 38 | loss: 0.26689 | val_0_auc: 0.94041 |  0:00:54s\n",
      "epoch 39 | loss: 0.27039 | val_0_auc: 0.94167 |  0:00:55s\n",
      "epoch 40 | loss: 0.26779 | val_0_auc: 0.93924 |  0:00:56s\n",
      "epoch 41 | loss: 0.27185 | val_0_auc: 0.93037 |  0:00:58s\n",
      "\n",
      "Early stopping occurred at epoch 41 with best_epoch = 31 and best_val_0_auc = 0.94422\n",
      "roc is : 0.9399018423437028\n",
      "prec 0.8596290688705073\n",
      "recall 0.8692766535789791\n",
      "f-score 0.8641441933313583\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score\n",
    "from pytorch_tabnet.tab_model import TabNetClassifier\n",
    "\n",
    "aucs = []\n",
    "for r in range(0, 5):\n",
    "    clf = TabNetClassifier(seed=r)\n",
    "    clf.fit(\n",
    "      X_over.values, y_over.values,\n",
    "      eval_set=[(dev_data.values, dev_label.values)],\n",
    "      patience = 10\n",
    "    )\n",
    "    preds = clf.predict(test_data.values)\n",
    "\n",
    "    rf_pred_prob = [i[1] for i in clf.predict_proba(test_data.values)]\n",
    "\n",
    "    roc = roc_auc_score(\n",
    "        test_label,\n",
    "        rf_pred_prob,\n",
    "    )\n",
    "    prec = precision_score(test_label.tolist(), preds.tolist(), average='macro')\n",
    "    recall = recall_score(test_label.tolist(), preds.tolist(), average='macro')\n",
    "    f_score = f1_score(test_label.tolist(), preds.tolist(), average='macro')\n",
    "    \n",
    "    aucs.append(roc)\n",
    "\n",
    "    print('roc is : {}\\nprec {}\\nrecall {}\\nf-score {}' .format(roc, prec, recall, f_score))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8885a5e5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.9421035940803384, 0.9413742071881607, 0.9424875415282392, 0.9397160978556327, 0.9399018423437028]\n",
      "0.9411166565992148 +- 0.0011275674574728301\n"
     ]
    }
   ],
   "source": [
    "print(aucs)\n",
    "print(f\"{np.mean(aucs)} +- {np.std(aucs)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d144549c",
   "metadata": {},
   "source": [
    "## Random Forest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3dca035e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "roc is : 0.9089021443672607\n",
      "prec 0.8466624941482974\n",
      "recall 0.8086416490486258\n",
      "f-score 0.823595574486512\n",
      "\n",
      "roc is : 0.9094008607671399\n",
      "prec 0.8425906258850184\n",
      "recall 0.8068974630021142\n",
      "f-score 0.8210843316099036\n",
      "\n",
      "roc is : 0.9090557988523106\n",
      "prec 0.843231172853929\n",
      "recall 0.8105708245243128\n",
      "f-score 0.8238117939486223\n",
      "\n",
      "roc is : 0.9115844910903051\n",
      "prec 0.8472310833226999\n",
      "recall 0.8123150105708246\n",
      "f-score 0.8263234986398827\n",
      "\n",
      "roc is : 0.9089538659015404\n",
      "prec 0.8418369932432432\n",
      "recall 0.8055987617034128\n",
      "f-score 0.81994159365874\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score\n",
    "\n",
    "from tqdm import tqdm\n",
    " \n",
    "comb_data = pd.concat([X_over, dev_data])\n",
    "comb_labels = pd.concat([y_over, dev_label])\n",
    "\n",
    "r_aucs = []\n",
    "for r in range(0, 5):\n",
    "    rf_clf = RandomForestClassifier(random_state=r).fit(comb_data, comb_labels)\n",
    "\n",
    "    rf_pred = rf_clf.predict(test_data)\n",
    "    test_ac = rf_clf.score(test_data, test_label)\n",
    "    f_score = f1_score(test_label, rf_pred, average='macro')\n",
    "    prec = precision_score(test_label, rf_pred, average='macro')\n",
    "    recall = recall_score(test_label, rf_pred, average='macro')\n",
    "\n",
    "    rf_pred_prob = [i[1] for i in rf_clf.predict_proba(test_data)]\n",
    "\n",
    "    weighted_roc = roc_auc_score(\n",
    "        test_label,\n",
    "        rf_pred_prob\n",
    "    )\n",
    "\n",
    "    r_aucs.append(weighted_roc)\n",
    "    print('roc is : {}\\nprec {}\\nrecall {}\\nf-score {}\\n' .format(weighted_roc, prec, recall, f_score))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5279bd13",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.9089021443672607, 0.9094008607671399, 0.9090557988523106, 0.9115844910903051, 0.9089538659015404]\n",
      "0.9095794321957114 +- 0.0010174863925559942\n"
     ]
    }
   ],
   "source": [
    "print(r_aucs)\n",
    "print(f\"{np.mean(r_aucs)} +- {np.std(r_aucs)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9196ea7e",
   "metadata": {},
   "source": [
    "## XGBoost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "87e27cad",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[11:44:46] WARNING: C:/Users/Administrator/workspace/xgboost-win64_release_1.4.0/src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n",
      "XGB test_acc: 0.8835341365461847\n",
      "recall: 0.8525747508305648\n",
      "precision: 0.8701088093743461\n",
      "f_score: 0.8604895404829562\n",
      " ROC: 0.9342615524010873\n"
     ]
    }
   ],
   "source": [
    "import xgboost as xgb\n",
    "from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score\n",
    "\n",
    "xg_boost = xgb.XGBClassifier(objective='binary:logistic', use_label_encoder=False)\n",
    "\n",
    "comb_data = pd.concat([X_over, dev_data])\n",
    "comb_labels = pd.concat([y_over, dev_label])\n",
    "\n",
    "xgb_clf = xg_boost.fit(comb_data.values.astype(float), comb_labels)\n",
    "test_acc = xgb_clf.score(test_data.values.astype(float), test_label)\n",
    "test_pred = xgb_clf.predict(test_data.values.astype(float))\n",
    "f_score = f1_score(test_label, test_pred, average='macro')\n",
    "prec = precision_score(test_label, test_pred, average='macro')\n",
    "recall = recall_score(test_label, test_pred, average='macro')\n",
    "\n",
    "pred_score = [i[1] for i in xgb_clf.predict_proba(test_data.values.astype(float))]\n",
    "weighted_roc = roc_auc_score(\n",
    "    test_label,\n",
    "    pred_score,\n",
    ")\n",
    "\n",
    "print(f'XGB test_acc: {test_acc}\\nrecall: {recall}\\nprecision: {prec}\\nf_score: {f_score}\\n ROC: {weighted_roc}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7178908",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
