{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-05-24 14:14:43--  https://storage.googleapis.com/kaggle-forum-message-attachments/237294/7771/german_credit_data.csv\n",
      "Resolving storage.googleapis.com (storage.googleapis.com)... 172.217.1.48\n",
      "Connecting to storage.googleapis.com (storage.googleapis.com)|172.217.1.48|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 53393 (52K) [text/csv]\n",
      "Saving to: ‘german_credit_data.csv’\n",
      "\n",
      "german_credit_data. 100%[===================>]  52.14K  --.-KB/s    in 0.04s   \n",
      "\n",
      "2020-05-24 14:14:44 (1.19 MB/s) - ‘german_credit_data.csv’ saved [53393/53393]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "! wget https://storage.googleapis.com/kaggle-forum-message-attachments/237294/7771/german_credit_data.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "df = pd.read_csv(\"german_credit_data.csv\", index_col=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "df = df.fillna(value=\"NA\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Age</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Job</th>\n",
       "      <th>Housing</th>\n",
       "      <th>Saving accounts</th>\n",
       "      <th>Checking account</th>\n",
       "      <th>Credit amount</th>\n",
       "      <th>Duration</th>\n",
       "      <th>Purpose</th>\n",
       "      <th>Risk</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>67</td>\n",
       "      <td>male</td>\n",
       "      <td>2</td>\n",
       "      <td>own</td>\n",
       "      <td>NA</td>\n",
       "      <td>little</td>\n",
       "      <td>1169</td>\n",
       "      <td>6</td>\n",
       "      <td>radio/TV</td>\n",
       "      <td>good</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>22</td>\n",
       "      <td>female</td>\n",
       "      <td>2</td>\n",
       "      <td>own</td>\n",
       "      <td>little</td>\n",
       "      <td>moderate</td>\n",
       "      <td>5951</td>\n",
       "      <td>48</td>\n",
       "      <td>radio/TV</td>\n",
       "      <td>bad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>49</td>\n",
       "      <td>male</td>\n",
       "      <td>1</td>\n",
       "      <td>own</td>\n",
       "      <td>little</td>\n",
       "      <td>NA</td>\n",
       "      <td>2096</td>\n",
       "      <td>12</td>\n",
       "      <td>education</td>\n",
       "      <td>good</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>45</td>\n",
       "      <td>male</td>\n",
       "      <td>2</td>\n",
       "      <td>free</td>\n",
       "      <td>little</td>\n",
       "      <td>little</td>\n",
       "      <td>7882</td>\n",
       "      <td>42</td>\n",
       "      <td>furniture/equipment</td>\n",
       "      <td>good</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>53</td>\n",
       "      <td>male</td>\n",
       "      <td>2</td>\n",
       "      <td>free</td>\n",
       "      <td>little</td>\n",
       "      <td>little</td>\n",
       "      <td>4870</td>\n",
       "      <td>24</td>\n",
       "      <td>car</td>\n",
       "      <td>bad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>995</th>\n",
       "      <td>31</td>\n",
       "      <td>female</td>\n",
       "      <td>1</td>\n",
       "      <td>own</td>\n",
       "      <td>little</td>\n",
       "      <td>NA</td>\n",
       "      <td>1736</td>\n",
       "      <td>12</td>\n",
       "      <td>furniture/equipment</td>\n",
       "      <td>good</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>996</th>\n",
       "      <td>40</td>\n",
       "      <td>male</td>\n",
       "      <td>3</td>\n",
       "      <td>own</td>\n",
       "      <td>little</td>\n",
       "      <td>little</td>\n",
       "      <td>3857</td>\n",
       "      <td>30</td>\n",
       "      <td>car</td>\n",
       "      <td>good</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>997</th>\n",
       "      <td>38</td>\n",
       "      <td>male</td>\n",
       "      <td>2</td>\n",
       "      <td>own</td>\n",
       "      <td>little</td>\n",
       "      <td>NA</td>\n",
       "      <td>804</td>\n",
       "      <td>12</td>\n",
       "      <td>radio/TV</td>\n",
       "      <td>good</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>998</th>\n",
       "      <td>23</td>\n",
       "      <td>male</td>\n",
       "      <td>2</td>\n",
       "      <td>free</td>\n",
       "      <td>little</td>\n",
       "      <td>little</td>\n",
       "      <td>1845</td>\n",
       "      <td>45</td>\n",
       "      <td>radio/TV</td>\n",
       "      <td>bad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>999</th>\n",
       "      <td>27</td>\n",
       "      <td>male</td>\n",
       "      <td>2</td>\n",
       "      <td>own</td>\n",
       "      <td>moderate</td>\n",
       "      <td>moderate</td>\n",
       "      <td>4576</td>\n",
       "      <td>45</td>\n",
       "      <td>car</td>\n",
       "      <td>good</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>1000 rows × 10 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     Age     Sex  Job Housing Saving accounts Checking account  Credit amount  \\\n",
       "0     67    male    2     own              NA           little           1169   \n",
       "1     22  female    2     own          little         moderate           5951   \n",
       "2     49    male    1     own          little               NA           2096   \n",
       "3     45    male    2    free          little           little           7882   \n",
       "4     53    male    2    free          little           little           4870   \n",
       "..   ...     ...  ...     ...             ...              ...            ...   \n",
       "995   31  female    1     own          little               NA           1736   \n",
       "996   40    male    3     own          little           little           3857   \n",
       "997   38    male    2     own          little               NA            804   \n",
       "998   23    male    2    free          little           little           1845   \n",
       "999   27    male    2     own        moderate         moderate           4576   \n",
       "\n",
       "     Duration              Purpose  Risk  \n",
       "0           6             radio/TV  good  \n",
       "1          48             radio/TV   bad  \n",
       "2          12            education  good  \n",
       "3          42  furniture/equipment  good  \n",
       "4          24                  car   bad  \n",
       "..        ...                  ...   ...  \n",
       "995        12  furniture/equipment  good  \n",
       "996        30                  car  good  \n",
       "997        12             radio/TV  good  \n",
       "998        45             radio/TV   bad  \n",
       "999        45                  car  good  \n",
       "\n",
       "[1000 rows x 10 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelEncoder\n",
    "from sklearn.compose import ColumnTransformer, make_column_transformer\n",
    "preprocess = make_column_transformer(\n",
    "    (StandardScaler(), ['Age', 'Credit amount', 'Duration']),\n",
    "    (OneHotEncoder(sparse=False), ['Sex', 'Job', 'Housing', 'Saving accounts', 'Checking account', 'Purpose', 'Risk'])\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "mat = preprocess.fit_transform(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.92754658, -0.70047155, -0.73866754,  1.        ,  0.        ,\n",
       "        0.        ,  0.        ,  1.        ,  0.        ,  0.        ,\n",
       "        0.        ,  1.        ,  0.        ,  1.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  0.        ,  1.        ,\n",
       "        0.        ,  0.        ,  1.        ,  0.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  0.        ,  1.        ,\n",
       "        0.        ])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mat[10,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = mat[:, :-2]\n",
    "Y = mat[:, -1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_feats = X.shape[1]\n",
    "numX = X.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasize = 500\n",
    "cs_size = 25\n",
    "split_on_doc = 0.8\n",
    "testsize = 100\n",
    "ratio_relevant = 0.4\n",
    "ratios_col = Y * ratio_relevant + (1-Y)*(1-ratio_relevant)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling between 0 and 800.0 for train\n",
      "Sampling between 800.0 and 1000 for test\n"
     ]
    }
   ],
   "source": [
    "# generate a candidate set of size 10 everytime\n",
    "data_X = []\n",
    "data_Y = []\n",
    "test_X = []\n",
    "test_Y = []\n",
    "group_identities_train = []\n",
    "group_identities_test = []\n",
    "print(\"Sampling between 0 and {} for train\".format(numX*split_on_doc))\n",
    "p = ratios_col[0:int(numX*split_on_doc)]\n",
    "p = p / sum(p)\n",
    "for i in range(datasize):\n",
    "    cs_indices = np.random.choice(np.arange(0, numX*split_on_doc, dtype=int), size=cs_size, p=p)\n",
    "    cs_X = X[cs_indices]\n",
    "    cs_Y = Y[cs_indices]\n",
    "    data_X.append(cs_X)\n",
    "    data_Y.append(cs_Y)\n",
    "    group_identities_train.append(cs_X[:,4])\n",
    "print(\"Sampling between {} and {} for test\".format(numX*split_on_doc, numX))\n",
    "p = ratios_col[int(numX*split_on_doc):]\n",
    "p = p/sum(p)\n",
    "for i in range(testsize):\n",
    "    cs_indices = np.random.choice(np.arange(numX*split_on_doc, numX, dtype=int), size=cs_size, p=p)\n",
    "    test_X.append(X[cs_indices])\n",
    "    test_Y.append(Y[cs_indices])\n",
    "    group_identities_test.append(X[cs_indices,4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle as pkl\n",
    "pkl.dump((data_X, data_Y), open(\"../../german/their_german_train_rank.pkl\", \"wb\"))\n",
    "pkl.dump((test_X, test_Y), open(\"../../german/their_german_test_rank.pkl\", \"wb\"))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = np.zeros((cs_size*datasize, data_X[0].shape[1]))\n",
    "relevance_train = np.zeros(cs_size*datasize)\n",
    "for i in range(len(data_X)):\n",
    "    X_train[cs_size*i:(i+1)*cs_size,:] = data_X[i]\n",
    "    relevance_train[cs_size*i:(i+1)*cs_size] = data_Y[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_test = np.zeros((cs_size*testsize, data_X[0].shape[1]))\n",
    "relevance_test = np.zeros(cs_size*testsize)\n",
    "for i in range(len(test_X)):\n",
    "    X_test[cs_size*i:(i+1)*cs_size,:] = test_X[i]\n",
    "    relevance_test[cs_size*i:(i+1)*cs_size] = test_Y[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-33-d29549dbeb33>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdata_X\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/opt/miniconda3/lib/python3.7/site-packages/IPython/core/displayhook.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, result)\u001b[0m\n\u001b[1;32m    260\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart_displayhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    261\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite_output_prompt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 262\u001b[0;31m             \u001b[0mformat_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_format_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    263\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_user_ns\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    264\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfill_exec_result\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/miniconda3/lib/python3.7/site-packages/IPython/core/displayhook.py\u001b[0m in \u001b[0;36mcompute_format_data\u001b[0;34m(self, result)\u001b[0m\n\u001b[1;32m    149\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    150\u001b[0m         \"\"\"\n\u001b[0;32m--> 151\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisplay_formatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    152\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    153\u001b[0m     \u001b[0;31m# This can be set to True by the write_output_prompt method in a subclass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/miniconda3/lib/python3.7/site-packages/IPython/core/formatters.py\u001b[0m in \u001b[0;36mformat\u001b[0;34m(self, obj, include, exclude)\u001b[0m\n\u001b[1;32m    178\u001b[0m             \u001b[0mmd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    179\u001b[0m             \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 180\u001b[0;31m                 \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mformatter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    181\u001b[0m             \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    182\u001b[0m                 \u001b[0;31m# FIXME: log the exception\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<decorator-gen-10>\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, obj)\u001b[0m\n",
      "\u001b[0;32m~/opt/miniconda3/lib/python3.7/site-packages/IPython/core/formatters.py\u001b[0m in \u001b[0;36mcatch_format_error\u001b[0;34m(method, self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    222\u001b[0m     \u001b[0;34m\"\"\"show traceback on failed format call\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    223\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 224\u001b[0;31m         \u001b[0mr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    225\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mNotImplementedError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    226\u001b[0m         \u001b[0;31m# don't warn on NotImplementedErrors\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/miniconda3/lib/python3.7/site-packages/IPython/core/formatters.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, obj)\u001b[0m\n\u001b[1;32m    700\u001b[0m                 \u001b[0mtype_pprinters\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtype_printers\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    701\u001b[0m                 deferred_pprinters=self.deferred_printers)\n\u001b[0;32m--> 702\u001b[0;31m             \u001b[0mprinter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpretty\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    703\u001b[0m             \u001b[0mprinter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflush\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    704\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mstream\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetvalue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/miniconda3/lib/python3.7/site-packages/IPython/lib/pretty.py\u001b[0m in \u001b[0;36mpretty\u001b[0;34m(self, obj)\u001b[0m\n\u001b[1;32m    375\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mcls\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtype_pprinters\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    376\u001b[0m                     \u001b[0;31m# printer registered in self.type_pprinters\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 377\u001b[0;31m                     \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtype_pprinters\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mcls\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcycle\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    378\u001b[0m                 \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    379\u001b[0m                     \u001b[0;31m# deferred printer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/miniconda3/lib/python3.7/site-packages/IPython/lib/pretty.py\u001b[0m in \u001b[0;36minner\u001b[0;34m(obj, p, cycle)\u001b[0m\n\u001b[1;32m    553\u001b[0m                 \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m','\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    554\u001b[0m                 \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbreakable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 555\u001b[0;31m             \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpretty\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    556\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    557\u001b[0m             \u001b[0;31m# Special case for 1-item tuples.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/miniconda3/lib/python3.7/site-packages/IPython/lib/pretty.py\u001b[0m in \u001b[0;36mpretty\u001b[0;34m(self, obj)\u001b[0m\n\u001b[1;32m    392\u001b[0m                         \u001b[0;32mif\u001b[0m \u001b[0mcls\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mobject\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    393\u001b[0m                                 \u001b[0;32mand\u001b[0m \u001b[0mcallable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcls\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__dict__\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'__repr__'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 394\u001b[0;31m                             \u001b[0;32mreturn\u001b[0m \u001b[0m_repr_pprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcycle\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    395\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    396\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0m_default_pprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcycle\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/miniconda3/lib/python3.7/site-packages/IPython/lib/pretty.py\u001b[0m in \u001b[0;36m_repr_pprint\u001b[0;34m(obj, p, cycle)\u001b[0m\n\u001b[1;32m    698\u001b[0m     \u001b[0;34m\"\"\"A pprint that just redirects to the normal repr function.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    699\u001b[0m     \u001b[0;31m# Find newlines and replace them with p.break_()\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 700\u001b[0;31m     \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrepr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    701\u001b[0m     \u001b[0mlines\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplitlines\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    702\u001b[0m     \u001b[0;32mwith\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgroup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/miniconda3/lib/python3.7/site-packages/numpy/core/arrayprint.py\u001b[0m in \u001b[0;36m_array_repr_implementation\u001b[0;34m(arr, max_line_width, precision, suppress_small, array2string)\u001b[0m\n\u001b[1;32m   1401\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1402\u001b[0m         lst = array2string(arr, max_line_width, precision, suppress_small,\n\u001b[0;32m-> 1403\u001b[0;31m                            ', ', prefix, suffix=suffix)\n\u001b[0m\u001b[1;32m   1404\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# show zero-length shape unless it is (0,)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1405\u001b[0m         \u001b[0mlst\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"[], shape=%s\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mrepr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/miniconda3/lib/python3.7/site-packages/numpy/core/arrayprint.py\u001b[0m in \u001b[0;36marray2string\u001b[0;34m(a, max_line_width, precision, suppress_small, separator, prefix, style, formatter, threshold, edgeitems, sign, floatmode, suffix, **kwarg)\u001b[0m\n\u001b[1;32m    710\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0;34m\"[]\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    711\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 712\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0m_array2string\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mseparator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprefix\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    713\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    714\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/miniconda3/lib/python3.7/site-packages/numpy/core/arrayprint.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    482\u001b[0m             \u001b[0mrepr_running\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    483\u001b[0m             \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 484\u001b[0;31m                 \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    485\u001b[0m             \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    486\u001b[0m                 \u001b[0mrepr_running\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdiscard\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/miniconda3/lib/python3.7/site-packages/numpy/core/arrayprint.py\u001b[0m in \u001b[0;36m_array2string\u001b[0;34m(a, options, separator, prefix)\u001b[0m\n\u001b[1;32m    517\u001b[0m     lst = _formatArray(a, format_function, options['linewidth'],\n\u001b[1;32m    518\u001b[0m                        \u001b[0mnext_line_prefix\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mseparator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptions\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'edgeitems'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 519\u001b[0;31m                        summary_insert, options['legacy'])\n\u001b[0m\u001b[1;32m    520\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mlst\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    521\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/miniconda3/lib/python3.7/site-packages/numpy/core/arrayprint.py\u001b[0m in \u001b[0;36m_formatArray\u001b[0;34m(a, format_function, line_width, next_line_prefix, separator, edge_items, summary_insert, legacy)\u001b[0m\n\u001b[1;32m    838\u001b[0m         return recurser(index=(),\n\u001b[1;32m    839\u001b[0m                         \u001b[0mhanging_indent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnext_line_prefix\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 840\u001b[0;31m                         curr_width=line_width)\n\u001b[0m\u001b[1;32m    841\u001b[0m     \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    842\u001b[0m         \u001b[0;31m# recursive closures have a cyclic reference to themselves, which\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/miniconda3/lib/python3.7/site-packages/numpy/core/arrayprint.py\u001b[0m in \u001b[0;36mrecurser\u001b[0;34m(index, hanging_indent, curr_width)\u001b[0m\n\u001b[1;32m    824\u001b[0m             \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrailing_items\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    825\u001b[0m                 nested = recurser(index + (-i,), next_hanging_indent,\n\u001b[0;32m--> 826\u001b[0;31m                                   next_width)\n\u001b[0m\u001b[1;32m    827\u001b[0m                 \u001b[0ms\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mhanging_indent\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mnested\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mline_sep\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    828\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/miniconda3/lib/python3.7/site-packages/numpy/core/arrayprint.py\u001b[0m in \u001b[0;36mrecurser\u001b[0;34m(index, hanging_indent, curr_width)\u001b[0m\n\u001b[1;32m    795\u001b[0m                 s, line = _extendLine(\n\u001b[1;32m    796\u001b[0m                     s, line, word, elem_width, hanging_indent, legacy)\n\u001b[0;32m--> 797\u001b[0;31m                 \u001b[0mline\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mseparator\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    798\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    799\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mlegacy\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'1.13'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "data_X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "pkl.dump((X_train, relevance_train, X_train[:,3]), open(\"../../german/german_train_rank.pkl\", \"wb\"))\n",
    "pkl.dump((X_test, relevance_test, X_test[:,3]), open(\"../../german/german_test_rank.pkl\", \"wb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
