{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0c0001bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../')\n",
    "\n",
    "import transtab\n",
    "\n",
    "# set random seed\n",
    "transtab.random_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "865b42a8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "########################################\n",
      "openml data index: 31\n",
      "load data from credit-g\n",
      "# data: 1000, # feat: 20, # cate: 11,  # bin: 2, # numerical: 7, pos rate: 0.70\n",
      "########################################\n",
      "openml data index: 29\n",
      "load data from credit-approval\n",
      "# data: 690, # feat: 15, # cate: 9,  # bin: 0, # numerical: 6, pos rate: 0.56\n"
     ]
    }
   ],
   "source": [
    "# load multiple datasets by passing a list of data names\n",
    "allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \\\n",
    "    = transtab.load_data(['credit-g','credit-approval'])\n",
    "\n",
    "# build contrastive learner, set supervised=True for supervised VPCL\n",
    "model, collate_fn = transtab.build_contrastive_learner(\n",
    "    cat_cols, num_cols, bin_cols, \n",
    "    supervised=True, # if take supervised CL\n",
    "    num_partition=4, # num of column partitions for pos/neg sampling\n",
    "    overlap_ratio=0.5, # specify the overlap ratio of column partitions during the CL\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "78d0bc6c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1a6a12cd244e4672b360c68222c7b7f8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:   0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, test val_loss: 5.794664\n",
      "epoch: 0, train loss: 105.4182, lr: 0.000100, spent: 1.1 secs\n",
      "epoch: 1, test val_loss: 5.786065\n",
      "epoch: 1, train loss: 104.5511, lr: 0.000100, spent: 2.0 secs\n",
      "epoch: 2, test val_loss: 5.781867\n",
      "epoch: 2, train loss: 104.5076, lr: 0.000100, spent: 3.0 secs\n",
      "epoch: 3, test val_loss: 5.777907\n",
      "epoch: 3, train loss: 104.4728, lr: 0.000100, spent: 4.1 secs\n",
      "epoch: 4, test val_loss: 5.775703\n",
      "epoch: 4, train loss: 104.4284, lr: 0.000100, spent: 5.0 secs\n",
      "epoch: 5, test val_loss: 5.772933\n",
      "epoch: 5, train loss: 104.4126, lr: 0.000100, spent: 6.0 secs\n",
      "epoch: 6, test val_loss: 5.771537\n",
      "epoch: 6, train loss: 104.3681, lr: 0.000100, spent: 6.9 secs\n",
      "epoch: 7, test val_loss: 5.768374\n",
      "epoch: 7, train loss: 104.3112, lr: 0.000100, spent: 7.8 secs\n",
      "epoch: 8, test val_loss: 5.766492\n",
      "epoch: 8, train loss: 104.3186, lr: 0.000100, spent: 8.8 secs\n",
      "epoch: 9, test val_loss: 5.763317\n",
      "epoch: 9, train loss: 104.2437, lr: 0.000100, spent: 9.7 secs\n",
      "epoch: 10, test val_loss: 5.763273\n",
      "epoch: 10, train loss: 104.2665, lr: 0.000100, spent: 10.8 secs\n",
      "epoch: 11, test val_loss: 5.758865\n",
      "epoch: 11, train loss: 104.2031, lr: 0.000100, spent: 12.0 secs\n",
      "epoch: 12, test val_loss: 5.761363\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 12, train loss: 104.2412, lr: 0.000100, spent: 13.1 secs\n",
      "epoch: 13, test val_loss: 5.760094\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 13, train loss: 104.2192, lr: 0.000100, spent: 14.4 secs\n",
      "epoch: 14, test val_loss: 5.756854\n",
      "epoch: 14, train loss: 104.1880, lr: 0.000100, spent: 15.7 secs\n",
      "epoch: 15, test val_loss: 5.755385\n",
      "epoch: 15, train loss: 104.1087, lr: 0.000100, spent: 17.0 secs\n",
      "epoch: 16, test val_loss: 5.755942\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 16, train loss: 104.1531, lr: 0.000100, spent: 18.3 secs\n",
      "epoch: 17, test val_loss: 5.758205\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 17, train loss: 104.2000, lr: 0.000100, spent: 19.4 secs\n",
      "epoch: 18, test val_loss: 5.748805\n",
      "epoch: 18, train loss: 104.0332, lr: 0.000100, spent: 20.5 secs\n",
      "epoch: 19, test val_loss: 5.748421\n",
      "epoch: 19, train loss: 104.0516, lr: 0.000100, spent: 21.8 secs\n",
      "epoch: 20, test val_loss: 5.749574\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 20, train loss: 104.0346, lr: 0.000100, spent: 22.9 secs\n",
      "epoch: 21, test val_loss: 5.749054\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 21, train loss: 104.0557, lr: 0.000100, spent: 23.9 secs\n",
      "epoch: 22, test val_loss: 5.752270\n",
      "EarlyStopping counter: 3 out of 5\n",
      "epoch: 22, train loss: 104.0468, lr: 0.000100, spent: 25.1 secs\n",
      "epoch: 23, test val_loss: 5.749521\n",
      "EarlyStopping counter: 4 out of 5\n",
      "epoch: 23, train loss: 104.0925, lr: 0.000100, spent: 26.1 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-08-31 10:56:45.227 | INFO     | transtab.trainer:train:132 - load best at last from ./checkpoint\n",
      "2022-08-31 10:56:45.242 | INFO     | transtab.trainer:save_model:239 - saving model checkpoint to ./checkpoint\n",
      "2022-08-31 10:56:45.379 | INFO     | transtab.trainer:train:137 - training complete, cost 27.2 secs.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 24, test val_loss: 5.751015\n",
      "EarlyStopping counter: 5 out of 5\n",
      "early stopped\n"
     ]
    }
   ],
   "source": [
    "# start contrastive pretraining training\n",
    "training_arguments = {\n",
    "    'num_epoch':50,\n",
    "    'batch_size':64,\n",
    "    'lr':1e-4,\n",
    "    'eval_metric':'val_loss',\n",
    "    'eval_less_is_better':True,\n",
    "    'output_dir':'./checkpoint'\n",
    "    }\n",
    "\n",
    "transtab.train(model, trainset, valset, collate_fn=collate_fn, **training_arguments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "85e9ad3c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "########################################\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-08-31 10:56:48.450 | WARNING  | transtab.modeling_transtab:_check_column_overlap:254 - No cat/num/bin cols specified, will take ALL columns as categorical! Ignore this warning if you specify the `checkpoint` to load the model.\n",
      "2022-08-31 10:56:48.527 | INFO     | transtab.modeling_transtab:load:782 - missing keys: ['clf.fc.weight', 'clf.fc.bias', 'clf.norm.weight', 'clf.norm.bias']\n",
      "2022-08-31 10:56:48.528 | INFO     | transtab.modeling_transtab:load:783 - unexpected keys: ['projection_head.dense.weight']\n",
      "2022-08-31 10:56:48.528 | INFO     | transtab.modeling_transtab:load:784 - load model from ./checkpoint\n",
      "2022-08-31 10:56:48.542 | INFO     | transtab.modeling_transtab:load:222 - load feature extractor from ./checkpoint/extractor/extractor.json\n",
      "2022-08-31 10:56:48.556 | WARNING  | transtab.modeling_transtab:_check_column_overlap:254 - No cat/num/bin cols specified, will take ALL columns as categorical! Ignore this warning if you specify the `checkpoint` to load the model.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "openml data index: 29\n",
      "load data from credit-approval\n",
      "# data: 690, # feat: 15, # cate: 9,  # bin: 0, # numerical: 6, pos rate: 0.56\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8bc6cedea8c74fa0a79a6201160b8641",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:   0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, test val_loss: 0.683971\n",
      "epoch: 0, train loss: 5.4453, lr: 0.000100, spent: 0.3 secs\n",
      "epoch: 1, test val_loss: 0.646593\n",
      "epoch: 1, train loss: 5.2291, lr: 0.000100, spent: 0.6 secs\n",
      "epoch: 2, test val_loss: 0.598986\n",
      "epoch: 2, train loss: 4.9122, lr: 0.000100, spent: 0.8 secs\n",
      "epoch: 3, test val_loss: 0.571086\n",
      "epoch: 3, train loss: 4.6084, lr: 0.000100, spent: 1.1 secs\n",
      "epoch: 4, test val_loss: 0.500248\n",
      "epoch: 4, train loss: 4.2688, lr: 0.000100, spent: 1.3 secs\n",
      "epoch: 5, test val_loss: 0.461829\n",
      "epoch: 5, train loss: 3.8759, lr: 0.000100, spent: 1.6 secs\n",
      "epoch: 6, test val_loss: 0.418263\n",
      "epoch: 6, train loss: 3.5448, lr: 0.000100, spent: 1.9 secs\n",
      "epoch: 7, test val_loss: 0.406784\n",
      "epoch: 7, train loss: 3.3226, lr: 0.000100, spent: 2.2 secs\n",
      "epoch: 8, test val_loss: 0.415289\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 8, train loss: 3.2534, lr: 0.000100, spent: 2.5 secs\n",
      "epoch: 9, test val_loss: 0.395700\n",
      "epoch: 9, train loss: 3.1036, lr: 0.000100, spent: 2.7 secs\n",
      "epoch: 10, test val_loss: 0.477691\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 10, train loss: 2.9625, lr: 0.000100, spent: 3.2 secs\n",
      "epoch: 11, test val_loss: 0.394624\n",
      "epoch: 11, train loss: 2.9855, lr: 0.000100, spent: 3.5 secs\n",
      "epoch: 12, test val_loss: 0.395159\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 12, train loss: 3.0646, lr: 0.000100, spent: 3.7 secs\n",
      "epoch: 13, test val_loss: 0.520994\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 13, train loss: 3.0765, lr: 0.000100, spent: 4.0 secs\n",
      "epoch: 14, test val_loss: 0.388927\n",
      "epoch: 14, train loss: 3.0590, lr: 0.000100, spent: 4.3 secs\n",
      "epoch: 15, test val_loss: 0.447461\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 15, train loss: 2.8070, lr: 0.000100, spent: 4.5 secs\n",
      "epoch: 16, test val_loss: 0.402370\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 16, train loss: 2.6713, lr: 0.000100, spent: 4.7 secs\n",
      "epoch: 17, test val_loss: 0.393792\n",
      "EarlyStopping counter: 3 out of 5\n",
      "epoch: 17, train loss: 2.7131, lr: 0.000100, spent: 5.0 secs\n",
      "epoch: 18, test val_loss: 0.455256\n",
      "EarlyStopping counter: 4 out of 5\n",
      "epoch: 18, train loss: 2.7538, lr: 0.000100, spent: 5.2 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-08-31 10:56:53.974 | INFO     | transtab.trainer:train:132 - load best at last from ./checkpoint\n",
      "2022-08-31 10:56:54.000 | INFO     | transtab.trainer:save_model:239 - saving model checkpoint to ./checkpoint\n",
      "2022-08-31 10:56:54.130 | INFO     | transtab.trainer:train:137 - training complete, cost 5.6 secs.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 19, test val_loss: 0.406734\n",
      "EarlyStopping counter: 5 out of 5\n",
      "early stopped\n"
     ]
    }
   ],
   "source": [
    "# load the pretrained model and finetune on a target dataset\n",
    "allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \\\n",
    "     = transtab.load_data('credit-approval')\n",
    "\n",
    "# build transtab classifier model, and load from the pretrained dir\n",
    "model = transtab.build_classifier(checkpoint='./checkpoint')\n",
    "\n",
    "# update model's categorical/numerical/binary column dict\n",
    "model.update({'cat':cat_cols,'num':num_cols,'bin':bin_cols})\n",
    "\n",
    "# start finetuning\n",
    "training_arguments = {\n",
    "    'num_epoch':50,\n",
    "    'eval_metric':'val_loss',\n",
    "    'eval_less_is_better':True,\n",
    "    'output_dir':'./checkpoint'\n",
    "    }\n",
    "transtab.train(model, trainset, valset, **training_arguments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ba5e5238",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "auc 0.95 mean/interval 0.8382(0.06)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.8382272091644043]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# evaluation\n",
    "x_test, y_test = testset\n",
    "ypred = transtab.predict(model, x_test)\n",
    "transtab.evaluate(ypred, y_test, metric='auc')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da5d6d70",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
