{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "134f979d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../')\n",
    "\n",
    "import transtab\n",
    "\n",
    "# set random seed\n",
    "transtab.random_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "42c60011",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "########################################\n",
      "openml data index: 31\n",
      "load data from credit-g\n",
      "# data: 1000, # feat: 20, # cate: 11,  # bin: 2, # numerical: 7, pos rate: 0.70\n",
      "########################################\n",
      "openml data index: 29\n",
      "load data from credit-approval\n",
      "# data: 690, # feat: 15, # cate: 9,  # bin: 0, # numerical: 6, pos rate: 0.56\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dd62a8df24d14e22a69d77088bd1b220",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:   0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, test val_loss: 0.574102\n",
      "epoch: 0, train loss: 3.9759, lr: 0.000100, spent: 0.4 secs\n",
      "epoch: 1, test val_loss: 0.565162\n",
      "epoch: 1, train loss: 3.7812, lr: 0.000100, spent: 0.9 secs\n",
      "epoch: 2, test val_loss: 0.576745\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 2, train loss: 3.6560, lr: 0.000100, spent: 1.1 secs\n",
      "epoch: 3, test val_loss: 0.566665\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 3, train loss: 3.6539, lr: 0.000100, spent: 1.4 secs\n",
      "epoch: 4, test val_loss: 0.548929\n",
      "epoch: 4, train loss: 3.6118, lr: 0.000100, spent: 1.7 secs\n",
      "epoch: 5, test val_loss: 0.545800\n",
      "epoch: 5, train loss: 3.5634, lr: 0.000100, spent: 2.2 secs\n",
      "epoch: 6, test val_loss: 0.545121\n",
      "epoch: 6, train loss: 3.5035, lr: 0.000100, spent: 2.4 secs\n",
      "epoch: 7, test val_loss: 0.529130\n",
      "epoch: 7, train loss: 3.4372, lr: 0.000100, spent: 2.7 secs\n",
      "epoch: 8, test val_loss: 0.525149\n",
      "epoch: 8, train loss: 3.3768, lr: 0.000100, spent: 3.0 secs\n",
      "epoch: 9, test val_loss: 0.518042\n",
      "epoch: 9, train loss: 3.3204, lr: 0.000100, spent: 3.5 secs\n",
      "epoch: 10, test val_loss: 0.508209\n",
      "epoch: 10, train loss: 3.2816, lr: 0.000100, spent: 3.8 secs\n",
      "epoch: 11, test val_loss: 0.497027\n",
      "epoch: 11, train loss: 3.1952, lr: 0.000100, spent: 4.1 secs\n",
      "epoch: 12, test val_loss: 0.495085\n",
      "epoch: 12, train loss: 3.1852, lr: 0.000100, spent: 4.6 secs\n",
      "epoch: 13, test val_loss: 0.479123\n",
      "epoch: 13, train loss: 3.0853, lr: 0.000100, spent: 4.9 secs\n",
      "epoch: 14, test val_loss: 0.492737\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 14, train loss: 3.0682, lr: 0.000100, spent: 5.2 secs\n",
      "epoch: 15, test val_loss: 0.477266\n",
      "epoch: 15, train loss: 2.9653, lr: 0.000100, spent: 5.5 secs\n",
      "epoch: 16, test val_loss: 0.503946\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 16, train loss: 2.9797, lr: 0.000100, spent: 5.7 secs\n",
      "epoch: 17, test val_loss: 0.484869\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 17, train loss: 2.9767, lr: 0.000100, spent: 6.0 secs\n",
      "epoch: 18, test val_loss: 0.467354\n",
      "epoch: 18, train loss: 2.8925, lr: 0.000100, spent: 6.5 secs\n",
      "epoch: 19, test val_loss: 0.471429\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 19, train loss: 2.8963, lr: 0.000100, spent: 6.7 secs\n",
      "epoch: 20, test val_loss: 0.460370\n",
      "epoch: 20, train loss: 2.8847, lr: 0.000100, spent: 7.0 secs\n",
      "epoch: 21, test val_loss: 0.498306\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 21, train loss: 2.8389, lr: 0.000100, spent: 7.4 secs\n",
      "epoch: 22, test val_loss: 0.441738\n",
      "epoch: 22, train loss: 2.8077, lr: 0.000100, spent: 7.7 secs\n",
      "epoch: 23, test val_loss: 0.479452\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 23, train loss: 2.8506, lr: 0.000100, spent: 8.0 secs\n",
      "epoch: 24, test val_loss: 0.450146\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 24, train loss: 2.7006, lr: 0.000100, spent: 8.5 secs\n",
      "epoch: 25, test val_loss: 0.460931\n",
      "EarlyStopping counter: 3 out of 5\n",
      "epoch: 25, train loss: 2.7361, lr: 0.000100, spent: 8.7 secs\n",
      "epoch: 26, test val_loss: 0.482305\n",
      "EarlyStopping counter: 4 out of 5\n",
      "epoch: 26, train loss: 2.6959, lr: 0.000100, spent: 9.0 secs\n",
      "epoch: 27, test val_loss: 0.440060\n",
      "epoch: 27, train loss: 2.7485, lr: 0.000100, spent: 9.3 secs\n",
      "epoch: 28, test val_loss: 0.450090\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 28, train loss: 2.7765, lr: 0.000100, spent: 9.6 secs\n",
      "epoch: 29, test val_loss: 0.472720\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 29, train loss: 2.6344, lr: 0.000100, spent: 9.8 secs\n",
      "epoch: 30, test val_loss: 0.438471\n",
      "epoch: 30, train loss: 2.5639, lr: 0.000100, spent: 10.3 secs\n",
      "epoch: 31, test val_loss: 0.498057\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 31, train loss: 2.7224, lr: 0.000100, spent: 10.6 secs\n",
      "epoch: 32, test val_loss: 0.463493\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 32, train loss: 2.6888, lr: 0.000100, spent: 11.0 secs\n",
      "epoch: 33, test val_loss: 0.435828\n",
      "epoch: 33, train loss: 2.6895, lr: 0.000100, spent: 11.3 secs\n",
      "epoch: 34, test val_loss: 0.495953\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 34, train loss: 2.5385, lr: 0.000100, spent: 11.6 secs\n",
      "epoch: 35, test val_loss: 0.444737\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 35, train loss: 2.5663, lr: 0.000100, spent: 12.1 secs\n",
      "epoch: 36, test val_loss: 0.449832\n",
      "EarlyStopping counter: 3 out of 5\n",
      "epoch: 36, train loss: 2.6015, lr: 0.000100, spent: 12.4 secs\n",
      "epoch: 37, test val_loss: 0.441197\n",
      "EarlyStopping counter: 4 out of 5\n",
      "epoch: 37, train loss: 2.5011, lr: 0.000100, spent: 12.6 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-10-05 08:35:04.023 | INFO     | transtab.trainer:train:136 - load best at last from ./checkpoint\n",
      "2022-10-05 08:35:04.042 | INFO     | transtab.trainer:save_model:243 - saving model checkpoint to ./checkpoint\n",
      "2022-10-05 08:35:04.167 | INFO     | transtab.trainer:train:141 - training complete, cost 13.1 secs.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 38, test val_loss: 0.503903\n",
      "EarlyStopping counter: 5 out of 5\n",
      "early stopped\n"
     ]
    }
   ],
   "source": [
    "# load a dataset and start vanilla supervised training\n",
    "allset, trainset, valset, testset, cat_cols, num_cols, bin_cols = transtab.load_data(['credit-g', 'credit-approval'])\n",
    "\n",
    "# build transtab classifier model\n",
    "model = transtab.build_classifier(cat_cols, num_cols, bin_cols)\n",
    "\n",
    "# start training\n",
    "training_arguments = {\n",
    "    'num_epoch':50,\n",
    "    'eval_metric':'val_loss',\n",
    "    'eval_less_is_better':True,\n",
    "    'output_dir':'./checkpoint',\n",
    "    'batch_size':128,\n",
    "    'lr':1e-4,\n",
    "    'weight_decay':1e-4,\n",
    "    }\n",
    "transtab.train(model, trainset[0], valset[0], **training_arguments)\n",
    "\n",
    "# save model\n",
    "model.save('./ckpt/pretrained')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d6bdc971",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-10-05 08:35:04.352 | INFO     | transtab.modeling_transtab:load:773 - missing keys: []\n",
      "2022-10-05 08:35:04.354 | INFO     | transtab.modeling_transtab:load:774 - unexpected keys: []\n",
      "2022-10-05 08:35:04.355 | INFO     | transtab.modeling_transtab:load:775 - load model from ./ckpt/pretrained\n",
      "2022-10-05 08:35:04.370 | INFO     | transtab.modeling_transtab:load:222 - load feature extractor from ./ckpt/pretrained/extractor/extractor.json\n",
      "2022-10-05 08:35:04.372 | INFO     | transtab.modeling_transtab:update:832 - Build a new classifier with num 2 classes outputs, need further finetune to work.\n"
     ]
    }
   ],
   "source": [
    "# now let's use another data and try to leverage the pretrained model for finetuning\n",
    "# here we have loaded the required data `credit-approval` before, no need to load again.\n",
    "\n",
    "# load the pretrained model\n",
    "model.load('./ckpt/pretrained')\n",
    "\n",
    "# update model's categorical/numerical/binary column dict\n",
    "# need to specify the number of classes if the new dataset has different # of classes from the \n",
    "# pretrained one.\n",
    "model.update({'cat':cat_cols,'num':num_cols,'bin':bin_cols, 'num_class':2})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f399d02e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9b64fb45097e4061af5a0186c17d98a6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:   0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, test auc: 0.282251\n",
      "epoch: 0, train loss: 3.3862, lr: 0.000200, spent: 0.2 secs\n",
      "epoch: 1, test auc: 0.865801\n",
      "epoch: 1, train loss: 2.8794, lr: 0.000200, spent: 0.3 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 2, test auc: 0.865801\n",
      "epoch: 2, train loss: 2.5943, lr: 0.000200, spent: 0.7 secs\n",
      "epoch: 3, test auc: 0.865801\n",
      "epoch: 3, train loss: 2.4300, lr: 0.000200, spent: 0.8 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 4, test auc: 0.872727\n",
      "epoch: 4, train loss: 2.2617, lr: 0.000200, spent: 1.0 secs\n",
      "epoch: 5, test auc: 0.879654\n",
      "epoch: 5, train loss: 2.0867, lr: 0.000200, spent: 1.1 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 6, test auc: 0.880519\n",
      "epoch: 6, train loss: 1.9774, lr: 0.000200, spent: 1.3 secs\n",
      "epoch: 7, test auc: 0.883117\n",
      "epoch: 7, train loss: 1.8739, lr: 0.000200, spent: 1.4 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 8, test auc: 0.889177\n",
      "epoch: 8, train loss: 1.8919, lr: 0.000200, spent: 1.5 secs\n",
      "epoch: 9, test auc: 0.890909\n",
      "epoch: 9, train loss: 1.8794, lr: 0.000200, spent: 1.7 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 10, test auc: 0.896970\n",
      "epoch: 10, train loss: 1.8456, lr: 0.000200, spent: 2.0 secs\n",
      "epoch: 11, test auc: 0.897835\n",
      "epoch: 11, train loss: 1.8213, lr: 0.000200, spent: 2.2 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 12, test auc: 0.896104\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 12, train loss: 1.8219, lr: 0.000200, spent: 2.3 secs\n",
      "epoch: 13, test auc: 0.903896\n",
      "epoch: 13, train loss: 1.7924, lr: 0.000200, spent: 2.4 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 14, test auc: 0.905628\n",
      "epoch: 14, train loss: 1.7964, lr: 0.000200, spent: 2.6 secs\n",
      "epoch: 15, test auc: 0.904762\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 15, train loss: 1.7641, lr: 0.000200, spent: 2.7 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 16, test auc: 0.904762\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 16, train loss: 1.7788, lr: 0.000200, spent: 2.8 secs\n",
      "epoch: 17, test auc: 0.909091\n",
      "epoch: 17, train loss: 1.7456, lr: 0.000200, spent: 2.9 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 18, test auc: 0.910823\n",
      "epoch: 18, train loss: 1.7438, lr: 0.000200, spent: 3.3 secs\n",
      "epoch: 19, test auc: 0.912554\n",
      "epoch: 19, train loss: 1.7569, lr: 0.000200, spent: 3.4 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 20, test auc: 0.912554\n",
      "epoch: 20, train loss: 1.7533, lr: 0.000200, spent: 3.5 secs\n",
      "epoch: 21, test auc: 0.915152\n",
      "epoch: 21, train loss: 1.7439, lr: 0.000200, spent: 3.7 secs\n",
      "epoch: 22, test auc: 0.915152\n",
      "epoch: 22, train loss: 1.7020, lr: 0.000200, spent: 3.9 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 23, test auc: 0.916883\n",
      "epoch: 23, train loss: 1.7017, lr: 0.000200, spent: 4.0 secs\n",
      "epoch: 24, test auc: 0.917749\n",
      "epoch: 24, train loss: 1.6625, lr: 0.000200, spent: 4.1 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 25, test auc: 0.918615\n",
      "epoch: 25, train loss: 1.6432, lr: 0.000200, spent: 4.3 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 26, test auc: 0.922944\n",
      "epoch: 26, train loss: 1.6299, lr: 0.000200, spent: 4.7 secs\n",
      "epoch: 27, test auc: 0.922944\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 27, train loss: 1.6158, lr: 0.000200, spent: 4.8 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 28, test auc: 0.925541\n",
      "epoch: 28, train loss: 1.5971, lr: 0.000200, spent: 4.9 secs\n",
      "epoch: 29, test auc: 0.926407\n",
      "epoch: 29, train loss: 1.5771, lr: 0.000200, spent: 5.0 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 30, test auc: 0.927273\n",
      "epoch: 30, train loss: 1.5763, lr: 0.000200, spent: 5.2 secs\n",
      "epoch: 31, test auc: 0.933333\n",
      "epoch: 31, train loss: 1.6021, lr: 0.000200, spent: 5.3 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 32, test auc: 0.936797\n",
      "epoch: 32, train loss: 1.5513, lr: 0.000200, spent: 5.5 secs\n",
      "epoch: 33, test auc: 0.938528\n",
      "epoch: 33, train loss: 1.5160, lr: 0.000200, spent: 5.6 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 34, test auc: 0.938528\n",
      "epoch: 34, train loss: 1.5250, lr: 0.000200, spent: 5.8 secs\n",
      "epoch: 35, test auc: 0.938528\n",
      "epoch: 35, train loss: 1.4732, lr: 0.000200, spent: 6.0 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 36, test auc: 0.934199\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 36, train loss: 1.4738, lr: 0.000200, spent: 6.1 secs\n",
      "epoch: 37, test auc: 0.934199\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 37, train loss: 1.4667, lr: 0.000200, spent: 6.2 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 38, test auc: 0.933333\n",
      "EarlyStopping counter: 3 out of 5\n",
      "epoch: 38, train loss: 1.4209, lr: 0.000200, spent: 6.3 secs\n",
      "epoch: 39, test auc: 0.933333\n",
      "EarlyStopping counter: 4 out of 5\n",
      "epoch: 39, train loss: 1.4371, lr: 0.000200, spent: 6.4 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/outcome_predict/transtab/transtab/trainer.py:169: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only.\n",
      "  y_test = pd.concat(y_test, 0)\n",
      "2022-10-05 08:35:10.982 | INFO     | transtab.trainer:train:136 - load best at last from ./checkpoint\n",
      "2022-10-05 08:35:10.994 | INFO     | transtab.trainer:save_model:243 - saving model checkpoint to ./checkpoint\n",
      "2022-10-05 08:35:11.142 | INFO     | transtab.trainer:train:141 - training complete, cost 6.7 secs.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 40, test auc: 0.929870\n",
      "EarlyStopping counter: 5 out of 5\n",
      "early stopped\n"
     ]
    }
   ],
   "source": [
    "# start training\n",
    "training_arguments = {\n",
    "    'num_epoch':50,\n",
    "    'eval_metric':'auc',\n",
    "    'eval_less_is_better':False,\n",
    "    'output_dir':'./checkpoint',\n",
    "    'batch_size':128,\n",
    "    'lr':2e-4,\n",
    "    }\n",
    "\n",
    "transtab.train(model, trainset[1], valset[1], **training_arguments)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3aa87021",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "auc 0.95 mean/interval 0.8757(0.05)\n",
      "0.8807749627421758\n"
     ]
    }
   ],
   "source": [
    "# evaluation\n",
    "x_test, y_test = testset[1]\n",
    "ypred = transtab.predict(model, x_test)\n",
    "transtab.evaluate(ypred, y_test, metric='auc')\n",
    "\n",
    "from sklearn.metrics import roc_auc_score\n",
    "print(roc_auc_score(y_test, ypred))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('pytrial': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "2f00ab411e3cfe281b54106f98420bd06c3920b043d7b3741a63d2a4ac576305"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
