{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9aa34ef4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../')\n",
    "\n",
    "import transtab\n",
    "\n",
    "# set random seed\n",
    "transtab.random_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ce7052e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "########################################\n",
      "openml data index: 31\n",
      "load data from credit-g\n",
      "# data: 1000, # feat: 20, # cate: 11,  # bin: 2, # numerical: 7, pos rate: 0.70\n"
     ]
    }
   ],
   "source": [
    "# load a dataset and start vanilla supervised training\n",
    "allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \\\n",
    "    = transtab.load_data('credit-g')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4e709521",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0740c9e1a09844238618d786a971d916",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:   0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, test val_loss: 6.349929\n",
      "epoch: 0, train loss: 72.9975, lr: 0.000100, spent: 1.3 secs\n",
      "epoch: 1, test val_loss: 6.043663\n",
      "epoch: 1, train loss: 62.8806, lr: 0.000100, spent: 2.2 secs\n",
      "epoch: 2, test val_loss: 5.999826\n",
      "epoch: 2, train loss: 61.3078, lr: 0.000100, spent: 3.0 secs\n",
      "epoch: 3, test val_loss: 5.989734\n",
      "epoch: 3, train loss: 61.0470, lr: 0.000100, spent: 3.9 secs\n",
      "epoch: 4, test val_loss: 5.986117\n",
      "epoch: 4, train loss: 60.9742, lr: 0.000100, spent: 4.8 secs\n",
      "epoch: 5, test val_loss: 5.984314\n",
      "epoch: 5, train loss: 60.9454, lr: 0.000100, spent: 5.8 secs\n",
      "epoch: 6, test val_loss: 5.983197\n",
      "epoch: 6, train loss: 60.9270, lr: 0.000100, spent: 6.7 secs\n",
      "epoch: 7, test val_loss: 5.982450\n",
      "epoch: 7, train loss: 60.9164, lr: 0.000100, spent: 7.6 secs\n",
      "epoch: 8, test val_loss: 5.981885\n",
      "epoch: 8, train loss: 60.9102, lr: 0.000100, spent: 8.5 secs\n",
      "epoch: 9, test val_loss: 5.981443\n",
      "epoch: 9, train loss: 60.9047, lr: 0.000100, spent: 9.5 secs\n",
      "epoch: 10, test val_loss: 5.981087\n",
      "epoch: 10, train loss: 60.9004, lr: 0.000100, spent: 10.3 secs\n",
      "epoch: 11, test val_loss: 5.980795\n",
      "epoch: 11, train loss: 60.8956, lr: 0.000100, spent: 11.3 secs\n",
      "epoch: 12, test val_loss: 5.980557\n",
      "epoch: 12, train loss: 60.8925, lr: 0.000100, spent: 12.3 secs\n",
      "epoch: 13, test val_loss: 5.980357\n",
      "epoch: 13, train loss: 60.8902, lr: 0.000100, spent: 13.3 secs\n",
      "epoch: 14, test val_loss: 5.980191\n",
      "epoch: 14, train loss: 60.8874, lr: 0.000100, spent: 14.5 secs\n",
      "epoch: 15, test val_loss: 5.980050\n",
      "epoch: 15, train loss: 60.8863, lr: 0.000100, spent: 15.5 secs\n",
      "epoch: 16, test val_loss: 5.979930\n",
      "epoch: 16, train loss: 60.8836, lr: 0.000100, spent: 16.4 secs\n",
      "epoch: 17, test val_loss: 5.979825\n",
      "epoch: 17, train loss: 60.8822, lr: 0.000100, spent: 17.3 secs\n",
      "epoch: 18, test val_loss: 5.979736\n",
      "epoch: 18, train loss: 60.8821, lr: 0.000100, spent: 18.2 secs\n",
      "epoch: 19, test val_loss: 5.979657\n",
      "epoch: 19, train loss: 60.8804, lr: 0.000100, spent: 19.2 secs\n",
      "epoch: 20, test val_loss: 5.979586\n",
      "epoch: 20, train loss: 60.8802, lr: 0.000100, spent: 20.3 secs\n",
      "epoch: 21, test val_loss: 5.979523\n",
      "epoch: 21, train loss: 60.8798, lr: 0.000100, spent: 21.3 secs\n",
      "epoch: 22, test val_loss: 5.979466\n",
      "epoch: 22, train loss: 60.8791, lr: 0.000100, spent: 22.2 secs\n",
      "epoch: 23, test val_loss: 5.979416\n",
      "epoch: 23, train loss: 60.8778, lr: 0.000100, spent: 23.2 secs\n",
      "epoch: 24, test val_loss: 5.979372\n",
      "epoch: 24, train loss: 60.8776, lr: 0.000100, spent: 24.2 secs\n",
      "epoch: 25, test val_loss: 5.979331\n",
      "epoch: 25, train loss: 60.8773, lr: 0.000100, spent: 25.1 secs\n",
      "epoch: 26, test val_loss: 5.979294\n",
      "epoch: 26, train loss: 60.8763, lr: 0.000100, spent: 26.0 secs\n",
      "epoch: 27, test val_loss: 5.979260\n",
      "epoch: 27, train loss: 60.8761, lr: 0.000100, spent: 27.0 secs\n",
      "epoch: 28, test val_loss: 5.979229\n",
      "epoch: 28, train loss: 60.8761, lr: 0.000100, spent: 27.9 secs\n",
      "epoch: 29, test val_loss: 5.979202\n",
      "epoch: 29, train loss: 60.8752, lr: 0.000100, spent: 28.9 secs\n",
      "epoch: 30, test val_loss: 5.979175\n",
      "epoch: 30, train loss: 60.8755, lr: 0.000100, spent: 29.8 secs\n",
      "epoch: 31, test val_loss: 5.979153\n",
      "epoch: 31, train loss: 60.8744, lr: 0.000100, spent: 30.8 secs\n",
      "epoch: 32, test val_loss: 5.979130\n",
      "epoch: 32, train loss: 60.8744, lr: 0.000100, spent: 31.6 secs\n",
      "epoch: 33, test val_loss: 5.979110\n",
      "epoch: 33, train loss: 60.8743, lr: 0.000100, spent: 32.4 secs\n",
      "epoch: 34, test val_loss: 5.979090\n",
      "epoch: 34, train loss: 60.8736, lr: 0.000100, spent: 33.4 secs\n",
      "epoch: 35, test val_loss: 5.979072\n",
      "epoch: 35, train loss: 60.8720, lr: 0.000100, spent: 34.3 secs\n",
      "epoch: 36, test val_loss: 5.979054\n",
      "epoch: 36, train loss: 60.8724, lr: 0.000100, spent: 35.2 secs\n",
      "epoch: 37, test val_loss: 5.979037\n",
      "epoch: 37, train loss: 60.8735, lr: 0.000100, spent: 36.2 secs\n",
      "epoch: 38, test val_loss: 5.979021\n",
      "epoch: 38, train loss: 60.8723, lr: 0.000100, spent: 36.9 secs\n",
      "epoch: 39, test val_loss: 5.979005\n",
      "epoch: 39, train loss: 60.8726, lr: 0.000100, spent: 37.8 secs\n",
      "epoch: 40, test val_loss: 5.978991\n",
      "epoch: 40, train loss: 60.8719, lr: 0.000100, spent: 38.5 secs\n",
      "epoch: 41, test val_loss: 5.978974\n",
      "epoch: 41, train loss: 60.8720, lr: 0.000100, spent: 39.3 secs\n",
      "epoch: 42, test val_loss: 5.978961\n",
      "epoch: 42, train loss: 60.8717, lr: 0.000100, spent: 40.1 secs\n",
      "epoch: 43, test val_loss: 5.978946\n",
      "epoch: 43, train loss: 60.8721, lr: 0.000100, spent: 40.9 secs\n",
      "epoch: 44, test val_loss: 5.978931\n",
      "epoch: 44, train loss: 60.8710, lr: 0.000100, spent: 41.8 secs\n",
      "epoch: 45, test val_loss: 5.978916\n",
      "epoch: 45, train loss: 60.8711, lr: 0.000100, spent: 42.7 secs\n",
      "epoch: 46, test val_loss: 5.978899\n",
      "epoch: 46, train loss: 60.8713, lr: 0.000100, spent: 43.6 secs\n",
      "epoch: 47, test val_loss: 5.978884\n",
      "epoch: 47, train loss: 60.8702, lr: 0.000100, spent: 44.6 secs\n",
      "epoch: 48, test val_loss: 5.978869\n",
      "epoch: 48, train loss: 60.8705, lr: 0.000100, spent: 45.7 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-08-31 14:15:16.839 | INFO     | transtab.trainer:train:132 - load best at last from ./checkpoint\n",
      "2022-08-31 14:15:16.853 | INFO     | transtab.trainer:save_model:239 - saving model checkpoint to ./checkpoint\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 49, test val_loss: 5.978854\n",
      "epoch: 49, train loss: 60.8699, lr: 0.000100, spent: 46.8 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-08-31 14:15:17.035 | INFO     | transtab.trainer:train:137 - training complete, cost 47.0 secs.\n"
     ]
    }
   ],
   "source": [
    "# make a fast pre-train of TransTab contrastive learning model\n",
    "# build contrastive learner, set supervised=True for supervised VPCL\n",
    "model, collate_fn = transtab.build_contrastive_learner(\n",
    "    cat_cols, num_cols, bin_cols, \n",
    "    supervised=True, # if take supervised CL\n",
    "    num_partition=4, # num of column partitions for pos/neg sampling\n",
    "    overlap_ratio=0.5, # specify the overlap ratio of column partitions during the CL\n",
    ")\n",
    "\n",
    "# start contrastive pretraining training\n",
    "training_arguments = {\n",
    "    'num_epoch':50,\n",
    "    'batch_size':64,\n",
    "    'lr':1e-4,\n",
    "    'eval_metric':'val_loss',\n",
    "    'eval_less_is_better':True,\n",
    "    'output_dir':'./checkpoint'\n",
    "    }\n",
    "\n",
    "transtab.train(model, trainset, valset, collate_fn=collate_fn, **training_arguments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5c87e48b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-08-31 14:15:17.125 | INFO     | transtab.modeling_transtab:load:773 - missing keys: []\n",
      "2022-08-31 14:15:17.126 | INFO     | transtab.modeling_transtab:load:774 - unexpected keys: ['projection_head.dense.weight']\n",
      "2022-08-31 14:15:17.126 | INFO     | transtab.modeling_transtab:load:775 - load model from ./checkpoint\n",
      "2022-08-31 14:15:17.159 | INFO     | transtab.modeling_transtab:load:222 - load feature extractor from ./checkpoint/extractor/extractor.json\n"
     ]
    }
   ],
   "source": [
    "# There are two ways to build the encoder\n",
    "# First, take the whole pretrained model and output the cls token embedding at the last layer's outputs\n",
    "enc = transtab.build_encoder(\n",
    "    binary_columns=bin_cols,\n",
    "    checkpoint = './checkpoint'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b8149cfa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([700, 128])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.2959e+00,  1.5239e+00, -1.2096e+00,  3.0303e-01,  7.4638e-01,\n",
       "          1.1758e+00,  1.1774e+00, -2.1921e-01,  4.2850e-01,  8.3295e-03,\n",
       "         -5.3477e-01,  1.4859e+00, -2.0534e+00, -9.4093e-01,  3.7010e-01,\n",
       "          1.3663e-01,  4.4837e-01,  1.3882e+00,  1.6472e+00, -1.2430e+00,\n",
       "         -4.8809e-01, -5.1914e-01, -3.3168e-01,  1.9889e+00, -4.9873e-01,\n",
       "          1.2286e+00,  8.6373e-01,  5.1300e-01,  6.7551e-01, -1.2021e+00,\n",
       "          6.3210e-01,  6.2366e-01,  5.6712e-01,  1.2275e-03, -1.5154e+00,\n",
       "          2.0082e+00, -1.2255e+00, -2.4254e-01, -5.1009e-01,  1.6733e+00,\n",
       "         -1.2059e+00, -7.0246e-01,  1.8980e-01, -7.8196e-01,  1.0777e+00,\n",
       "         -6.1830e-01, -1.1279e+00, -1.3290e+00,  9.6929e-01, -7.6388e-02,\n",
       "         -4.5835e-01, -1.1462e+00,  1.5084e+00,  5.7778e-01,  2.0644e-01,\n",
       "          4.3633e-01,  7.6116e-03,  5.2441e-01, -1.9919e-01, -1.9441e-01,\n",
       "          1.8144e+00,  2.7863e-01, -1.8727e+00, -9.4760e-01,  1.1152e+00,\n",
       "          3.5514e-01,  1.6321e+00,  4.3554e-01,  6.1438e-01,  2.2991e-01,\n",
       "          2.3567e-01,  1.0738e+00, -1.0689e+00,  1.1454e+00, -2.9430e-01,\n",
       "         -7.8866e-01,  1.7377e-01,  4.7786e-01, -1.1535e+00, -1.9210e+00,\n",
       "          5.6469e-01, -4.9142e-02, -6.4016e-01, -3.3013e-01, -3.1188e-01,\n",
       "         -7.4673e-01, -3.0021e-01, -2.0609e+00,  7.0935e-01, -6.6764e-01,\n",
       "          6.4810e-01, -8.1043e-02, -1.0044e+00, -2.1534e+00, -1.4149e+00,\n",
       "         -7.6418e-01,  1.9660e+00, -1.0766e+00, -5.2616e-01, -1.2752e+00,\n",
       "          1.1527e+00,  2.2518e-01,  1.7696e-01,  8.3931e-01, -3.5717e-01,\n",
       "          1.4251e-01,  1.6778e+00, -1.5331e+00, -1.5316e+00, -7.3143e-01,\n",
       "         -2.6362e-01, -5.3092e-01,  1.1220e+00,  9.4099e-01, -1.3653e+00,\n",
       "         -5.5385e-01, -2.5665e-01, -3.1621e-01, -1.3123e+00, -9.7127e-02,\n",
       "         -4.2603e-01,  1.8091e+00, -7.5452e-01,  1.9514e+00,  7.2433e-03,\n",
       "          3.7320e-02,  5.3549e-01, -3.9535e-01],\n",
       "        [ 1.4275e+00,  1.4772e+00, -1.1928e+00,  1.8642e-01,  8.1510e-01,\n",
       "          1.2602e+00,  1.2150e+00, -2.1353e-01,  3.9298e-01, -1.8265e-01,\n",
       "         -5.9739e-01,  1.2885e+00, -2.1044e+00, -1.0534e+00,  4.8087e-01,\n",
       "          1.2070e-01,  3.0839e-01,  1.2873e+00,  1.6255e+00, -1.0916e+00,\n",
       "         -3.2920e-01, -2.7017e-01, -3.4054e-01,  2.0612e+00, -6.5718e-01,\n",
       "          1.1547e+00,  9.0340e-01,  5.3138e-01,  7.4846e-01, -1.1599e+00,\n",
       "          6.1057e-01,  6.2320e-01,  6.3401e-01, -7.8121e-02, -1.5336e+00,\n",
       "          1.8799e+00, -1.4002e+00, -3.4578e-01, -8.7409e-01,  1.7005e+00,\n",
       "         -1.2923e+00, -5.9172e-01,  8.2113e-02, -7.6255e-01,  9.8186e-01,\n",
       "         -5.2740e-01, -1.1055e+00, -1.3655e+00,  8.0880e-01,  6.8788e-02,\n",
       "         -5.1715e-01, -1.2682e+00,  1.6060e+00,  5.9163e-01,  3.5197e-01,\n",
       "          6.1037e-01,  1.6449e-01,  4.7828e-01, -2.3575e-01, -2.4127e-01,\n",
       "          1.8397e+00,  3.7601e-01, -1.9676e+00, -9.4222e-01,  1.1711e+00,\n",
       "          3.2122e-01,  1.7164e+00,  4.7828e-01,  7.2740e-01,  2.1730e-01,\n",
       "          2.0191e-01,  7.4816e-01, -1.1957e+00,  1.2826e+00, -3.4407e-01,\n",
       "         -8.6727e-01,  1.4943e-01,  5.4311e-01, -1.1209e+00, -1.8852e+00,\n",
       "          5.8967e-01, -2.3814e-01, -6.1390e-01, -2.7548e-01, -2.5533e-01,\n",
       "         -8.5195e-01, -2.3613e-01, -1.9835e+00,  5.6644e-01, -5.9843e-01,\n",
       "          6.8693e-01,  3.4524e-02, -1.0214e+00, -1.8806e+00, -1.4108e+00,\n",
       "         -7.1087e-01,  1.9959e+00, -1.2109e+00, -6.3984e-01, -9.7635e-01,\n",
       "          1.1544e+00,  2.3031e-01,  2.3562e-01,  6.8024e-01, -2.9665e-01,\n",
       "          1.2141e-01,  1.7590e+00, -1.4833e+00, -1.4007e+00, -9.1892e-01,\n",
       "         -1.3863e-01, -3.3393e-01,  1.0803e+00,  1.0124e+00, -1.4227e+00,\n",
       "         -6.2524e-01, -1.6816e-01, -4.6652e-01, -1.3414e+00, -1.7069e-01,\n",
       "         -2.8513e-01,  1.7853e+00, -9.1653e-01,  1.7702e+00,  2.3768e-01,\n",
       "          9.3338e-02,  5.9862e-01, -3.1038e-01]], device='cuda:0',\n",
       "       grad_fn=<SliceBackward0>)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Then take the encoder to get the input embedding\n",
    "df = trainset[0]\n",
    "output = enc(df)\n",
    "print(output.shape)\n",
    "output[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4aadae44",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>own_telephone</th>\n",
       "      <th>foreign_worker</th>\n",
       "      <th>duration</th>\n",
       "      <th>credit_amount</th>\n",
       "      <th>installment_commitment</th>\n",
       "      <th>residence_since</th>\n",
       "      <th>age</th>\n",
       "      <th>existing_credits</th>\n",
       "      <th>num_dependents</th>\n",
       "      <th>checking_status</th>\n",
       "      <th>credit_history</th>\n",
       "      <th>purpose</th>\n",
       "      <th>savings_status</th>\n",
       "      <th>employment</th>\n",
       "      <th>personal_status</th>\n",
       "      <th>other_parties</th>\n",
       "      <th>property_magnitude</th>\n",
       "      <th>other_payment_plans</th>\n",
       "      <th>housing</th>\n",
       "      <th>job</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>636</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.294118</td>\n",
       "      <td>0.061957</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.160714</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>no checking</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>radio/tv</td>\n",
       "      <td>500&lt;=X&lt;1000</td>\n",
       "      <td>4&lt;=X&lt;7</td>\n",
       "      <td>female div/dep/mar</td>\n",
       "      <td>none</td>\n",
       "      <td>car</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>182</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.250000</td>\n",
       "      <td>0.076868</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.375000</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>1.0</td>\n",
       "      <td>&lt;0</td>\n",
       "      <td>all paid</td>\n",
       "      <td>new car</td>\n",
       "      <td>no known savings</td>\n",
       "      <td>1&lt;=X&lt;4</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>life insurance</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>unskilled resident</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>736</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.294118</td>\n",
       "      <td>0.622318</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.071429</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0&lt;=X&lt;200</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>used car</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>1&lt;=X&lt;4</td>\n",
       "      <td>female div/dep/mar</td>\n",
       "      <td>none</td>\n",
       "      <td>car</td>\n",
       "      <td>none</td>\n",
       "      <td>rent</td>\n",
       "      <td>high qualif/self emp/mgmt</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>922</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.073529</td>\n",
       "      <td>0.061406</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.053571</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>&lt;0</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>radio/tv</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>&lt;1</td>\n",
       "      <td>female div/dep/mar</td>\n",
       "      <td>none</td>\n",
       "      <td>life insurance</td>\n",
       "      <td>none</td>\n",
       "      <td>rent</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>511</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.470588</td>\n",
       "      <td>0.244085</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.232143</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>no checking</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>used car</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>1&lt;=X&lt;4</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>no known property</td>\n",
       "      <td>none</td>\n",
       "      <td>for free</td>\n",
       "      <td>high qualif/self emp/mgmt</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     own_telephone  foreign_worker  duration  credit_amount  \\\n",
       "636              0               1  0.294118       0.061957   \n",
       "182              0               1  0.250000       0.076868   \n",
       "736              0               1  0.294118       0.622318   \n",
       "922              0               1  0.073529       0.061406   \n",
       "511              1               1  0.470588       0.244085   \n",
       "\n",
       "     installment_commitment  residence_since       age  existing_credits  \\\n",
       "636                1.000000         0.000000  0.160714          0.000000   \n",
       "182                1.000000         0.333333  0.375000          0.333333   \n",
       "736                0.000000         1.000000  0.071429          0.333333   \n",
       "922                0.666667         1.000000  0.053571          0.000000   \n",
       "511                0.333333         0.333333  0.232143          0.000000   \n",
       "\n",
       "     num_dependents checking_status credit_history   purpose  \\\n",
       "636             0.0     no checking  existing paid  radio/tv   \n",
       "182             1.0              <0       all paid   new car   \n",
       "736             0.0        0<=X<200  existing paid  used car   \n",
       "922             0.0              <0  existing paid  radio/tv   \n",
       "511             0.0     no checking  existing paid  used car   \n",
       "\n",
       "       savings_status employment     personal_status other_parties  \\\n",
       "636       500<=X<1000     4<=X<7  female div/dep/mar          none   \n",
       "182  no known savings     1<=X<4         male single          none   \n",
       "736              <100     1<=X<4  female div/dep/mar          none   \n",
       "922              <100         <1  female div/dep/mar          none   \n",
       "511              <100     1<=X<4         male single          none   \n",
       "\n",
       "    property_magnitude other_payment_plans   housing  \\\n",
       "636                car                none       own   \n",
       "182     life insurance                none       own   \n",
       "736                car                none      rent   \n",
       "922     life insurance                none      rent   \n",
       "511  no known property                none  for free   \n",
       "\n",
       "                           job  \n",
       "636                    skilled  \n",
       "182         unskilled resident  \n",
       "736  high qualif/self emp/mgmt  \n",
       "922                    skilled  \n",
       "511  high qualif/self emp/mgmt  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4f3e1e91",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-08-31 14:16:28.124 | INFO     | transtab.modeling_transtab:load:222 - load feature extractor from ./checkpoint/extractor/extractor.json\n",
      "2022-08-31 14:16:28.134 | INFO     | transtab.modeling_transtab:load:523 - missing keys: []\n",
      "2022-08-31 14:16:28.135 | INFO     | transtab.modeling_transtab:load:524 - unexpected keys: []\n",
      "2022-08-31 14:16:28.136 | INFO     | transtab.modeling_transtab:load:525 - load model from ./checkpoint\n"
     ]
    }
   ],
   "source": [
    "# Second, if we only want to the embeded token level embeddings (embeddings before going to transformers)\n",
    "enc = transtab.build_encoder(\n",
    "    binary_columns=bin_cols,\n",
    "    checkpoint = './checkpoint',\n",
    "    num_layer = 0,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "39a0172b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([700, 85, 128])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[[ 0.1370,  0.0427, -0.0106,  ..., -0.0806,  0.0518, -0.1315],\n",
       "         [ 0.0657,  0.0341, -0.0128,  ..., -0.0207,  0.0102, -0.0046],\n",
       "         [ 0.1494,  0.4290,  0.2463,  ...,  0.1992, -0.0848, -0.0840],\n",
       "         ...,\n",
       "         [ 1.1575,  0.0165,  0.9202,  ..., -0.2052,  1.0815, -1.0268],\n",
       "         [ 1.1575,  0.0165,  0.9202,  ..., -0.2052,  1.0815, -1.0268],\n",
       "         [ 1.1575,  0.0165,  0.9202,  ..., -0.2052,  1.0815, -1.0268]],\n",
       "\n",
       "        [[ 0.1204,  0.0388, -0.0098,  ..., -0.0738,  0.0400, -0.1099],\n",
       "         [ 0.0752,  0.0383, -0.0145,  ..., -0.0174,  0.0190, -0.0085],\n",
       "         [ 0.1494,  0.4290,  0.2463,  ...,  0.1992, -0.0848, -0.0840],\n",
       "         ...,\n",
       "         [ 1.1575,  0.0165,  0.9202,  ..., -0.2052,  1.0815, -1.0268],\n",
       "         [ 1.1575,  0.0165,  0.9202,  ..., -0.2052,  1.0815, -1.0268],\n",
       "         [ 1.1575,  0.0165,  0.9202,  ..., -0.2052,  1.0815, -1.0268]]],\n",
       "       device='cuda:0', grad_fn=<SliceBackward0>)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output = enc(df)\n",
    "print(output['embedding'].shape)\n",
    "output['embedding'][:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55936f1e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
