{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import json\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "import pickle\n",
    "df = pd.read_csv('./adult.csv').astype('float')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.rename(columns = {'target':'y'}, inplace = True)\n",
    "df.drop(['fnlwgt', 'capitalgain'], axis=1, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>education</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capitalloss</th>\n",
       "      <th>hoursperweek</th>\n",
       "      <th>native-country</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>50.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>38.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>53.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>45217</th>\n",
       "      <td>33.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>45218</th>\n",
       "      <td>39.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>36.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>45219</th>\n",
       "      <td>38.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>50.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>45220</th>\n",
       "      <td>44.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>45221</th>\n",
       "      <td>35.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>60.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>45222 rows × 12 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        age  workclass  education  marital-status  occupation  relationship  \\\n",
       "0      39.0        0.0       13.0             0.0         0.0           0.0   \n",
       "1      50.0        1.0       13.0             1.0         1.0           1.0   \n",
       "2      38.0        2.0        8.0             2.0         2.0           0.0   \n",
       "3      53.0        2.0        6.0             1.0         2.0           1.0   \n",
       "4      28.0        2.0       13.0             1.0         3.0           2.0   \n",
       "...     ...        ...        ...             ...         ...           ...   \n",
       "45217  33.0        2.0       13.0             0.0         3.0           3.0   \n",
       "45218  39.0        2.0       13.0             2.0         3.0           0.0   \n",
       "45219  38.0        2.0       13.0             1.0         3.0           1.0   \n",
       "45220  44.0        2.0       13.0             2.0         0.0           3.0   \n",
       "45221  35.0        5.0       13.0             1.0         1.0           1.0   \n",
       "\n",
       "       race  sex  capitalloss  hoursperweek  native-country    y  \n",
       "0       0.0  0.0          0.0          40.0             0.0  0.0  \n",
       "1       0.0  0.0          0.0          13.0             0.0  0.0  \n",
       "2       0.0  0.0          0.0          40.0             0.0  0.0  \n",
       "3       1.0  0.0          0.0          40.0             0.0  0.0  \n",
       "4       1.0  1.0          0.0          40.0             1.0  0.0  \n",
       "...     ...  ...          ...           ...             ...  ...  \n",
       "45217   0.0  0.0          0.0          40.0             0.0  0.0  \n",
       "45218   0.0  1.0          0.0          36.0             0.0  0.0  \n",
       "45219   0.0  0.0          0.0          50.0             0.0  0.0  \n",
       "45220   2.0  0.0          0.0          40.0             0.0  0.0  \n",
       "45221   0.0  0.0          0.0          60.0             0.0  1.0  \n",
       "\n",
       "[45222 rows x 12 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "train, test = train_test_split(df, test_size=0.2, random_state=0)\n",
    "train.to_csv('adult_data_train.csv', index=False)\n",
    "valid, test = train_test_split(test, test_size=0.5, random_state=0)\n",
    "valid.to_csv('adult_data_valid.csv', index=False)\n",
    "test.to_csv('adult_data_test.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
    "from sklearn.compose import ColumnTransformer\n",
    "from sklearn.pipeline import Pipeline\n",
    "\n",
    "categorical = ['workclass', 'marital-status', 'occupation',\t'relationship',\t'race',\t'sex', 'native-country']\n",
    "numerical = list(train.drop('y', axis=1).columns.difference(categorical))\n",
    "\n",
    "numeric_transformer = Pipeline(steps=[\n",
    "    ('scaler', StandardScaler())])\n",
    "categorical_transformer = Pipeline(steps=[\n",
    "    ('onehot', OneHotEncoder(handle_unknown='error'))])\n",
    "\n",
    "transformations = ColumnTransformer(\n",
    "    transformers=[\n",
    "        ('num', numeric_transformer, numerical),\n",
    "        ('cat', categorical_transformer, categorical)])\n",
    "\n",
    "transformations.fit(train.drop('y', axis=1))\n",
    "all_onehot_norm = transformations.transform(df.drop('y', axis=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<45222x86 sparse matrix of type '<class 'numpy.float64'>'\n",
       "\twith 497442 stored elements in Compressed Sparse Row format>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_onehot_norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "onehot_columns = []\n",
    "n_classes = []\n",
    "for col in numerical:\n",
    "    onehot_columns.append(col)\n",
    "for col in categorical:\n",
    "    n_class = len(np.unique(train[col]))\n",
    "    n_classes.append(n_class)\n",
    "    onehot_columns.extend([col + '_' + str(i) for i in range(n_class)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_all_onehot_norm = pd.DataFrame.sparse.from_spmatrix(all_onehot_norm, columns=onehot_columns)\n",
    "df_all_onehot_norm['y'] = df['y']\n",
    "df_all_onehot_norm.to_csv('adult_all_onehot_norm.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[7, 7, 14, 6, 5, 2, 41]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n_classes"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.7 ('base': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "13f752eec1c22bf877ed13574aae0c5036489a901cc325afaf3d1c548bba661a"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
