{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "# this code is largely taken from Ashudeep Singh and found here: https://github.com/ashudeep/Fair-PGRank/blob/master/GermanCredit/German%20Credit%20Data%20Preprocessing.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-09-30 19:53:27--  https://storage.googleapis.com/kaggle-forum-message-attachments/237294/7771/german_credit_data.csv\n",
      "Resolving storage.googleapis.com (storage.googleapis.com)... 172.217.4.80, 172.217.4.48, 172.217.9.80, ...\n",
      "Connecting to storage.googleapis.com (storage.googleapis.com)|172.217.4.80|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 53393 (52K) [text/csv]\n",
      "Saving to: ‘german_credit_data.csv.1’\n",
      "\n",
      "german_credit_data. 100%[===================>]  52.14K  --.-KB/s    in 0.05s   \n",
      "\n",
      "2020-09-30 19:53:27 (1.12 MB/s) - ‘german_credit_data.csv.1’ saved [53393/53393]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "! wget https://storage.googleapis.com/kaggle-forum-message-attachments/237294/7771/german_credit_data.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "df = pd.read_csv(\"german_credit_data.csv\", index_col=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "df = df.fillna(value=\"NA\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Age</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Job</th>\n",
       "      <th>Housing</th>\n",
       "      <th>Saving accounts</th>\n",
       "      <th>Checking account</th>\n",
       "      <th>Credit amount</th>\n",
       "      <th>Duration</th>\n",
       "      <th>Purpose</th>\n",
       "      <th>Risk</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>67</td>\n",
       "      <td>male</td>\n",
       "      <td>2</td>\n",
       "      <td>own</td>\n",
       "      <td>NA</td>\n",
       "      <td>little</td>\n",
       "      <td>1169</td>\n",
       "      <td>6</td>\n",
       "      <td>radio/TV</td>\n",
       "      <td>good</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>22</td>\n",
       "      <td>female</td>\n",
       "      <td>2</td>\n",
       "      <td>own</td>\n",
       "      <td>little</td>\n",
       "      <td>moderate</td>\n",
       "      <td>5951</td>\n",
       "      <td>48</td>\n",
       "      <td>radio/TV</td>\n",
       "      <td>bad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>49</td>\n",
       "      <td>male</td>\n",
       "      <td>1</td>\n",
       "      <td>own</td>\n",
       "      <td>little</td>\n",
       "      <td>NA</td>\n",
       "      <td>2096</td>\n",
       "      <td>12</td>\n",
       "      <td>education</td>\n",
       "      <td>good</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>45</td>\n",
       "      <td>male</td>\n",
       "      <td>2</td>\n",
       "      <td>free</td>\n",
       "      <td>little</td>\n",
       "      <td>little</td>\n",
       "      <td>7882</td>\n",
       "      <td>42</td>\n",
       "      <td>furniture/equipment</td>\n",
       "      <td>good</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>53</td>\n",
       "      <td>male</td>\n",
       "      <td>2</td>\n",
       "      <td>free</td>\n",
       "      <td>little</td>\n",
       "      <td>little</td>\n",
       "      <td>4870</td>\n",
       "      <td>24</td>\n",
       "      <td>car</td>\n",
       "      <td>bad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>995</th>\n",
       "      <td>31</td>\n",
       "      <td>female</td>\n",
       "      <td>1</td>\n",
       "      <td>own</td>\n",
       "      <td>little</td>\n",
       "      <td>NA</td>\n",
       "      <td>1736</td>\n",
       "      <td>12</td>\n",
       "      <td>furniture/equipment</td>\n",
       "      <td>good</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>996</th>\n",
       "      <td>40</td>\n",
       "      <td>male</td>\n",
       "      <td>3</td>\n",
       "      <td>own</td>\n",
       "      <td>little</td>\n",
       "      <td>little</td>\n",
       "      <td>3857</td>\n",
       "      <td>30</td>\n",
       "      <td>car</td>\n",
       "      <td>good</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>997</th>\n",
       "      <td>38</td>\n",
       "      <td>male</td>\n",
       "      <td>2</td>\n",
       "      <td>own</td>\n",
       "      <td>little</td>\n",
       "      <td>NA</td>\n",
       "      <td>804</td>\n",
       "      <td>12</td>\n",
       "      <td>radio/TV</td>\n",
       "      <td>good</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>998</th>\n",
       "      <td>23</td>\n",
       "      <td>male</td>\n",
       "      <td>2</td>\n",
       "      <td>free</td>\n",
       "      <td>little</td>\n",
       "      <td>little</td>\n",
       "      <td>1845</td>\n",
       "      <td>45</td>\n",
       "      <td>radio/TV</td>\n",
       "      <td>bad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>999</th>\n",
       "      <td>27</td>\n",
       "      <td>male</td>\n",
       "      <td>2</td>\n",
       "      <td>own</td>\n",
       "      <td>moderate</td>\n",
       "      <td>moderate</td>\n",
       "      <td>4576</td>\n",
       "      <td>45</td>\n",
       "      <td>car</td>\n",
       "      <td>good</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>1000 rows × 10 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     Age     Sex  Job Housing Saving accounts Checking account  Credit amount  \\\n",
       "0     67    male    2     own              NA           little           1169   \n",
       "1     22  female    2     own          little         moderate           5951   \n",
       "2     49    male    1     own          little               NA           2096   \n",
       "3     45    male    2    free          little           little           7882   \n",
       "4     53    male    2    free          little           little           4870   \n",
       "..   ...     ...  ...     ...             ...              ...            ...   \n",
       "995   31  female    1     own          little               NA           1736   \n",
       "996   40    male    3     own          little           little           3857   \n",
       "997   38    male    2     own          little               NA            804   \n",
       "998   23    male    2    free          little           little           1845   \n",
       "999   27    male    2     own        moderate         moderate           4576   \n",
       "\n",
       "     Duration              Purpose  Risk  \n",
       "0           6             radio/TV  good  \n",
       "1          48             radio/TV   bad  \n",
       "2          12            education  good  \n",
       "3          42  furniture/equipment  good  \n",
       "4          24                  car   bad  \n",
       "..        ...                  ...   ...  \n",
       "995        12  furniture/equipment  good  \n",
       "996        30                  car  good  \n",
       "997        12             radio/TV  good  \n",
       "998        45             radio/TV   bad  \n",
       "999        45                  car  good  \n",
       "\n",
       "[1000 rows x 10 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['age_binary'] = df.apply(lambda row: 1 if row['Age'] >= 25 else 0, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelEncoder\n",
    "from sklearn.compose import ColumnTransformer, make_column_transformer\n",
    "preprocess = make_column_transformer(\n",
    "    (StandardScaler(), ['Age', 'Credit amount', 'Duration']),\n",
    "    (OneHotEncoder(sparse=False), ['Sex', 'Job', 'Housing', 'Saving accounts', 'Checking account', 'Purpose', 'Risk', 'age_binary'])\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "mat = preprocess.fit_transform(df)\n",
    "#np.random.shuffle(mat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 2.76645648, -0.74513141, -1.23647786,  0.        ,  1.        ,\n",
       "        0.        ,  0.        ,  1.        ,  0.        ,  0.        ,\n",
       "        1.        ,  0.        ,  1.        ,  0.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  1.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "        0.        ,  1.        ,  0.        ,  0.        ,  0.        ,\n",
       "        1.        ,  0.        ,  1.        ])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mat[0,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = mat[:, :-4]\n",
    "Y = mat[:, -1-2]\n",
    "age = mat[:, -1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 2.76645648, -0.74513141, -1.23647786,  0.        ,  1.        ,\n",
       "        0.        ,  0.        ,  1.        ,  0.        ,  0.        ,\n",
       "        1.        ,  0.        ,  1.        ,  0.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  1.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "        0.        ,  1.        ,  0.        ,  0.        ])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X[0,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_feats = X.shape[1]\n",
    "numX = X.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasize = 500 #number of queries in training\n",
    "query_size = 10 # number of items per query\n",
    "split_on_doc = 0.8 #seperates items for training/testing\n",
    "testsize = 100 # number of queries in testing\n",
    "ratio_relevant = .4 # out of the 10 items only 20% or 2 of them are relevant\n",
    "ratios_col = Y * ratio_relevant + (1-Y)*(1-ratio_relevant)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_weights(Y, ratio_relevant):\n",
    "    num_relevant = np.sum(Y == 1)\n",
    "    num_not_relevant = len(Y) - num_relevant\n",
    "    y = (1-ratio_relevant) / num_not_relevant\n",
    "    x = ratio_relevant / num_relevant\n",
    "    ratios_col = Y * x + (1-Y)* y\n",
    "    return ratios_col\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling between 0 and 800.0 for train\n",
      "Sampling between 800.0 and 1000 for test\n"
     ]
    }
   ],
   "source": [
    "# generate a candidate set of size 10 everytime\n",
    "data_X = np.zeros((datasize*query_size, num_feats))\n",
    "data_Y = np.zeros(datasize*query_size)\n",
    "\n",
    "test_X = np.zeros((testsize*query_size, num_feats))\n",
    "test_Y = np.zeros(testsize*query_size)\n",
    "\n",
    "group_identities_train = np.zeros(datasize*query_size)\n",
    "group_identities_test = np.zeros(testsize*query_size)\n",
    "print(\"Sampling between 0 and {} for train\".format(numX*split_on_doc))\n",
    "\n",
    "p = get_weights(Y[:int(numX*split_on_doc)], ratio_relevant)\n",
    "#p = ratios_col[0:int(numX*split_on_doc)]\n",
    "#p = p / np.sum(p)\n",
    "\n",
    "for i in range(datasize):\n",
    "    cs_indices = np.random.choice(np.arange(0, numX*split_on_doc, dtype=int), size=query_size, p=p)\n",
    "    while np.sum(Y[cs_indices]) == 0:\n",
    "        cs_indices = np.random.choice(np.arange(0, numX*split_on_doc, dtype=int), size=query_size, p=p)\n",
    "    data_X[i*query_size:(i+1)*query_size, :] = X[cs_indices, :]\n",
    "    data_Y[i*query_size:(i+1)*query_size] = Y[cs_indices]\n",
    "    group_identities_train[i*query_size:(i+1)*query_size] = X[cs_indices,4]\n",
    "    \n",
    "print(\"Sampling between {} and {} for test\".format(numX*split_on_doc, numX))\n",
    "p = get_weights(Y[int(numX*split_on_doc):], ratio_relevant)\n",
    "#p = ratios_col[int(numX*split_on_doc):]\n",
    "#p = p/sum(p)\n",
    "\n",
    "for i in range(testsize):\n",
    "    cs_indices = np.random.choice(np.arange(0, numX*(1-split_on_doc), dtype=int), size=query_size, p=p)\n",
    "    while np.sum(Y[int(numX*split_on_doc) + cs_indices]) == 0:\n",
    "        cs_indices = np.random.choice(np.arange(0, numX*(1-split_on_doc), dtype=int), size=query_size, p=p)\n",
    "    test_X[i*query_size:(i+1)*query_size, :] = X[int(numX*split_on_doc) + cs_indices, :]\n",
    "    test_Y[i*query_size:(i+1)*query_size] = Y[int(numX*split_on_doc) + cs_indices]\n",
    "    group_identities_test[i*query_size:(i+1)*query_size] = X[int(numX*split_on_doc) + cs_indices,4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle as pkl\n",
    "pkl.dump((data_X, data_Y, group_identities_train), open(\"german_train_rank.pkl\", \"wb\"))\n",
    "pkl.dump((test_X, test_Y, group_identities_test), open(\"german_test_rank.pkl\", \"wb\"))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
