{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-05-24 14:14:43--  https://storage.googleapis.com/kaggle-forum-message-attachments/237294/7771/german_credit_data.csv\n",
      "Resolving storage.googleapis.com (storage.googleapis.com)... 172.217.1.48\n",
      "Connecting to storage.googleapis.com (storage.googleapis.com)|172.217.1.48|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 53393 (52K) [text/csv]\n",
      "Saving to: ‘german_credit_data.csv’\n",
      "\n",
      "german_credit_data. 100%[===================>]  52.14K  --.-KB/s    in 0.04s   \n",
      "\n",
      "2020-05-24 14:14:44 (1.19 MB/s) - ‘german_credit_data.csv’ saved [53393/53393]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "! wget https://storage.googleapis.com/kaggle-forum-message-attachments/237294/7771/german_credit_data.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "df = pd.read_csv(\"german_credit_data.csv\", index_col=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "df = df.fillna(value=\"NA\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Age</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Job</th>\n",
       "      <th>Housing</th>\n",
       "      <th>Saving accounts</th>\n",
       "      <th>Checking account</th>\n",
       "      <th>Credit amount</th>\n",
       "      <th>Duration</th>\n",
       "      <th>Purpose</th>\n",
       "      <th>Risk</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>67</td>\n",
       "      <td>male</td>\n",
       "      <td>2</td>\n",
       "      <td>own</td>\n",
       "      <td>NA</td>\n",
       "      <td>little</td>\n",
       "      <td>1169</td>\n",
       "      <td>6</td>\n",
       "      <td>radio/TV</td>\n",
       "      <td>good</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>22</td>\n",
       "      <td>female</td>\n",
       "      <td>2</td>\n",
       "      <td>own</td>\n",
       "      <td>little</td>\n",
       "      <td>moderate</td>\n",
       "      <td>5951</td>\n",
       "      <td>48</td>\n",
       "      <td>radio/TV</td>\n",
       "      <td>bad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>49</td>\n",
       "      <td>male</td>\n",
       "      <td>1</td>\n",
       "      <td>own</td>\n",
       "      <td>little</td>\n",
       "      <td>NA</td>\n",
       "      <td>2096</td>\n",
       "      <td>12</td>\n",
       "      <td>education</td>\n",
       "      <td>good</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>45</td>\n",
       "      <td>male</td>\n",
       "      <td>2</td>\n",
       "      <td>free</td>\n",
       "      <td>little</td>\n",
       "      <td>little</td>\n",
       "      <td>7882</td>\n",
       "      <td>42</td>\n",
       "      <td>furniture/equipment</td>\n",
       "      <td>good</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>53</td>\n",
       "      <td>male</td>\n",
       "      <td>2</td>\n",
       "      <td>free</td>\n",
       "      <td>little</td>\n",
       "      <td>little</td>\n",
       "      <td>4870</td>\n",
       "      <td>24</td>\n",
       "      <td>car</td>\n",
       "      <td>bad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>995</th>\n",
       "      <td>31</td>\n",
       "      <td>female</td>\n",
       "      <td>1</td>\n",
       "      <td>own</td>\n",
       "      <td>little</td>\n",
       "      <td>NA</td>\n",
       "      <td>1736</td>\n",
       "      <td>12</td>\n",
       "      <td>furniture/equipment</td>\n",
       "      <td>good</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>996</th>\n",
       "      <td>40</td>\n",
       "      <td>male</td>\n",
       "      <td>3</td>\n",
       "      <td>own</td>\n",
       "      <td>little</td>\n",
       "      <td>little</td>\n",
       "      <td>3857</td>\n",
       "      <td>30</td>\n",
       "      <td>car</td>\n",
       "      <td>good</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>997</th>\n",
       "      <td>38</td>\n",
       "      <td>male</td>\n",
       "      <td>2</td>\n",
       "      <td>own</td>\n",
       "      <td>little</td>\n",
       "      <td>NA</td>\n",
       "      <td>804</td>\n",
       "      <td>12</td>\n",
       "      <td>radio/TV</td>\n",
       "      <td>good</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>998</th>\n",
       "      <td>23</td>\n",
       "      <td>male</td>\n",
       "      <td>2</td>\n",
       "      <td>free</td>\n",
       "      <td>little</td>\n",
       "      <td>little</td>\n",
       "      <td>1845</td>\n",
       "      <td>45</td>\n",
       "      <td>radio/TV</td>\n",
       "      <td>bad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>999</th>\n",
       "      <td>27</td>\n",
       "      <td>male</td>\n",
       "      <td>2</td>\n",
       "      <td>own</td>\n",
       "      <td>moderate</td>\n",
       "      <td>moderate</td>\n",
       "      <td>4576</td>\n",
       "      <td>45</td>\n",
       "      <td>car</td>\n",
       "      <td>good</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>1000 rows × 10 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     Age     Sex  Job Housing Saving accounts Checking account  Credit amount  \\\n",
       "0     67    male    2     own              NA           little           1169   \n",
       "1     22  female    2     own          little         moderate           5951   \n",
       "2     49    male    1     own          little               NA           2096   \n",
       "3     45    male    2    free          little           little           7882   \n",
       "4     53    male    2    free          little           little           4870   \n",
       "..   ...     ...  ...     ...             ...              ...            ...   \n",
       "995   31  female    1     own          little               NA           1736   \n",
       "996   40    male    3     own          little           little           3857   \n",
       "997   38    male    2     own          little               NA            804   \n",
       "998   23    male    2    free          little           little           1845   \n",
       "999   27    male    2     own        moderate         moderate           4576   \n",
       "\n",
       "     Duration              Purpose  Risk  \n",
       "0           6             radio/TV  good  \n",
       "1          48             radio/TV   bad  \n",
       "2          12            education  good  \n",
       "3          42  furniture/equipment  good  \n",
       "4          24                  car   bad  \n",
       "..        ...                  ...   ...  \n",
       "995        12  furniture/equipment  good  \n",
       "996        30                  car  good  \n",
       "997        12             radio/TV  good  \n",
       "998        45             radio/TV   bad  \n",
       "999        45                  car  good  \n",
       "\n",
       "[1000 rows x 10 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelEncoder\n",
    "from sklearn.compose import ColumnTransformer, make_column_transformer\n",
    "preprocess = make_column_transformer(\n",
    "    (StandardScaler(), ['Age', 'Credit amount', 'Duration']),\n",
    "    (OneHotEncoder(sparse=False), ['Sex', 'Job', 'Housing', 'Saving accounts', 'Checking account', 'Purpose', 'Risk'])\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "mat = preprocess.fit_transform(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.92754658, -0.70047155, -0.73866754,  1.        ,  0.        ,\n",
       "        0.        ,  0.        ,  1.        ,  0.        ,  0.        ,\n",
       "        0.        ,  1.        ,  0.        ,  1.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  0.        ,  1.        ,\n",
       "        0.        ,  0.        ,  1.        ,  0.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  0.        ,  1.        ,\n",
       "        0.        ])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mat[10,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = mat[:, :-2]\n",
    "Y = mat[:, -1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_feats = X.shape[1]\n",
    "numX = X.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasize = 500\n",
    "cs_size = 25\n",
    "split_on_doc = 0.8\n",
    "testsize = 100\n",
    "ratio_relevant = 0.4\n",
    "ratios_col = Y * ratio_relevant + (1-Y)*(1-ratio_relevant)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling between 0 and 800.0 for train\n",
      "Sampling between 800.0 and 1000 for test\n"
     ]
    }
   ],
   "source": [
    "# generate a candidate set of size 10 everytime\n",
    "data_X = []\n",
    "data_Y = []\n",
    "test_X = []\n",
    "test_Y = []\n",
    "group_identities_train = []\n",
    "group_identities_test = []\n",
    "print(\"Sampling between 0 and {} for train\".format(numX*split_on_doc))\n",
    "p = ratios_col[0:int(numX*split_on_doc)]\n",
    "p = p / sum(p)\n",
    "for i in range(datasize):\n",
    "    cs_indices = np.random.choice(np.arange(0, numX*split_on_doc, dtype=int), size=cs_size, p=p)\n",
    "    cs_X = X[cs_indices]\n",
    "    cs_Y = Y[cs_indices]\n",
    "    data_X.append(cs_X)\n",
    "    data_Y.append(cs_Y)\n",
    "    group_identities_train.append(cs_X[:,4])\n",
    "print(\"Sampling between {} and {} for test\".format(numX*split_on_doc, numX))\n",
    "p = ratios_col[int(numX*split_on_doc):]\n",
    "p = p/sum(p)\n",
    "for i in range(testsize):\n",
    "    cs_indices = np.random.choice(np.arange(numX*split_on_doc, numX, dtype=int), size=cs_size, p=p)\n",
    "    test_X.append(X[cs_indices])\n",
    "    test_Y.append(Y[cs_indices])\n",
    "    group_identities_test.append(X[cs_indices,4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle as pkl\n",
    "pkl.dump((data_X, data_Y), open(\"../../german/their_german_train_rank.pkl\", \"wb\"))\n",
    "pkl.dump((test_X, test_Y), open(\"../../german/their_german_test_rank.pkl\", \"wb\"))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
