{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0bc8ef17",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../')\n",
    "\n",
    "import transtab\n",
    "\n",
    "# set random seed\n",
    "transtab.random_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e06b2eb3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "########################################\n",
      "openml data index: 31\n",
      "load data from credit-g\n",
      "# data: 1000, # feat: 20, # cate: 11,  # bin: 2, # numerical: 7, pos rate: 0.70\n",
      "########################################\n",
      "openml data index: 29\n",
      "load data from credit-approval\n",
      "# data: 690, # feat: 15, # cate: 9,  # bin: 0, # numerical: 6, pos rate: 0.56\n"
     ]
    }
   ],
   "source": [
    "# load multiple datasets by passing a list of data names\n",
    "allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \\\n",
    "    = transtab.load_data(['credit-g','credit-approval'])\n",
    "\n",
    "# build transtab classifier model\n",
    "model = transtab.build_classifier(cat_cols, num_cols, bin_cols)\n",
    "\n",
    "# specify training arguments, take validation loss for early stopping\n",
    "training_arguments = {\n",
    "    'num_epoch':50,\n",
    "    'batch_size':128,\n",
    "    'lr':1e-4,\n",
    "    'eval_metric':'val_loss',\n",
    "    'eval_less_is_better':True,\n",
    "    'output_dir':'./checkpoint'\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f0c84e5f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>own_telephone</th>\n",
       "      <th>foreign_worker</th>\n",
       "      <th>duration</th>\n",
       "      <th>credit_amount</th>\n",
       "      <th>installment_commitment</th>\n",
       "      <th>residence_since</th>\n",
       "      <th>age</th>\n",
       "      <th>existing_credits</th>\n",
       "      <th>num_dependents</th>\n",
       "      <th>checking_status</th>\n",
       "      <th>credit_history</th>\n",
       "      <th>purpose</th>\n",
       "      <th>savings_status</th>\n",
       "      <th>employment</th>\n",
       "      <th>personal_status</th>\n",
       "      <th>other_parties</th>\n",
       "      <th>property_magnitude</th>\n",
       "      <th>other_payment_plans</th>\n",
       "      <th>housing</th>\n",
       "      <th>job</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>636</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.294118</td>\n",
       "      <td>0.061957</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.160714</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>no checking</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>radio/tv</td>\n",
       "      <td>500&lt;=X&lt;1000</td>\n",
       "      <td>4&lt;=X&lt;7</td>\n",
       "      <td>female div/dep/mar</td>\n",
       "      <td>none</td>\n",
       "      <td>car</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>182</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.250000</td>\n",
       "      <td>0.076868</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.375000</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>1.0</td>\n",
       "      <td>&lt;0</td>\n",
       "      <td>all paid</td>\n",
       "      <td>new car</td>\n",
       "      <td>no known savings</td>\n",
       "      <td>1&lt;=X&lt;4</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>life insurance</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>unskilled resident</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>736</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.294118</td>\n",
       "      <td>0.622318</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.071429</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0&lt;=X&lt;200</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>used car</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>1&lt;=X&lt;4</td>\n",
       "      <td>female div/dep/mar</td>\n",
       "      <td>none</td>\n",
       "      <td>car</td>\n",
       "      <td>none</td>\n",
       "      <td>rent</td>\n",
       "      <td>high qualif/self emp/mgmt</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>922</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.073529</td>\n",
       "      <td>0.061406</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.053571</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>&lt;0</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>radio/tv</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>&lt;1</td>\n",
       "      <td>female div/dep/mar</td>\n",
       "      <td>none</td>\n",
       "      <td>life insurance</td>\n",
       "      <td>none</td>\n",
       "      <td>rent</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>511</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.470588</td>\n",
       "      <td>0.244085</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.232143</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>no checking</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>used car</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>1&lt;=X&lt;4</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>no known property</td>\n",
       "      <td>none</td>\n",
       "      <td>for free</td>\n",
       "      <td>high qualif/self emp/mgmt</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>845</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.250000</td>\n",
       "      <td>0.205018</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>0.285714</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0&lt;=X&lt;200</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>furniture/equipment</td>\n",
       "      <td>no known savings</td>\n",
       "      <td>4&lt;=X&lt;7</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>car</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>492</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.029412</td>\n",
       "      <td>0.054308</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.142857</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.0</td>\n",
       "      <td>no checking</td>\n",
       "      <td>critical/other existing credit</td>\n",
       "      <td>radio/tv</td>\n",
       "      <td>100&lt;=X&lt;500</td>\n",
       "      <td>1&lt;=X&lt;4</td>\n",
       "      <td>female div/dep/mar</td>\n",
       "      <td>none</td>\n",
       "      <td>life insurance</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>849</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.117647</td>\n",
       "      <td>0.025256</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.678571</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>&lt;0</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>radio/tv</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>&gt;=7</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>real estate</td>\n",
       "      <td>stores</td>\n",
       "      <td>own</td>\n",
       "      <td>unskilled resident</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>297</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.088235</td>\n",
       "      <td>0.057060</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.464286</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>no checking</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>new car</td>\n",
       "      <td>no known savings</td>\n",
       "      <td>&gt;=7</td>\n",
       "      <td>male single</td>\n",
       "      <td>co applicant</td>\n",
       "      <td>life insurance</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>unskilled resident</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.470588</td>\n",
       "      <td>0.114834</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.303571</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0&lt;=X&lt;200</td>\n",
       "      <td>critical/other existing credit</td>\n",
       "      <td>radio/tv</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>&gt;=7</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>real estate</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>700 rows × 20 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     own_telephone  foreign_worker  duration  credit_amount  \\\n",
       "636              0               1  0.294118       0.061957   \n",
       "182              0               1  0.250000       0.076868   \n",
       "736              0               1  0.294118       0.622318   \n",
       "922              0               1  0.073529       0.061406   \n",
       "511              1               1  0.470588       0.244085   \n",
       "..             ...             ...       ...            ...   \n",
       "845              1               1  0.250000       0.205018   \n",
       "492              0               1  0.029412       0.054308   \n",
       "849              0               1  0.117647       0.025256   \n",
       "297              0               0  0.088235       0.057060   \n",
       "98               0               1  0.470588       0.114834   \n",
       "\n",
       "     installment_commitment  residence_since       age  existing_credits  \\\n",
       "636                1.000000         0.000000  0.160714          0.000000   \n",
       "182                1.000000         0.333333  0.375000          0.333333   \n",
       "736                0.000000         1.000000  0.071429          0.333333   \n",
       "922                0.666667         1.000000  0.053571          0.000000   \n",
       "511                0.333333         0.333333  0.232143          0.000000   \n",
       "..                      ...              ...       ...               ...   \n",
       "845                0.333333         0.666667  0.285714          0.000000   \n",
       "492                0.000000         0.000000  0.142857          0.333333   \n",
       "849                1.000000         1.000000  0.678571          0.000000   \n",
       "297                1.000000         0.333333  0.464286          0.000000   \n",
       "98                 1.000000         1.000000  0.303571          0.000000   \n",
       "\n",
       "     num_dependents checking_status                  credit_history  \\\n",
       "636             0.0     no checking                   existing paid   \n",
       "182             1.0              <0                        all paid   \n",
       "736             0.0        0<=X<200                   existing paid   \n",
       "922             0.0              <0                   existing paid   \n",
       "511             0.0     no checking                   existing paid   \n",
       "..              ...             ...                             ...   \n",
       "845             0.0        0<=X<200                   existing paid   \n",
       "492             0.0     no checking  critical/other existing credit   \n",
       "849             0.0              <0                   existing paid   \n",
       "297             0.0     no checking                   existing paid   \n",
       "98              0.0        0<=X<200  critical/other existing credit   \n",
       "\n",
       "                 purpose    savings_status employment     personal_status  \\\n",
       "636             radio/tv       500<=X<1000     4<=X<7  female div/dep/mar   \n",
       "182              new car  no known savings     1<=X<4         male single   \n",
       "736             used car              <100     1<=X<4  female div/dep/mar   \n",
       "922             radio/tv              <100         <1  female div/dep/mar   \n",
       "511             used car              <100     1<=X<4         male single   \n",
       "..                   ...               ...        ...                 ...   \n",
       "845  furniture/equipment  no known savings     4<=X<7         male single   \n",
       "492             radio/tv        100<=X<500     1<=X<4  female div/dep/mar   \n",
       "849             radio/tv              <100        >=7         male single   \n",
       "297              new car  no known savings        >=7         male single   \n",
       "98              radio/tv              <100        >=7         male single   \n",
       "\n",
       "    other_parties property_magnitude other_payment_plans   housing  \\\n",
       "636          none                car                none       own   \n",
       "182          none     life insurance                none       own   \n",
       "736          none                car                none      rent   \n",
       "922          none     life insurance                none      rent   \n",
       "511          none  no known property                none  for free   \n",
       "..            ...                ...                 ...       ...   \n",
       "845          none                car                none       own   \n",
       "492          none     life insurance                none       own   \n",
       "849          none        real estate              stores       own   \n",
       "297  co applicant     life insurance                none       own   \n",
       "98           none        real estate                none       own   \n",
       "\n",
       "                           job  \n",
       "636                    skilled  \n",
       "182         unskilled resident  \n",
       "736  high qualif/self emp/mgmt  \n",
       "922                    skilled  \n",
       "511  high qualif/self emp/mgmt  \n",
       "..                         ...  \n",
       "845                    skilled  \n",
       "492                    skilled  \n",
       "849         unskilled resident  \n",
       "297         unskilled resident  \n",
       "98                     skilled  \n",
       "\n",
       "[700 rows x 20 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainset[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "058f667e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>own_telephone</th>\n",
       "      <th>foreign_worker</th>\n",
       "      <th>duration</th>\n",
       "      <th>credit_amount</th>\n",
       "      <th>installment_commitment</th>\n",
       "      <th>residence_since</th>\n",
       "      <th>age</th>\n",
       "      <th>existing_credits</th>\n",
       "      <th>num_dependents</th>\n",
       "      <th>checking_status</th>\n",
       "      <th>credit_history</th>\n",
       "      <th>purpose</th>\n",
       "      <th>savings_status</th>\n",
       "      <th>employment</th>\n",
       "      <th>personal_status</th>\n",
       "      <th>other_parties</th>\n",
       "      <th>property_magnitude</th>\n",
       "      <th>other_payment_plans</th>\n",
       "      <th>housing</th>\n",
       "      <th>job</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.205882</td>\n",
       "      <td>0.309013</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.196429</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0&lt;=X&lt;200</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>new car</td>\n",
       "      <td>100&lt;=X&lt;500</td>\n",
       "      <td>1&lt;=X&lt;4</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>car</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>924</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.294118</td>\n",
       "      <td>0.364367</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.642857</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>&lt;0</td>\n",
       "      <td>all paid</td>\n",
       "      <td>furniture/equipment</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>&lt;1</td>\n",
       "      <td>male div/sep</td>\n",
       "      <td>none</td>\n",
       "      <td>life insurance</td>\n",
       "      <td>bank</td>\n",
       "      <td>own</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>931</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.073529</td>\n",
       "      <td>0.078134</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.053571</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0&lt;=X&lt;200</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>radio/tv</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>&lt;1</td>\n",
       "      <td>female div/dep/mar</td>\n",
       "      <td>none</td>\n",
       "      <td>car</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>796</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.205882</td>\n",
       "      <td>0.399527</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.571429</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>&lt;0</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>used car</td>\n",
       "      <td>no known savings</td>\n",
       "      <td>&gt;=7</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>life insurance</td>\n",
       "      <td>none</td>\n",
       "      <td>for free</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>226</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.647059</td>\n",
       "      <td>0.589358</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.142857</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0&lt;=X&lt;200</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>radio/tv</td>\n",
       "      <td>&gt;=1000</td>\n",
       "      <td>4&lt;=X&lt;7</td>\n",
       "      <td>male single</td>\n",
       "      <td>co applicant</td>\n",
       "      <td>no known property</td>\n",
       "      <td>bank</td>\n",
       "      <td>own</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>380</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.235294</td>\n",
       "      <td>0.107956</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.357143</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>&lt;0</td>\n",
       "      <td>existing paid</td>\n",
       "      <td>furniture/equipment</td>\n",
       "      <td>no known savings</td>\n",
       "      <td>4&lt;=X&lt;7</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>car</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>768</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.117647</td>\n",
       "      <td>0.185265</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.160714</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0&lt;=X&lt;200</td>\n",
       "      <td>critical/other existing credit</td>\n",
       "      <td>furniture/equipment</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>&gt;=7</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>car</td>\n",
       "      <td>none</td>\n",
       "      <td>rent</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>85</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.117647</td>\n",
       "      <td>0.063937</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.178571</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.0</td>\n",
       "      <td>no checking</td>\n",
       "      <td>critical/other existing credit</td>\n",
       "      <td>business</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>1&lt;=X&lt;4</td>\n",
       "      <td>female div/dep/mar</td>\n",
       "      <td>guarantor</td>\n",
       "      <td>real estate</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>high qualif/self emp/mgmt</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>527</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.068945</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.410714</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>1.0</td>\n",
       "      <td>no checking</td>\n",
       "      <td>critical/other existing credit</td>\n",
       "      <td>radio/tv</td>\n",
       "      <td>&lt;100</td>\n",
       "      <td>4&lt;=X&lt;7</td>\n",
       "      <td>male single</td>\n",
       "      <td>none</td>\n",
       "      <td>real estate</td>\n",
       "      <td>none</td>\n",
       "      <td>own</td>\n",
       "      <td>unskilled resident</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>117</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.088235</td>\n",
       "      <td>0.103555</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>0.142857</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.0</td>\n",
       "      <td>&lt;0</td>\n",
       "      <td>critical/other existing credit</td>\n",
       "      <td>furniture/equipment</td>\n",
       "      <td>no known savings</td>\n",
       "      <td>&lt;1</td>\n",
       "      <td>female div/dep/mar</td>\n",
       "      <td>co applicant</td>\n",
       "      <td>real estate</td>\n",
       "      <td>none</td>\n",
       "      <td>rent</td>\n",
       "      <td>skilled</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>100 rows × 20 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     own_telephone  foreign_worker  duration  credit_amount  \\\n",
       "32               1               1  0.205882       0.309013   \n",
       "924              1               1  0.294118       0.364367   \n",
       "931              1               1  0.073529       0.078134   \n",
       "796              1               1  0.205882       0.399527   \n",
       "226              1               1  0.647059       0.589358   \n",
       "..             ...             ...       ...            ...   \n",
       "380              1               1  0.235294       0.107956   \n",
       "768              1               1  0.117647       0.185265   \n",
       "85               1               1  0.117647       0.063937   \n",
       "527              0               1  0.000000       0.068945   \n",
       "117              0               0  0.088235       0.103555   \n",
       "\n",
       "     installment_commitment  residence_since       age  existing_credits  \\\n",
       "32                 0.333333         0.333333  0.196429          0.333333   \n",
       "924                0.333333         0.000000  0.642857          0.000000   \n",
       "931                1.000000         0.333333  0.053571          0.000000   \n",
       "796                0.000000         1.000000  0.571429          0.000000   \n",
       "226                0.000000         0.333333  0.142857          0.333333   \n",
       "..                      ...              ...       ...               ...   \n",
       "380                1.000000         1.000000  0.357143          0.000000   \n",
       "768                0.000000         1.000000  0.160714          0.666667   \n",
       "85                 1.000000         0.333333  0.178571          0.333333   \n",
       "527                0.333333         0.000000  0.410714          0.333333   \n",
       "117                0.333333         0.666667  0.142857          0.333333   \n",
       "\n",
       "     num_dependents checking_status                  credit_history  \\\n",
       "32              0.0        0<=X<200                   existing paid   \n",
       "924             0.0              <0                        all paid   \n",
       "931             0.0        0<=X<200                   existing paid   \n",
       "796             1.0              <0                   existing paid   \n",
       "226             0.0        0<=X<200                   existing paid   \n",
       "..              ...             ...                             ...   \n",
       "380             0.0              <0                   existing paid   \n",
       "768             0.0        0<=X<200  critical/other existing credit   \n",
       "85              0.0     no checking  critical/other existing credit   \n",
       "527             1.0     no checking  critical/other existing credit   \n",
       "117             0.0              <0  critical/other existing credit   \n",
       "\n",
       "                 purpose    savings_status employment     personal_status  \\\n",
       "32               new car        100<=X<500     1<=X<4         male single   \n",
       "924  furniture/equipment              <100         <1        male div/sep   \n",
       "931             radio/tv              <100         <1  female div/dep/mar   \n",
       "796             used car  no known savings        >=7         male single   \n",
       "226             radio/tv            >=1000     4<=X<7         male single   \n",
       "..                   ...               ...        ...                 ...   \n",
       "380  furniture/equipment  no known savings     4<=X<7         male single   \n",
       "768  furniture/equipment              <100        >=7         male single   \n",
       "85              business              <100     1<=X<4  female div/dep/mar   \n",
       "527             radio/tv              <100     4<=X<7         male single   \n",
       "117  furniture/equipment  no known savings         <1  female div/dep/mar   \n",
       "\n",
       "    other_parties property_magnitude other_payment_plans   housing  \\\n",
       "32           none                car                none       own   \n",
       "924          none     life insurance                bank       own   \n",
       "931          none                car                none       own   \n",
       "796          none     life insurance                none  for free   \n",
       "226  co applicant  no known property                bank       own   \n",
       "..            ...                ...                 ...       ...   \n",
       "380          none                car                none       own   \n",
       "768          none                car                none      rent   \n",
       "85      guarantor        real estate                none       own   \n",
       "527          none        real estate                none       own   \n",
       "117  co applicant        real estate                none      rent   \n",
       "\n",
       "                           job  \n",
       "32                     skilled  \n",
       "924                    skilled  \n",
       "931                    skilled  \n",
       "796                    skilled  \n",
       "226                    skilled  \n",
       "..                         ...  \n",
       "380                    skilled  \n",
       "768                    skilled  \n",
       "85   high qualif/self emp/mgmt  \n",
       "527         unskilled resident  \n",
       "117                    skilled  \n",
       "\n",
       "[100 rows x 20 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "valset[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "af2eed94",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3018579e308d4eb995ed65b3581b7f06",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:   0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, test val_loss: 0.624792\n",
      "epoch: 0, train loss: 6.6127, lr: 0.000100, spent: 0.5 secs\n",
      "epoch: 1, test val_loss: 0.599838\n",
      "epoch: 1, train loss: 6.3586, lr: 0.000100, spent: 1.2 secs\n",
      "epoch: 2, test val_loss: 0.593658\n",
      "epoch: 2, train loss: 6.0999, lr: 0.000100, spent: 1.8 secs\n",
      "epoch: 3, test val_loss: 0.550265\n",
      "epoch: 3, train loss: 5.8295, lr: 0.000100, spent: 2.3 secs\n",
      "epoch: 4, test val_loss: 0.527351\n",
      "epoch: 4, train loss: 5.6347, lr: 0.000100, spent: 2.8 secs\n",
      "epoch: 5, test val_loss: 0.508950\n",
      "epoch: 5, train loss: 5.5123, lr: 0.000100, spent: 3.3 secs\n",
      "epoch: 6, test val_loss: 0.485854\n",
      "epoch: 6, train loss: 5.4929, lr: 0.000100, spent: 3.9 secs\n",
      "epoch: 7, test val_loss: 0.522198\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 7, train loss: 5.6552, lr: 0.000100, spent: 4.3 secs\n",
      "epoch: 8, test val_loss: 0.478467\n",
      "epoch: 8, train loss: 5.7420, lr: 0.000100, spent: 4.7 secs\n",
      "epoch: 9, test val_loss: 0.515104\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 9, train loss: 5.3993, lr: 0.000100, spent: 5.3 secs\n",
      "epoch: 10, test val_loss: 0.474058\n",
      "epoch: 10, train loss: 5.3141, lr: 0.000100, spent: 5.8 secs\n",
      "epoch: 11, test val_loss: 0.473926\n",
      "epoch: 11, train loss: 5.2754, lr: 0.000100, spent: 6.3 secs\n",
      "epoch: 12, test val_loss: 0.470752\n",
      "epoch: 12, train loss: 5.1095, lr: 0.000100, spent: 6.8 secs\n",
      "epoch: 13, test val_loss: 0.478428\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 13, train loss: 5.0845, lr: 0.000100, spent: 7.3 secs\n",
      "epoch: 14, test val_loss: 0.454532\n",
      "epoch: 14, train loss: 5.1003, lr: 0.000100, spent: 8.0 secs\n",
      "epoch: 15, test val_loss: 0.462518\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 15, train loss: 5.0139, lr: 0.000100, spent: 8.5 secs\n",
      "epoch: 16, test val_loss: 0.453442\n",
      "epoch: 16, train loss: 4.9912, lr: 0.000100, spent: 9.1 secs\n",
      "epoch: 17, test val_loss: 0.459327\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 17, train loss: 4.9310, lr: 0.000100, spent: 9.5 secs\n",
      "epoch: 18, test val_loss: 0.442287\n",
      "epoch: 18, train loss: 4.8740, lr: 0.000100, spent: 10.2 secs\n",
      "epoch: 19, test val_loss: 0.466330\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 19, train loss: 4.8456, lr: 0.000100, spent: 10.8 secs\n",
      "epoch: 20, test val_loss: 0.436802\n",
      "epoch: 20, train loss: 4.7808, lr: 0.000100, spent: 11.2 secs\n",
      "epoch: 21, test val_loss: 0.472410\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 21, train loss: 4.7860, lr: 0.000100, spent: 11.6 secs\n",
      "epoch: 22, test val_loss: 0.448208\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 22, train loss: 4.9795, lr: 0.000100, spent: 12.2 secs\n",
      "epoch: 23, test val_loss: 0.426601\n",
      "epoch: 23, train loss: 4.8747, lr: 0.000100, spent: 12.8 secs\n",
      "epoch: 24, test val_loss: 0.556543\n",
      "EarlyStopping counter: 1 out of 5\n",
      "epoch: 24, train loss: 4.9586, lr: 0.000100, spent: 13.2 secs\n",
      "epoch: 25, test val_loss: 0.455203\n",
      "EarlyStopping counter: 2 out of 5\n",
      "epoch: 25, train loss: 4.9627, lr: 0.000100, spent: 13.8 secs\n",
      "epoch: 26, test val_loss: 0.581238\n",
      "EarlyStopping counter: 3 out of 5\n",
      "epoch: 26, train loss: 5.0275, lr: 0.000100, spent: 14.2 secs\n",
      "epoch: 27, test val_loss: 0.501105\n",
      "EarlyStopping counter: 4 out of 5\n",
      "epoch: 27, train loss: 5.2915, lr: 0.000100, spent: 14.7 secs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-08-31 10:57:39.041 | INFO     | transtab.trainer:train:132 - load best at last from ./checkpoint\n",
      "2022-08-31 10:57:39.057 | INFO     | transtab.trainer:save_model:239 - saving model checkpoint to ./checkpoint\n",
      "2022-08-31 10:57:39.187 | INFO     | transtab.trainer:train:137 - training complete, cost 15.3 secs.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 28, test val_loss: 0.461543\n",
      "EarlyStopping counter: 5 out of 5\n",
      "early stopped\n"
     ]
    }
   ],
   "source": [
    "# start training, take the validation loss on average for evaluation\n",
    "transtab.train(model, trainset, valset, **training_arguments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9b65a489",
   "metadata": {},
   "outputs": [],
   "source": [
    "# make predictions on the first dataset 'credit-g'\n",
    "x_test, y_test = testset[0]\n",
    "ypred = transtab.predict(model, x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6eefaa05",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "auc 0.95 mean/interval 0.7399(0.06)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.7399011920073604]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# evaluate the predictions with bootstrapping estimate\n",
    "transtab.evaluate(ypred, y_test, seed=123, metric='auc')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34d19852",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
