{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "334f6cbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "# Define column names based on the dataset description\n",
    "column_names = [\n",
    "    \"age\", \"workclass\", \"fnlwgt\", \"education\", \"education-num\", \"marital-status\",\n",
    "    \"occupation\", \"relationship\", \"race\", \"sex\", \"capital-gain\", \"capital-loss\",\n",
    "    \"hours-per-week\", \"native-country\", \"income\"\n",
    "]\n",
    "\n",
    "# Load adult.data into a dataframe\n",
    "path = r\"\"\n",
    "df_train = pd.read_csv(os.path.join(path, \"adult.data\"), names=column_names, sep=\",\\s*\", engine='python')\n",
    "\n",
    "# Load adult.test into a dataframe (skipping the first row which is a comment)\n",
    "df_test = pd.read_csv(os.path.join(path, \"adult.test\"), names=column_names, skiprows=1, sep=\",\\s*\", engine='python')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "46dbeca1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>State-gov</td>\n",
       "      <td>77516</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>2174</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>50</td>\n",
       "      <td>Self-emp-not-inc</td>\n",
       "      <td>83311</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>215646</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Divorced</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>53</td>\n",
       "      <td>Private</td>\n",
       "      <td>234721</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28</td>\n",
       "      <td>Private</td>\n",
       "      <td>338409</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Prof-specialty</td>\n",
       "      <td>Wife</td>\n",
       "      <td>Black</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>Cuba</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32556</th>\n",
       "      <td>27</td>\n",
       "      <td>Private</td>\n",
       "      <td>257302</td>\n",
       "      <td>Assoc-acdm</td>\n",
       "      <td>12</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Tech-support</td>\n",
       "      <td>Wife</td>\n",
       "      <td>White</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>38</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32557</th>\n",
       "      <td>40</td>\n",
       "      <td>Private</td>\n",
       "      <td>154374</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&gt;50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32558</th>\n",
       "      <td>58</td>\n",
       "      <td>Private</td>\n",
       "      <td>151910</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Widowed</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>Unmarried</td>\n",
       "      <td>White</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32559</th>\n",
       "      <td>22</td>\n",
       "      <td>Private</td>\n",
       "      <td>201490</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>20</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32560</th>\n",
       "      <td>52</td>\n",
       "      <td>Self-emp-inc</td>\n",
       "      <td>287927</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>Wife</td>\n",
       "      <td>White</td>\n",
       "      <td>Female</td>\n",
       "      <td>15024</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&gt;50K</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>32561 rows × 15 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       age         workclass  fnlwgt   education  education-num  \\\n",
       "0       39         State-gov   77516   Bachelors             13   \n",
       "1       50  Self-emp-not-inc   83311   Bachelors             13   \n",
       "2       38           Private  215646     HS-grad              9   \n",
       "3       53           Private  234721        11th              7   \n",
       "4       28           Private  338409   Bachelors             13   \n",
       "...    ...               ...     ...         ...            ...   \n",
       "32556   27           Private  257302  Assoc-acdm             12   \n",
       "32557   40           Private  154374     HS-grad              9   \n",
       "32558   58           Private  151910     HS-grad              9   \n",
       "32559   22           Private  201490     HS-grad              9   \n",
       "32560   52      Self-emp-inc  287927     HS-grad              9   \n",
       "\n",
       "           marital-status         occupation   relationship   race     sex  \\\n",
       "0           Never-married       Adm-clerical  Not-in-family  White    Male   \n",
       "1      Married-civ-spouse    Exec-managerial        Husband  White    Male   \n",
       "2                Divorced  Handlers-cleaners  Not-in-family  White    Male   \n",
       "3      Married-civ-spouse  Handlers-cleaners        Husband  Black    Male   \n",
       "4      Married-civ-spouse     Prof-specialty           Wife  Black  Female   \n",
       "...                   ...                ...            ...    ...     ...   \n",
       "32556  Married-civ-spouse       Tech-support           Wife  White  Female   \n",
       "32557  Married-civ-spouse  Machine-op-inspct        Husband  White    Male   \n",
       "32558             Widowed       Adm-clerical      Unmarried  White  Female   \n",
       "32559       Never-married       Adm-clerical      Own-child  White    Male   \n",
       "32560  Married-civ-spouse    Exec-managerial           Wife  White  Female   \n",
       "\n",
       "       capital-gain  capital-loss  hours-per-week native-country income  \n",
       "0              2174             0              40  United-States  <=50K  \n",
       "1                 0             0              13  United-States  <=50K  \n",
       "2                 0             0              40  United-States  <=50K  \n",
       "3                 0             0              40  United-States  <=50K  \n",
       "4                 0             0              40           Cuba  <=50K  \n",
       "...             ...           ...             ...            ...    ...  \n",
       "32556             0             0              38  United-States  <=50K  \n",
       "32557             0             0              40  United-States   >50K  \n",
       "32558             0             0              40  United-States  <=50K  \n",
       "32559             0             0              20  United-States  <=50K  \n",
       "32560         15024             0              40  United-States   >50K  \n",
       "\n",
       "[32561 rows x 15 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7016d1e2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/shenyu/miniconda3/envs/DLcourse/lib/python3.7/site-packages/ipykernel_launcher.py:1: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  \"\"\"Entry point for launching an IPython kernel.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>State-gov</td>\n",
       "      <td>77516</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>2174</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>50</td>\n",
       "      <td>Self-emp-not-inc</td>\n",
       "      <td>83311</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>1</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>215646</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Divorced</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>53</td>\n",
       "      <td>Private</td>\n",
       "      <td>234721</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28</td>\n",
       "      <td>Private</td>\n",
       "      <td>338409</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Prof-specialty</td>\n",
       "      <td>Wife</td>\n",
       "      <td>Black</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>0</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age         workclass  fnlwgt  education  education-num  \\\n",
       "0   39         State-gov   77516  Bachelors             13   \n",
       "1   50  Self-emp-not-inc   83311  Bachelors             13   \n",
       "2   38           Private  215646    HS-grad              9   \n",
       "3   53           Private  234721       11th              7   \n",
       "4   28           Private  338409  Bachelors             13   \n",
       "\n",
       "       marital-status         occupation   relationship   race     sex  \\\n",
       "0       Never-married       Adm-clerical  Not-in-family  White    Male   \n",
       "1  Married-civ-spouse    Exec-managerial        Husband  White    Male   \n",
       "2            Divorced  Handlers-cleaners  Not-in-family  White    Male   \n",
       "3  Married-civ-spouse  Handlers-cleaners        Husband  Black    Male   \n",
       "4  Married-civ-spouse     Prof-specialty           Wife  Black  Female   \n",
       "\n",
       "   capital-gain  capital-loss  hours-per-week  native-country income  \n",
       "0          2174             0              40               1  <=50K  \n",
       "1             0             0              13               1  <=50K  \n",
       "2             0             0              40               1  <=50K  \n",
       "3             0             0              40               1  <=50K  \n",
       "4             0             0              40               0  <=50K  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[feature for feature in df_train.columns if '?' in df_train[feature].unique() or np.nan in df_train[feature].unique()]\n",
    "# Remove all missing values\n",
    "for feature in df_train.columns:\n",
    "    df_train[feature] = df_train[feature].replace('?', np.nan)\n",
    "df_train.dropna(how='any', inplace=True)\n",
    "\n",
    "df_train['native-country'].unique(), df_train['workclass'].unique(), df_train['occupation'].unique()\n",
    "df_train.loc[df_train['native-country']!='United-States', 'native-country'] = 'Non-US'\n",
    "df_train.loc[df_train['native-country'] == 'United-States', 'native-country'] = 'US'\n",
    "US_LABEL, NON_US_LABEL = (1, 0)\n",
    "df_train['native-country'] = df_train['native-country'].map({'US':US_LABEL,'Non-US':NON_US_LABEL}).astype(int)\n",
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6d6cb6ba",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>State-gov</td>\n",
       "      <td>77516</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>1</td>\n",
       "      <td>2174</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>50</td>\n",
       "      <td>Self-emp-not-inc</td>\n",
       "      <td>83311</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>215646</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Divorced</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>53</td>\n",
       "      <td>Private</td>\n",
       "      <td>234721</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28</td>\n",
       "      <td>Private</td>\n",
       "      <td>338409</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Prof-specialty</td>\n",
       "      <td>Wife</td>\n",
       "      <td>Black</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age         workclass  fnlwgt  education  education-num  \\\n",
       "0   39         State-gov   77516  Bachelors             13   \n",
       "1   50  Self-emp-not-inc   83311  Bachelors             13   \n",
       "2   38           Private  215646    HS-grad              9   \n",
       "3   53           Private  234721       11th              7   \n",
       "4   28           Private  338409  Bachelors             13   \n",
       "\n",
       "       marital-status         occupation   relationship   race  sex  \\\n",
       "0       Never-married       Adm-clerical  Not-in-family  White    1   \n",
       "1  Married-civ-spouse    Exec-managerial        Husband  White    1   \n",
       "2            Divorced  Handlers-cleaners  Not-in-family  White    1   \n",
       "3  Married-civ-spouse  Handlers-cleaners        Husband  Black    1   \n",
       "4  Married-civ-spouse     Prof-specialty           Wife  Black    0   \n",
       "\n",
       "   capital-gain  capital-loss  hours-per-week  native-country  income  \n",
       "0          2174             0              40               1       0  \n",
       "1             0             0              13               1       0  \n",
       "2             0             0              40               1       0  \n",
       "3             0             0              40               1       0  \n",
       "4             0             0              40               0       0  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "FEMALE_LABEL, MALE_LABEL = (0, 1)\n",
    "HIGH_SALARY_LABEL, LOW_SALARY_LABEL = (1, 0)\n",
    "df_train['income'] = df_train['income'].map({'>50K':HIGH_SALARY_LABEL,'<=50K':LOW_SALARY_LABEL})\n",
    "df_train['sex'] = df_train['sex'].map({'Male':MALE_LABEL,'Female':FEMALE_LABEL})\n",
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fb67bea1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>State-gov</td>\n",
       "      <td>77516</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Single</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>1</td>\n",
       "      <td>2174</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>50</td>\n",
       "      <td>Self-emp-not-inc</td>\n",
       "      <td>83311</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Couple</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>215646</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Single</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>53</td>\n",
       "      <td>Private</td>\n",
       "      <td>234721</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>Couple</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28</td>\n",
       "      <td>Private</td>\n",
       "      <td>338409</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Couple</td>\n",
       "      <td>Prof-specialty</td>\n",
       "      <td>Wife</td>\n",
       "      <td>Black</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age         workclass  fnlwgt  education  education-num marital-status  \\\n",
       "0   39         State-gov   77516  Bachelors             13         Single   \n",
       "1   50  Self-emp-not-inc   83311  Bachelors             13         Couple   \n",
       "2   38           Private  215646    HS-grad              9         Single   \n",
       "3   53           Private  234721       11th              7         Couple   \n",
       "4   28           Private  338409  Bachelors             13         Couple   \n",
       "\n",
       "          occupation   relationship   race  sex  capital-gain  capital-loss  \\\n",
       "0       Adm-clerical  Not-in-family  White    1          2174             0   \n",
       "1    Exec-managerial        Husband  White    1             0             0   \n",
       "2  Handlers-cleaners  Not-in-family  White    1             0             0   \n",
       "3  Handlers-cleaners        Husband  Black    1             0             0   \n",
       "4     Prof-specialty           Wife  Black    0             0             0   \n",
       "\n",
       "   hours-per-week  native-country  income  \n",
       "0              40               1       0  \n",
       "1              13               1       0  \n",
       "2              40               1       0  \n",
       "3              40               1       0  \n",
       "4              40               0       0  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train['marital-status'] = df_train['marital-status'].replace(['Divorced','Married-spouse-absent','Never-married','Separated','Widowed'],'Single')\n",
    "df_train['marital-status'] = df_train['marital-status'].replace(['Married-AF-spouse','Married-civ-spouse'],'Couple')\n",
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "011b62f9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>State-gov</td>\n",
       "      <td>77516</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>1</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>1</td>\n",
       "      <td>2174</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>50</td>\n",
       "      <td>Self-emp-not-inc</td>\n",
       "      <td>83311</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>215646</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>1</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>53</td>\n",
       "      <td>Private</td>\n",
       "      <td>234721</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>0</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28</td>\n",
       "      <td>Private</td>\n",
       "      <td>338409</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>Prof-specialty</td>\n",
       "      <td>Wife</td>\n",
       "      <td>Black</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age         workclass  fnlwgt  education  education-num  marital-status  \\\n",
       "0   39         State-gov   77516  Bachelors             13               1   \n",
       "1   50  Self-emp-not-inc   83311  Bachelors             13               0   \n",
       "2   38           Private  215646    HS-grad              9               1   \n",
       "3   53           Private  234721       11th              7               0   \n",
       "4   28           Private  338409  Bachelors             13               0   \n",
       "\n",
       "          occupation   relationship   race  sex  capital-gain  capital-loss  \\\n",
       "0       Adm-clerical  Not-in-family  White    1          2174             0   \n",
       "1    Exec-managerial        Husband  White    1             0             0   \n",
       "2  Handlers-cleaners  Not-in-family  White    1             0             0   \n",
       "3  Handlers-cleaners        Husband  Black    1             0             0   \n",
       "4     Prof-specialty           Wife  Black    0             0             0   \n",
       "\n",
       "   hours-per-week  native-country  income  \n",
       "0              40               1       0  \n",
       "1              13               1       0  \n",
       "2              40               1       0  \n",
       "3              40               1       0  \n",
       "4              40               0       0  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "COUPLE_STATUS_LABEL, SINGLE_STATUS_LABEL = (0, 1)\n",
    "df_train['marital-status'] = df_train['marital-status'].map({'Couple':COUPLE_STATUS_LABEL,'Single':SINGLE_STATUS_LABEL})\n",
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "909d5427",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>State-gov</td>\n",
       "      <td>77516</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>1</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>3</td>\n",
       "      <td>White</td>\n",
       "      <td>1</td>\n",
       "      <td>2174</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>50</td>\n",
       "      <td>Self-emp-not-inc</td>\n",
       "      <td>83311</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>2</td>\n",
       "      <td>White</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>215646</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>1</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>3</td>\n",
       "      <td>White</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>53</td>\n",
       "      <td>Private</td>\n",
       "      <td>234721</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>0</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>2</td>\n",
       "      <td>Black</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28</td>\n",
       "      <td>Private</td>\n",
       "      <td>338409</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>Prof-specialty</td>\n",
       "      <td>1</td>\n",
       "      <td>Black</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>37</td>\n",
       "      <td>Private</td>\n",
       "      <td>284582</td>\n",
       "      <td>Masters</td>\n",
       "      <td>14</td>\n",
       "      <td>0</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>1</td>\n",
       "      <td>White</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>49</td>\n",
       "      <td>Private</td>\n",
       "      <td>160187</td>\n",
       "      <td>9th</td>\n",
       "      <td>5</td>\n",
       "      <td>1</td>\n",
       "      <td>Other-service</td>\n",
       "      <td>3</td>\n",
       "      <td>Black</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>16</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>52</td>\n",
       "      <td>Self-emp-not-inc</td>\n",
       "      <td>209642</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>0</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>2</td>\n",
       "      <td>White</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>45</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>31</td>\n",
       "      <td>Private</td>\n",
       "      <td>45781</td>\n",
       "      <td>Masters</td>\n",
       "      <td>14</td>\n",
       "      <td>1</td>\n",
       "      <td>Prof-specialty</td>\n",
       "      <td>3</td>\n",
       "      <td>White</td>\n",
       "      <td>0</td>\n",
       "      <td>14084</td>\n",
       "      <td>0</td>\n",
       "      <td>50</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>42</td>\n",
       "      <td>Private</td>\n",
       "      <td>159449</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>2</td>\n",
       "      <td>White</td>\n",
       "      <td>1</td>\n",
       "      <td>5178</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age         workclass  fnlwgt  education  education-num  marital-status  \\\n",
       "0   39         State-gov   77516  Bachelors             13               1   \n",
       "1   50  Self-emp-not-inc   83311  Bachelors             13               0   \n",
       "2   38           Private  215646    HS-grad              9               1   \n",
       "3   53           Private  234721       11th              7               0   \n",
       "4   28           Private  338409  Bachelors             13               0   \n",
       "5   37           Private  284582    Masters             14               0   \n",
       "6   49           Private  160187        9th              5               1   \n",
       "7   52  Self-emp-not-inc  209642    HS-grad              9               0   \n",
       "8   31           Private   45781    Masters             14               1   \n",
       "9   42           Private  159449  Bachelors             13               0   \n",
       "\n",
       "          occupation  relationship   race  sex  capital-gain  capital-loss  \\\n",
       "0       Adm-clerical             3  White    1          2174             0   \n",
       "1    Exec-managerial             2  White    1             0             0   \n",
       "2  Handlers-cleaners             3  White    1             0             0   \n",
       "3  Handlers-cleaners             2  Black    1             0             0   \n",
       "4     Prof-specialty             1  Black    0             0             0   \n",
       "5    Exec-managerial             1  White    0             0             0   \n",
       "6      Other-service             3  Black    0             0             0   \n",
       "7    Exec-managerial             2  White    1             0             0   \n",
       "8     Prof-specialty             3  White    0         14084             0   \n",
       "9    Exec-managerial             2  White    1          5178             0   \n",
       "\n",
       "   hours-per-week  native-country  income  \n",
       "0              40               1       0  \n",
       "1              13               1       0  \n",
       "2              40               1       0  \n",
       "3              40               1       0  \n",
       "4              40               0       0  \n",
       "5              40               1       0  \n",
       "6              16               0       0  \n",
       "7              45               1       1  \n",
       "8              50               1       1  \n",
       "9              40               1       1  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# First convert relationship to integers\n",
    "rel_map = {'Unmarried':0,'Wife':1,'Husband':2,'Not-in-family':3,'Own-child':4,'Other-relative':5}\n",
    "df_train['relationship'] = df_train['relationship'].map(rel_map)\n",
    "df_train.head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f5cd8516",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "      <th>relationship_0</th>\n",
       "      <th>relationship_1</th>\n",
       "      <th>relationship_2</th>\n",
       "      <th>relationship_3</th>\n",
       "      <th>relationship_4</th>\n",
       "      <th>relationship_5</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>State-gov</td>\n",
       "      <td>77516</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>1</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>White</td>\n",
       "      <td>1</td>\n",
       "      <td>2174</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>50</td>\n",
       "      <td>Self-emp-not-inc</td>\n",
       "      <td>83311</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>White</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>215646</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>1</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>White</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>53</td>\n",
       "      <td>Private</td>\n",
       "      <td>234721</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>0</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Black</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28</td>\n",
       "      <td>Private</td>\n",
       "      <td>338409</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>Prof-specialty</td>\n",
       "      <td>Black</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age         workclass  fnlwgt  education  education-num  marital-status  \\\n",
       "0   39         State-gov   77516  Bachelors             13               1   \n",
       "1   50  Self-emp-not-inc   83311  Bachelors             13               0   \n",
       "2   38           Private  215646    HS-grad              9               1   \n",
       "3   53           Private  234721       11th              7               0   \n",
       "4   28           Private  338409  Bachelors             13               0   \n",
       "\n",
       "          occupation   race  sex  capital-gain  capital-loss  hours-per-week  \\\n",
       "0       Adm-clerical  White    1          2174             0              40   \n",
       "1    Exec-managerial  White    1             0             0              13   \n",
       "2  Handlers-cleaners  White    1             0             0              40   \n",
       "3  Handlers-cleaners  Black    1             0             0              40   \n",
       "4     Prof-specialty  Black    0             0             0              40   \n",
       "\n",
       "   native-country  income  relationship_0  relationship_1  relationship_2  \\\n",
       "0               1       0               0               0               0   \n",
       "1               1       0               0               0               1   \n",
       "2               1       0               0               0               0   \n",
       "3               1       0               0               0               1   \n",
       "4               0       0               0               1               0   \n",
       "\n",
       "   relationship_3  relationship_4  relationship_5  \n",
       "0               1               0               0  \n",
       "1               0               0               0  \n",
       "2               1               0               0  \n",
       "3               0               0               0  \n",
       "4               0               0               0  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train = pd.get_dummies(df_train, columns=['relationship'])\n",
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1968dce2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "      <th>relationship_0</th>\n",
       "      <th>relationship_1</th>\n",
       "      <th>relationship_2</th>\n",
       "      <th>relationship_3</th>\n",
       "      <th>relationship_4</th>\n",
       "      <th>relationship_5</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>State-gov</td>\n",
       "      <td>77516</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>1</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>2174</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>50</td>\n",
       "      <td>Self-emp-not-inc</td>\n",
       "      <td>83311</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>215646</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>1</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>53</td>\n",
       "      <td>Private</td>\n",
       "      <td>234721</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>0</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28</td>\n",
       "      <td>Private</td>\n",
       "      <td>338409</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>Prof-specialty</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age         workclass  fnlwgt  education  education-num  marital-status  \\\n",
       "0   39         State-gov   77516  Bachelors             13               1   \n",
       "1   50  Self-emp-not-inc   83311  Bachelors             13               0   \n",
       "2   38           Private  215646    HS-grad              9               1   \n",
       "3   53           Private  234721       11th              7               0   \n",
       "4   28           Private  338409  Bachelors             13               0   \n",
       "\n",
       "          occupation  race  sex  capital-gain  capital-loss  hours-per-week  \\\n",
       "0       Adm-clerical     0    1          2174             0              40   \n",
       "1    Exec-managerial     0    1             0             0              13   \n",
       "2  Handlers-cleaners     0    1             0             0              40   \n",
       "3  Handlers-cleaners     1    1             0             0              40   \n",
       "4     Prof-specialty     1    0             0             0              40   \n",
       "\n",
       "   native-country  income  relationship_0  relationship_1  relationship_2  \\\n",
       "0               1       0               0               0               0   \n",
       "1               1       0               0               0               1   \n",
       "2               1       0               0               0               0   \n",
       "3               1       0               0               0               1   \n",
       "4               0       0               0               1               0   \n",
       "\n",
       "   relationship_3  relationship_4  relationship_5  \n",
       "0               1               0               0  \n",
       "1               0               0               0  \n",
       "2               1               0               0  \n",
       "3               0               0               0  \n",
       "4               0               0               0  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Convert to integers first\n",
    "race_map={'White':0,'Amer-Indian-Eskimo':1,'Asian-Pac-Islander':1,'Black':1,'Other':1}\n",
    "df_train['race']= df_train['race'].map(race_map)\n",
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "356450bc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>...</th>\n",
       "      <th>relationship_0</th>\n",
       "      <th>relationship_1</th>\n",
       "      <th>relationship_2</th>\n",
       "      <th>relationship_3</th>\n",
       "      <th>relationship_4</th>\n",
       "      <th>relationship_5</th>\n",
       "      <th>workclass_0</th>\n",
       "      <th>workclass_1</th>\n",
       "      <th>workclass_2</th>\n",
       "      <th>workclass_3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>77516</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>1</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>2174</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>50</td>\n",
       "      <td>83311</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>38</td>\n",
       "      <td>215646</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>1</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>53</td>\n",
       "      <td>234721</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>0</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28</td>\n",
       "      <td>338409</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>Prof-specialty</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 23 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   age  fnlwgt  education  education-num  marital-status         occupation  \\\n",
       "0   39   77516  Bachelors             13               1       Adm-clerical   \n",
       "1   50   83311  Bachelors             13               0    Exec-managerial   \n",
       "2   38  215646    HS-grad              9               1  Handlers-cleaners   \n",
       "3   53  234721       11th              7               0  Handlers-cleaners   \n",
       "4   28  338409  Bachelors             13               0     Prof-specialty   \n",
       "\n",
       "   race  sex  capital-gain  capital-loss  ...  relationship_0  relationship_1  \\\n",
       "0     0    1          2174             0  ...               0               0   \n",
       "1     0    1             0             0  ...               0               0   \n",
       "2     0    1             0             0  ...               0               0   \n",
       "3     1    1             0             0  ...               0               0   \n",
       "4     1    0             0             0  ...               0               1   \n",
       "\n",
       "   relationship_2  relationship_3  relationship_4  relationship_5  \\\n",
       "0               0               1               0               0   \n",
       "1               1               0               0               0   \n",
       "2               0               1               0               0   \n",
       "3               1               0               0               0   \n",
       "4               0               0               0               0   \n",
       "\n",
       "   workclass_0  workclass_1  workclass_2  workclass_3  \n",
       "0            1            0            0            0  \n",
       "1            0            0            1            0  \n",
       "2            0            1            0            0  \n",
       "3            0            1            0            0  \n",
       "4            0            1            0            0  \n",
       "\n",
       "[5 rows x 23 columns]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def group_workclass(x):\n",
    "    if x['workclass'] == 'Federal-gov' or x['workclass']== 'Local-gov' or x['workclass']=='State-gov': return 'govt'\n",
    "    elif x['workclass'] == 'Private':return 'private'\n",
    "    elif x['workclass'] == 'Self-emp-inc' or x['workclass'] == 'Self-emp-not-inc': return 'self_employed'\n",
    "    else: return 'without_pay'\n",
    "    \n",
    "df_train['workclass']=df_train.apply(group_workclass, axis=1)\n",
    "\n",
    "df_train['workclass'] = df_train['workclass'].map({'govt':0,'private':1,'self_employed':2,'without_pay':3})\n",
    "df_train = pd.get_dummies(df_train, columns=['workclass'])\n",
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "38160644",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>...</th>\n",
       "      <th>occupation_4</th>\n",
       "      <th>occupation_5</th>\n",
       "      <th>occupation_6</th>\n",
       "      <th>occupation_7</th>\n",
       "      <th>occupation_8</th>\n",
       "      <th>occupation_9</th>\n",
       "      <th>occupation_10</th>\n",
       "      <th>occupation_11</th>\n",
       "      <th>occupation_12</th>\n",
       "      <th>occupation_13</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>77516</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>2174</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>50</td>\n",
       "      <td>83311</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>38</td>\n",
       "      <td>215646</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>53</td>\n",
       "      <td>234721</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28</td>\n",
       "      <td>338409</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 36 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   age  fnlwgt  education  education-num  marital-status  race  sex  \\\n",
       "0   39   77516  Bachelors             13               1     0    1   \n",
       "1   50   83311  Bachelors             13               0     0    1   \n",
       "2   38  215646    HS-grad              9               1     0    1   \n",
       "3   53  234721       11th              7               0     1    1   \n",
       "4   28  338409  Bachelors             13               0     1    0   \n",
       "\n",
       "   capital-gain  capital-loss  hours-per-week  ...  occupation_4  \\\n",
       "0          2174             0              40  ...             0   \n",
       "1             0             0              13  ...             0   \n",
       "2             0             0              40  ...             0   \n",
       "3             0             0              40  ...             0   \n",
       "4             0             0              40  ...             0   \n",
       "\n",
       "   occupation_5  occupation_6  occupation_7  occupation_8  occupation_9  \\\n",
       "0             0             0             0             0             0   \n",
       "1             0             0             0             0             0   \n",
       "2             0             0             0             0             0   \n",
       "3             0             0             0             0             0   \n",
       "4             0             0             0             0             0   \n",
       "\n",
       "   occupation_10  occupation_11  occupation_12  occupation_13  \n",
       "0              0              0              0              0  \n",
       "1              0              0              0              0  \n",
       "2              0              0              0              0  \n",
       "3              0              0              0              0  \n",
       "4              0              0              0              0  \n",
       "\n",
       "[5 rows x 36 columns]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "occupation_map = dict((value, key) for (key, value) in enumerate(df_train.occupation.unique()))\n",
    "df_train['occupation'] = df_train['occupation'].map(occupation_map)\n",
    "df_train = pd.get_dummies(df_train, columns=['occupation'])\n",
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b25cc162",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>...</th>\n",
       "      <th>occupation_4</th>\n",
       "      <th>occupation_5</th>\n",
       "      <th>occupation_6</th>\n",
       "      <th>occupation_7</th>\n",
       "      <th>occupation_8</th>\n",
       "      <th>occupation_9</th>\n",
       "      <th>occupation_10</th>\n",
       "      <th>occupation_11</th>\n",
       "      <th>occupation_12</th>\n",
       "      <th>occupation_13</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>77516</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>50</td>\n",
       "      <td>83311</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>38</td>\n",
       "      <td>215646</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>53</td>\n",
       "      <td>234721</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28</td>\n",
       "      <td>338409</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 36 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   age  fnlwgt  education  education-num  marital-status  race  sex  \\\n",
       "0   39   77516  Bachelors             13               1     0    1   \n",
       "1   50   83311  Bachelors             13               0     0    1   \n",
       "2   38  215646    HS-grad              9               1     0    1   \n",
       "3   53  234721       11th              7               0     1    1   \n",
       "4   28  338409  Bachelors             13               0     1    0   \n",
       "\n",
       "   capital-gain  capital-loss  hours-per-week  ...  occupation_4  \\\n",
       "0             1             0              40  ...             0   \n",
       "1             0             0              13  ...             0   \n",
       "2             0             0              40  ...             0   \n",
       "3             0             0              40  ...             0   \n",
       "4             0             0              40  ...             0   \n",
       "\n",
       "   occupation_5  occupation_6  occupation_7  occupation_8  occupation_9  \\\n",
       "0             0             0             0             0             0   \n",
       "1             0             0             0             0             0   \n",
       "2             0             0             0             0             0   \n",
       "3             0             0             0             0             0   \n",
       "4             0             0             0             0             0   \n",
       "\n",
       "   occupation_10  occupation_11  occupation_12  occupation_13  \n",
       "0              0              0              0              0  \n",
       "1              0              0              0              0  \n",
       "2              0              0              0              0  \n",
       "3              0              0              0              0  \n",
       "4              0              0              0              0  \n",
       "\n",
       "[5 rows x 36 columns]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train.loc[(df_train['capital-gain'] > 0),'capital-gain'] = 1\n",
    "df_train.loc[(df_train['capital-gain'] == 0 ,'capital-gain')]= 0\n",
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "906e9e47",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>...</th>\n",
       "      <th>occupation_4</th>\n",
       "      <th>occupation_5</th>\n",
       "      <th>occupation_6</th>\n",
       "      <th>occupation_7</th>\n",
       "      <th>occupation_8</th>\n",
       "      <th>occupation_9</th>\n",
       "      <th>occupation_10</th>\n",
       "      <th>occupation_11</th>\n",
       "      <th>occupation_12</th>\n",
       "      <th>occupation_13</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>77516</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>50</td>\n",
       "      <td>83311</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>38</td>\n",
       "      <td>215646</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>53</td>\n",
       "      <td>234721</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28</td>\n",
       "      <td>338409</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 36 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   age  fnlwgt  education  education-num  marital-status  race  sex  \\\n",
       "0   39   77516  Bachelors             13               1     0    1   \n",
       "1   50   83311  Bachelors             13               0     0    1   \n",
       "2   38  215646    HS-grad              9               1     0    1   \n",
       "3   53  234721       11th              7               0     1    1   \n",
       "4   28  338409  Bachelors             13               0     1    0   \n",
       "\n",
       "   capital-gain  capital-loss  hours-per-week  ...  occupation_4  \\\n",
       "0             1             0              40  ...             0   \n",
       "1             0             0              13  ...             0   \n",
       "2             0             0              40  ...             0   \n",
       "3             0             0              40  ...             0   \n",
       "4             0             0              40  ...             0   \n",
       "\n",
       "   occupation_5  occupation_6  occupation_7  occupation_8  occupation_9  \\\n",
       "0             0             0             0             0             0   \n",
       "1             0             0             0             0             0   \n",
       "2             0             0             0             0             0   \n",
       "3             0             0             0             0             0   \n",
       "4             0             0             0             0             0   \n",
       "\n",
       "   occupation_10  occupation_11  occupation_12  occupation_13  \n",
       "0              0              0              0              0  \n",
       "1              0              0              0              0  \n",
       "2              0              0              0              0  \n",
       "3              0              0              0              0  \n",
       "4              0              0              0              0  \n",
       "\n",
       "[5 rows x 36 columns]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train.loc[(df_train['capital-loss'] > 0),'capital-loss'] = 1\n",
    "df_train.loc[(df_train['capital-loss'] == 0 ,'capital-loss')]= 0\n",
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "e35cfe51",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>...</th>\n",
       "      <th>occupation_4</th>\n",
       "      <th>occupation_5</th>\n",
       "      <th>occupation_6</th>\n",
       "      <th>occupation_7</th>\n",
       "      <th>occupation_8</th>\n",
       "      <th>occupation_9</th>\n",
       "      <th>occupation_10</th>\n",
       "      <th>occupation_11</th>\n",
       "      <th>occupation_12</th>\n",
       "      <th>occupation_13</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.042796</td>\n",
       "      <td>-1.062722</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>1.128918</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.077734</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.880288</td>\n",
       "      <td>-1.007871</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>1.128918</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>-2.331531</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>-0.033340</td>\n",
       "      <td>0.244693</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>-0.439738</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.077734</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1.108695</td>\n",
       "      <td>0.425240</td>\n",
       "      <td>11th</td>\n",
       "      <td>-1.224066</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.077734</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>-0.794697</td>\n",
       "      <td>1.406658</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>1.128918</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.077734</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 36 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        age    fnlwgt  education  education-num  marital-status  race  sex  \\\n",
       "0  0.042796 -1.062722  Bachelors       1.128918               1     0    1   \n",
       "1  0.880288 -1.007871  Bachelors       1.128918               0     0    1   \n",
       "2 -0.033340  0.244693    HS-grad      -0.439738               1     0    1   \n",
       "3  1.108695  0.425240       11th      -1.224066               0     1    1   \n",
       "4 -0.794697  1.406658  Bachelors       1.128918               0     1    0   \n",
       "\n",
       "   capital-gain  capital-loss  hours-per-week  ...  occupation_4  \\\n",
       "0             1             0       -0.077734  ...             0   \n",
       "1             0             0       -2.331531  ...             0   \n",
       "2             0             0       -0.077734  ...             0   \n",
       "3             0             0       -0.077734  ...             0   \n",
       "4             0             0       -0.077734  ...             0   \n",
       "\n",
       "   occupation_5  occupation_6  occupation_7  occupation_8  occupation_9  \\\n",
       "0             0             0             0             0             0   \n",
       "1             0             0             0             0             0   \n",
       "2             0             0             0             0             0   \n",
       "3             0             0             0             0             0   \n",
       "4             0             0             0             0             0   \n",
       "\n",
       "   occupation_10  occupation_11  occupation_12  occupation_13  \n",
       "0              0              0              0              0  \n",
       "1              0              0              0              0  \n",
       "2              0              0              0              0  \n",
       "3              0              0              0              0  \n",
       "4              0              0              0              0  \n",
       "\n",
       "[5 rows x 36 columns]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "continuous_features = ['age', 'fnlwgt', 'education-num', 'hours-per-week']\n",
    "X = df_train[continuous_features]\n",
    "df_train[continuous_features] = (X - np.mean(X))/ np.std(X)\n",
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "6a886c60",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train = df_train.drop(['education'], axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "644f2163",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "def get_naive_dataset(dataset):\n",
    "    data_shuffled = dataset.sample(frac=1).reset_index(drop=True)\n",
    "    X = data_shuffled.drop(['income','race','sex'], axis=1)\n",
    "    y = data_shuffled['income']\n",
    "    x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.25)\n",
    "    return (x_train, y_train), (x_test, y_test)\n",
    "\n",
    "(x_train, y_train), (x_test, y_test) = get_naive_dataset(df_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "4938cea8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/shenyu/miniconda3/envs/DLcourse/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import os\n",
    "\n",
    "batch_size = 256\n",
    "def make_dataloader(data, batch_size):\n",
    "    dataset = BasicDataset(data)\n",
    "    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
    "    return dataloader\n",
    "\n",
    "\n",
    "class BasicDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, data):\n",
    "        super().__init__()\n",
    "        self.data = torch.tensor(data.values)\n",
    "        self.one_hot_relationship = self.data[:,9:15]#relationship0-relationship_5\n",
    "        self.relation_tensor = self.one_hot_relationship.argmax(dim=1)\n",
    "\n",
    "    \n",
    "    def __len__(self):\n",
    "        return self.data.size(0)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        return self.data.float()[idx], self.relation_tensor[idx]\n",
    "    \n",
    "training_loader = make_dataloader(x_train, batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "f24a99c3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Directory Adult already exists.\n",
      "ReversibleGraphNet(\n",
      "  (module_list): ModuleList(\n",
      "    (0): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (1): PermuteRandom()\n",
      "    (2): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (3): PermuteRandom()\n",
      "    (4): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (5): PermuteRandom()\n",
      "    (6): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (7): PermuteRandom()\n",
      "    (8): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (9): PermuteRandom()\n",
      "    (10): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (11): PermuteRandom()\n",
      "    (12): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (13): PermuteRandom()\n",
      "    (14): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (15): PermuteRandom()\n",
      "    (16): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (17): PermuteRandom()\n",
      "    (18): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (19): PermuteRandom()\n",
      "    (20): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (21): PermuteRandom()\n",
      "    (22): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (23): PermuteRandom()\n",
      "    (24): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (25): PermuteRandom()\n",
      "    (26): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (27): PermuteRandom()\n",
      "    (28): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (29): PermuteRandom()\n",
      "    (30): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (31): PermuteRandom()\n",
      "    (32): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (33): PermuteRandom()\n",
      "    (34): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (35): PermuteRandom()\n",
      "  )\n",
      ")\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 0: 100%|██████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.32batch/s, loss=-1.18]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save_1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.21batch/s, loss=-1.67]\n",
      "epoch 2: 100%|██████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.26batch/s, loss=-1.78]\n",
      "epoch 3: 100%|██████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.22batch/s, loss=-1.85]\n",
      "epoch 4: 100%|██████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.31batch/s, loss=-1.94]\n",
      "epoch 5: 100%|██████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.56batch/s, loss=-1.85]\n",
      "epoch 6: 100%|█████████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.55batch/s, loss=-2]\n",
      "epoch 7: 100%|██████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.62batch/s, loss=-2.08]\n",
      "epoch 8: 100%|██████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.38batch/s, loss=-2.01]\n",
      "epoch 9: 100%|██████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.20batch/s, loss=-2.08]\n",
      "epoch 10: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.28batch/s, loss=-2.17]\n",
      "epoch 11: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.21batch/s, loss=-2.22]\n",
      "epoch 12: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.45batch/s, loss=-2.16]\n",
      "epoch 13: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.33batch/s, loss=-2.25]\n",
      "epoch 14: 100%|██████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.24batch/s, loss=-2.2]\n",
      "epoch 15: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.22batch/s, loss=-2.31]\n",
      "epoch 16: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.38batch/s, loss=-2.26]\n",
      "epoch 17: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.68batch/s, loss=-2.22]\n",
      "epoch 18: 100%|██████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.15batch/s, loss=-2.2]\n",
      "epoch 19: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.60batch/s, loss=-2.15]\n",
      "epoch 20: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 16.89batch/s, loss=-2.15]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save_21\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 21: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.17batch/s, loss=-2.29]\n",
      "epoch 22: 100%|██████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.44batch/s, loss=-2.1]\n",
      "epoch 23: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.49batch/s, loss=-2.26]\n",
      "epoch 24: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.39batch/s, loss=-2.27]\n",
      "epoch 25: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.36batch/s, loss=-2.28]\n",
      "epoch 26: 100%|██████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.16batch/s, loss=-2.4]\n",
      "epoch 27: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.43batch/s, loss=-2.27]\n",
      "epoch 28: 100%|██████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.59batch/s, loss=-2.3]\n",
      "epoch 29: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.25batch/s, loss=-2.29]\n",
      "epoch 30: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.38batch/s, loss=-2.13]\n",
      "epoch 31: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.40batch/s, loss=-2.32]\n",
      "epoch 32: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.69batch/s, loss=-1.08]\n",
      "epoch 33: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.33batch/s, loss=-1.99]\n",
      "epoch 34: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.13batch/s, loss=-1.97]\n",
      "epoch 35: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.42batch/s, loss=-2.31]\n",
      "epoch 36: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.57batch/s, loss=-2.25]\n",
      "epoch 37: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.31batch/s, loss=-2.27]\n",
      "epoch 38: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.38batch/s, loss=-2.24]\n",
      "epoch 39: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.62batch/s, loss=-2.26]\n",
      "epoch 40: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.36batch/s, loss=-2.26]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save_41\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 41: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.06batch/s, loss=-2.24]\n",
      "epoch 42: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 16.81batch/s, loss=-2.23]\n",
      "epoch 43: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.43batch/s, loss=-1.55]\n",
      "epoch 44: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.40batch/s, loss=-2.11]\n",
      "epoch 45: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.42batch/s, loss=-2.21]\n",
      "epoch 46: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.37batch/s, loss=-2.29]\n",
      "epoch 47: 100%|██████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.25batch/s, loss=-2.3]\n",
      "epoch 48: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.16batch/s, loss=-2.36]\n",
      "epoch 49: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.25batch/s, loss=-2.29]\n",
      "epoch 50: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.10batch/s, loss=-2.31]\n",
      "epoch 51: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.36batch/s, loss=-2.31]\n",
      "epoch 52: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.29batch/s, loss=-2.32]\n",
      "epoch 53: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.01batch/s, loss=-2.43]\n",
      "epoch 54: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.46batch/s, loss=-1.76]\n",
      "epoch 55: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.66batch/s, loss=-2.21]\n",
      "epoch 56: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.25batch/s, loss=-2.16]\n",
      "epoch 57: 100%|██████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.42batch/s, loss=-2.3]\n",
      "epoch 58: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.25batch/s, loss=-2.33]\n",
      "epoch 59: 100%|██████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.14batch/s, loss=-2.3]\n",
      "epoch 60: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.23batch/s, loss=-2.27]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save_61\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 61: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.34batch/s, loss=-2.31]\n",
      "epoch 62: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.41batch/s, loss=-1.86]\n",
      "epoch 63: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.45batch/s, loss=-2.22]\n",
      "epoch 64: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.33batch/s, loss=-2.12]\n",
      "epoch 65: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.27batch/s, loss=-2.16]\n",
      "epoch 66: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.31batch/s, loss=-1.23]\n",
      "epoch 67: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.59batch/s, loss=-1.98]\n",
      "epoch 68: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.04batch/s, loss=-2.27]\n",
      "epoch 69: 100%|██████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.15batch/s, loss=-2.2]\n",
      "epoch 70: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.09batch/s, loss=-2.28]\n",
      "epoch 71: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.38batch/s, loss=-2.21]\n",
      "epoch 72: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.62batch/s, loss=-2.37]\n",
      "epoch 73: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.24batch/s, loss=-2.27]\n",
      "epoch 74: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.16batch/s, loss=-2.29]\n",
      "epoch 75: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.14batch/s, loss=-2.28]\n",
      "epoch 76: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.75batch/s, loss=-2.34]\n",
      "epoch 77: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.45batch/s, loss=-2.29]\n",
      "epoch 78: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.51batch/s, loss=-2.28]\n",
      "epoch 79: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.60batch/s, loss=-2.28]\n",
      "epoch 80: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.08batch/s, loss=-2.29]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save_81\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 81: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.49batch/s, loss=-2.35]\n",
      "epoch 82: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.52batch/s, loss=-2.36]\n",
      "epoch 83: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.32batch/s, loss=-2.38]\n",
      "epoch 84: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.27batch/s, loss=-2.36]\n",
      "epoch 85: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.04batch/s, loss=-2.29]\n",
      "epoch 86: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.53batch/s, loss=-2.34]\n",
      "epoch 87: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.30batch/s, loss=-2.29]\n",
      "epoch 88: 100%|██████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.38batch/s, loss=-2.3]\n",
      "epoch 89: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.25batch/s, loss=-2.37]\n",
      "epoch 90: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.37batch/s, loss=-2.35]\n",
      "epoch 91: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.39batch/s, loss=-2.31]\n",
      "epoch 92: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.36batch/s, loss=-2.33]\n",
      "epoch 93: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.22batch/s, loss=-2.37]\n",
      "epoch 94: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.61batch/s, loss=-2.39]\n",
      "epoch 95: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.28batch/s, loss=-2.43]\n",
      "epoch 96: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.24batch/s, loss=-2.32]\n",
      "epoch 97: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.38batch/s, loss=-2.31]\n",
      "epoch 98: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.34batch/s, loss=-2.37]\n",
      "epoch 99: 100%|██████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.21batch/s, loss=-2.3]\n",
      "epoch 100: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.47batch/s, loss=-2.43]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save_101\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 101: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.56batch/s, loss=-2.43]\n",
      "epoch 102: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 16.95batch/s, loss=-2.34]\n",
      "epoch 103: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.12batch/s, loss=-2.41]\n",
      "epoch 104: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.40batch/s, loss=-2.34]\n",
      "epoch 105: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.21batch/s, loss=-2.34]\n",
      "epoch 106: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 16.94batch/s, loss=-2.4]\n",
      "epoch 107: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.09batch/s, loss=-2.43]\n",
      "epoch 108: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.78batch/s, loss=-2.31]\n",
      "epoch 109: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.13batch/s, loss=-2.27]\n",
      "epoch 110: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.56batch/s, loss=-2.28]\n",
      "epoch 111: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.28batch/s, loss=-2.24]\n",
      "epoch 112: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.34batch/s, loss=-2.36]\n",
      "epoch 113: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 16.92batch/s, loss=-2.32]\n",
      "epoch 114: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.04batch/s, loss=-2.33]\n",
      "epoch 115: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.42batch/s, loss=-2.27]\n",
      "epoch 116: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.03batch/s, loss=-2.32]\n",
      "epoch 117: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 16.99batch/s, loss=-2.33]\n",
      "epoch 118: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 16.97batch/s, loss=-2.4]\n",
      "epoch 119: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.15batch/s, loss=-2.35]\n",
      "epoch 120: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.13batch/s, loss=-2.35]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save_121\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 121: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.11batch/s, loss=-2.36]\n",
      "epoch 122: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.39batch/s, loss=-2.45]\n",
      "epoch 123: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.26batch/s, loss=-2.34]\n",
      "epoch 124: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.13batch/s, loss=-2.36]\n",
      "epoch 125: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.34batch/s, loss=-2.38]\n",
      "epoch 126: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.34batch/s, loss=-2.35]\n",
      "epoch 127: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.29batch/s, loss=-2.33]\n",
      "epoch 128: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.23batch/s, loss=-2.36]\n",
      "epoch 129: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 16.92batch/s, loss=-2.4]\n",
      "epoch 130: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.24batch/s, loss=-2.41]\n",
      "epoch 131: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.49batch/s, loss=-2.36]\n",
      "epoch 132: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.69batch/s, loss=-2.43]\n",
      "epoch 133: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.21batch/s, loss=-2.41]\n",
      "epoch 134: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.50batch/s, loss=-2.33]\n",
      "epoch 135: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.03batch/s, loss=-2.29]\n",
      "epoch 136: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.37batch/s, loss=-2.36]\n",
      "epoch 137: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.56batch/s, loss=-2.42]\n",
      "epoch 138: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.11batch/s, loss=-2.34]\n",
      "epoch 139: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.09batch/s, loss=-2.43]\n",
      "epoch 140: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.16batch/s, loss=-2.37]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save_141\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 141: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.30batch/s, loss=-2.41]\n",
      "epoch 142: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.64batch/s, loss=-2.36]\n",
      "epoch 143: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.35batch/s, loss=-2.43]\n",
      "epoch 144: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.12batch/s, loss=-2.38]\n",
      "epoch 145: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.10batch/s, loss=-2.37]\n",
      "epoch 146: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.25batch/s, loss=-2.35]\n",
      "epoch 147: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 16.92batch/s, loss=-2.4]\n",
      "epoch 148: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.39batch/s, loss=-2.44]\n",
      "epoch 149: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.25batch/s, loss=-2.37]\n",
      "epoch 150: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.31batch/s, loss=-2.4]\n",
      "epoch 151: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 16.97batch/s, loss=-2.38]\n",
      "epoch 152: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.29batch/s, loss=-2.43]\n",
      "epoch 153: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.46batch/s, loss=-2.36]\n",
      "epoch 154: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 16.86batch/s, loss=-2.42]\n",
      "epoch 155: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 16.10batch/s, loss=-2.34]\n",
      "epoch 156: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 15.51batch/s, loss=-2.39]\n",
      "epoch 157: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 15.68batch/s, loss=-2.37]\n",
      "epoch 158: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 16.34batch/s, loss=-2.43]\n",
      "epoch 159: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.12batch/s, loss=-2.42]\n",
      "epoch 160: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.29batch/s, loss=-2.38]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save_161\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 161: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.39batch/s, loss=-2.35]\n",
      "epoch 162: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.12batch/s, loss=-2.4]\n",
      "epoch 163: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.05batch/s, loss=-2.36]\n",
      "epoch 164: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 16.46batch/s, loss=-2.37]\n",
      "epoch 165: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 15.58batch/s, loss=-2.38]\n",
      "epoch 166: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 15.53batch/s, loss=-2.35]\n",
      "epoch 167: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 15.28batch/s, loss=-2.44]\n",
      "epoch 168: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 15.37batch/s, loss=-2.37]\n",
      "epoch 169: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 15.82batch/s, loss=-2.39]\n",
      "epoch 170: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 15.92batch/s, loss=-2.34]\n",
      "epoch 171: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 15.62batch/s, loss=-2.44]\n",
      "epoch 172: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 16.03batch/s, loss=-2.36]\n",
      "epoch 173: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 16.20batch/s, loss=-2.37]\n",
      "epoch 174: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.27batch/s, loss=-2.47]\n",
      "epoch 175: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.39batch/s, loss=-2.37]\n",
      "epoch 176: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.69batch/s, loss=-2.32]\n",
      "epoch 177: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.26batch/s, loss=-2.44]\n",
      "epoch 178: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.44batch/s, loss=-2.35]\n",
      "epoch 179: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.24batch/s, loss=-2.45]\n",
      "epoch 180: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.10batch/s, loss=-2.43]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save_181\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 181: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.24batch/s, loss=-2.41]\n",
      "epoch 182: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 16.96batch/s, loss=-2.45]\n",
      "epoch 183: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.32batch/s, loss=-2.4]\n",
      "epoch 184: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.18batch/s, loss=-2.34]\n",
      "epoch 185: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.34batch/s, loss=-2.39]\n",
      "epoch 186: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.41batch/s, loss=-2.32]\n",
      "epoch 187: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.36batch/s, loss=-2.37]\n",
      "epoch 188: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.10batch/s, loss=-2.35]\n",
      "epoch 189: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.16batch/s, loss=-2.36]\n",
      "epoch 190: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.24batch/s, loss=-2.51]\n",
      "epoch 191: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.76batch/s, loss=-2.47]\n",
      "epoch 192: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.36batch/s, loss=-2.39]\n",
      "epoch 193: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.61batch/s, loss=-2.48]\n",
      "epoch 194: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.28batch/s, loss=-2.39]\n",
      "epoch 195: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.06batch/s, loss=-2.36]\n",
      "epoch 196: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.29batch/s, loss=-2.46]\n",
      "epoch 197: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.38batch/s, loss=-2.45]\n",
      "epoch 198: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.30batch/s, loss=-2.4]\n",
      "epoch 199: 100%|█████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.77batch/s, loss=-2.4]\n",
      "epoch 200: 100%|████████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.25batch/s, loss=-2.39]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save_201\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "from time import time\n",
    "import os\n",
    "from collections import OrderedDict\n",
    "import FrEIA.framework as Ff\n",
    "import FrEIA.modules as Fm\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "torch.set_num_threads(5)\n",
    "\n",
    "class Disentangle(nn.Module):\n",
    "    def __init__(self, train_loader, n_epochs, lr, lr_schedule, init_identity=True):\n",
    "        super().__init__()\n",
    "        directory_path = \"Adult\" \n",
    "        if not os.path.exists(directory_path):\n",
    "            os.makedirs(directory_path)\n",
    "            print(f\"Directory {directory_path} created.\")\n",
    "        else:\n",
    "            print(f\"Directory {directory_path} already exists.\")\n",
    "            \n",
    "        self.device = torch.device(\"cpu\")\n",
    "        self.n_epochs = n_epochs\n",
    "        self.lr = lr\n",
    "        self.lr_schedule = lr_schedule\n",
    "        self.init_identity = bool(init_identity)\n",
    "        self.net = construct_net(init_identity=init_identity)\n",
    "        self.n_classes = 6\n",
    "        self.n_dims = 32\n",
    "        self.train_loader = train_loader\n",
    "            \n",
    "    def forward(self, x, rev=False):\n",
    "        x, logdet_J  = self.net(x, rev=rev)\n",
    "        return x, logdet_J \n",
    "    \n",
    "    def train_model(self):\n",
    "        self.net = self.net.to(self.device)\n",
    "        self.net.train()\n",
    "        optimizer = torch.optim.Adam(self.parameters(), self.lr)\n",
    "        sched = torch.optim.lr_scheduler.MultiStepLR(optimizer, self.lr_schedule)\n",
    "        t0 = time()\n",
    "        for epoch in range(self.n_epochs):\n",
    "            self.epoch = epoch\n",
    "            with tqdm(self.train_loader, unit='batch') as tepoch:\n",
    "                for data, target in tepoch:\n",
    "                    if min([sum(target==i).item() for i in range(self.n_classes)]) < 2:\n",
    "                        continue\n",
    "                    optimizer.zero_grad()\n",
    "                    target = target.to(self.device)\n",
    "                    data = data + torch.randn_like(data)*1e-2\n",
    "                    data = data.to(self.device)\n",
    "                    z, logdet_J = self.net(data)         \n",
    "                    sig = torch.stack([z[target==i].std(0, unbiased=False) for i in range(self.n_classes)])\n",
    "                    loss = 0.5 + sig[target].log().mean(1) + 0.5*np.log(2*np.pi)\n",
    "                    loss -= logdet_J / self.n_dims\n",
    "                    loss = loss.mean()\n",
    "                    loss.backward(retain_graph=True)\n",
    "                    optimizer.step()\n",
    "                    tepoch.set_description(f\"epoch {epoch}\")\n",
    "                    tepoch.set_postfix(loss=loss.item())\n",
    "            sched.step()\n",
    "            if self.epoch % 20 == 0:\n",
    "                print(f'save_{str(self.epoch + 1)}')\n",
    "                torch.save(\n",
    "                    self.net.state_dict(), f\"Adult/model_{str(self.epoch + 1).zfill(2)}.pt\"\n",
    "                )\n",
    "\n",
    "\n",
    "def subnet_fc(c_in, c_out, init_identity):\n",
    "    subnet = nn.Sequential(nn.Linear(c_in, 32), nn.ReLU(),\n",
    "                            nn.Linear(32, 32), nn.ReLU(),\n",
    "                            nn.Linear(32,  c_out))\n",
    "    if init_identity:\n",
    "        subnet[-1].weight.data.fill_(0.)\n",
    "        subnet[-1].bias.data.fill_(0.)\n",
    "    return subnet\n",
    "\n",
    "\n",
    "def construct_net(init_identity=True):\n",
    "    block = Fm.GINCouplingBlock   \n",
    "    nodes = [Ff.InputNode(32, name='input')]\n",
    "    for k in range(18):\n",
    "        nodes.append(Ff.Node(nodes[-1], block,\n",
    "                             {'subnet_constructor':lambda c_in,c_out: subnet_fc(c_in, c_out, init_identity), 'clamp':2.0},\n",
    "                             name=F'coupling_{k}'))\n",
    "        nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom,\n",
    "                        {'seed':np.random.randint(2**23)},\n",
    "                        name=F'permute_{k+1}'))\n",
    "\n",
    "    nodes.append(Ff.OutputNode(nodes[-1], name='output'))\n",
    "    return Ff.ReversibleGraphNet(nodes)\n",
    "\n",
    "\n",
    "rep = Disentangle(train_loader = training_loader, n_epochs = 201, lr = 1e-3, lr_schedule=[70], init_identity=True)\n",
    "rep.train_model()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:DLcourse]",
   "language": "python",
   "name": "conda-env-DLcourse-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
