{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.preprocessing import StandardScaler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pd.read_csv('original_data/train.txt', sep=' ' , header=None)\n",
    "validation = pd.read_csv('original_data/vali.txt', sep=' ' , header=None)\n",
    "test = pd.read_csv('original_data/test.txt', sep=' ' , header=None)\n",
    "\n",
    "# last column is nan\n",
    "train.drop(train.columns[len(train.columns)-1], axis=1, inplace=True)\n",
    "validation.drop(validation.columns[len(validation.columns)-1], axis=1, inplace=True)\n",
    "test.drop(test.columns[len(test.columns)-1], axis=1, inplace=True)\n",
    "\n",
    "train.iloc[:, 1:] = train.iloc[:, 1:].applymap(lambda s: float(s.split(':')[1]))\n",
    "# train.to_csv('data/train_preprocessed.csv')\n",
    "\n",
    "test.iloc[:, 1:] = test.iloc[:, 1:].applymap(lambda s: float(s.split(':')[1]))\n",
    "# test.to_csv('data/test_preprocessed.csv')\n",
    "\n",
    "validation.iloc[:, 1:] = validation.iloc[:, 1:].applymap(lambda s: float(s.split(':')[1]))\n",
    "# validation.to_csv('data/validation_preprocessed.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess(df, name,SS = None, cut_off = None):\n",
    "    # change relevance to binary\n",
    "#     df[0].replace(to_replace=[1,2,3], value = 0, inplace=True)\n",
    "#     df[0].replace(to_replace=[4,5], value = 1, inplace=True)    \n",
    "    \n",
    "    # 0 = relevance, 1 = query_id\n",
    "    avoid = [0,1]\n",
    "    # these features are binary so don't standardize\n",
    "    binary_feats = [97, 98, 99, 100, 101]\n",
    "    # these are the remaining features that we will standardize\n",
    "    cont_feats = [i for i in np.arange(len(list(train))) if i not in binary_feats and i not in avoid]\n",
    "    \n",
    "    # these are the ids for queries\n",
    "\n",
    "    query_ids = set(df[1])\n",
    "    keep_idx = []\n",
    "    # for each query subsample 20 docs with at most 3 relevant\n",
    "    for query_id in query_ids:\n",
    "        # get dataframe indicies of all docs for one query\n",
    "        query_idx = df.index[df[1] == query_id].tolist()\n",
    "        #print(df.loc[query_idx][1], query_id)\n",
    "        # query must have at least 20 docs and must have something relevant in it\n",
    "        if np.sum(df.loc[query_idx][0]) > 0 and len(query_idx)>= 20 and (4 in list(df.loc[query_idx][0]) or 5 in list(df.loc[query_idx][0])):\n",
    "            chosen_idx = np.random.choice(query_idx, 20, replace = False)\n",
    "            # we are only chosing at most 3 relevant docs and there should be at least one relevant doc\n",
    "            while np.sum(df.loc[chosen_idx][0]) == 0 or (4 not in list(df.loc[chosen_idx][0]) and 5 in list(df.loc[chosen_idx][0])):# or np.sum(df.loc[chosen_idx][0]) > 3:\n",
    "                print('sampling', query_id, 'again which has', np.sum(np.sum(df.loc[query_idx][0])), 'relevant docs.')\n",
    "                chosen_idx = np.random.choice(query_idx, 20, replace = False)\n",
    "            keep_idx.extend(list(chosen_idx))\n",
    "#         else:\n",
    "#             if np.sum(np.sum(df.loc[query_idx][0])) == 0:\n",
    "#                 print('qid', query_id, 'has no relevant docs!')\n",
    "#             if len(query_idx) == 20:\n",
    "#                 print('qid', query_id, 'did not have enough docs!')\n",
    "    sampled_df = df.loc[keep_idx]\n",
    "    \n",
    "    if SS is None:\n",
    "        SS = StandardScaler().fit(sampled_df[cont_feats])\n",
    "    else:\n",
    "        print('SS supplied')\n",
    "    sampled_df[cont_feats] = SS.transform(sampled_df[cont_feats])\n",
    "    \n",
    "    # based on https://arxiv.org/pdf/1911.08054.pdf, we create binary groups based on this quality score\n",
    "    # info about the quality score found here: https://www.microsoft.com/en-us/research/project/mslr/\n",
    "    if cut_off is None:\n",
    "        cut_off = np.percentile(sampled_df[134], 40, interpolation='lower')\n",
    "    sampled_df['binary_group'] = sampled_df[134] >= cut_off\n",
    "    # we do not train on 0 = relevance, 1 = query_id, and we throw out 133 = first quality score because we use second quality score\n",
    "    valid_columns = [i for i in np.arange(138) if i not in [0,1,133]]\n",
    "    \n",
    "    X = np.array(sampled_df[valid_columns])\n",
    "    y = np.array(sampled_df[0])\n",
    "    group = np.array(sampled_df['binary_group'])\n",
    "\n",
    "    np.save('data/X_{}.npy'.format(name), X)\n",
    "    np.save('data/y_{}.npy'.format(name), y)\n",
    "    np.save('data/group_{}.npy'.format(name), group)\n",
    "    return SS, cut_off\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sampling 3082.0 again which has 9 relevant docs.\n",
      "sampling 10879.0 again which has 15 relevant docs.\n",
      "sampling 20641.0 again which has 7 relevant docs.\n",
      "sampling 20641.0 again which has 7 relevant docs.\n",
      "sampling 20641.0 again which has 7 relevant docs.\n",
      "sampling 20641.0 again which has 7 relevant docs.\n",
      "sampling 20641.0 again which has 7 relevant docs.\n",
      "sampling 20641.0 again which has 7 relevant docs.\n",
      "sampling 28102.0 again which has 11 relevant docs.\n"
     ]
    }
   ],
   "source": [
    "SS, cut_off = preprocess(train.copy(), 'train', SS = None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sampling 24643.0 again which has 7 relevant docs.\n",
      "sampling 1708.0 again which has 32 relevant docs.\n"
     ]
    }
   ],
   "source": [
    "SS, cut_off = preprocess(test.copy(), 'test', SS = None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SS, cut_off = preprocess(valid.copy(valid), 'test', SS = None)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
