{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hide\n",
    "from ipynb_path import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "from counterfactual.import_essentials import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def dict2json(dictionary: dict, file_name: str):\n",
    "    with open(file_name, \"w\") as outfile:\n",
    "        json.dump(dictionary, outfile, indent = 4)\n",
    "\n",
    "def load_json(file_name: str):\n",
    "    with open(file_name) as json_file:\n",
    "        return json.load(json_file)\n",
    "\n",
    "def update_json_file(param: dict, file_name: str):\n",
    "    if os.path.exists(file_name):\n",
    "        old_param = load_json(file_name)\n",
    "    else:\n",
    "        old_param = {}\n",
    "    # copy to old_param\n",
    "    for k in param.keys():\n",
    "        old_param[k] = param[k]\n",
    "    dict2json(old_param, file_name)\n",
    "    return old_param"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def simple_transform(df: pd.DataFrame, cat_cols: list, outcome_col: str):\n",
    "    \"\"\"construct features to [[--cont_features--], [--cat_features--], outcome_col]\"\"\"\n",
    "    cols = df.columns.tolist()\n",
    "    \n",
    "    assert outcome_col in cols\n",
    "    \n",
    "    for col in cat_cols:\n",
    "        cols.remove(col)\n",
    "    cols.remove(outcome_col)\n",
    "    cols += cat_cols\n",
    "    cols += [outcome_col]\n",
    "    return df[cols]\n",
    "\n",
    "def simple_init_params(cols: list, cat_cols: list, outcome_col: str, file_name: str):\n",
    "    # load configs of adult for no reasons\n",
    "    param = load_json(\"../counterfactual/configs/adult.json\")\n",
    "    for col in cat_cols:\n",
    "        cols.remove(col)\n",
    "    cols.remove(outcome_col)\n",
    "    # copy to cont_cols\n",
    "    cont_cols = cols\n",
    "    param['continous_cols'] = cont_cols\n",
    "    param['cat_cols'] = cat_cols\n",
    "    return update_json_file(param, file_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dummy Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def bn_func(x1, x2, x3, x4):\n",
    "    def sigmoid(x):\n",
    "        return 1 / (1 + np.exp(-x))\n",
    "    return sigmoid(10.5 * ((x1 * x2) / 8100) + 10 - np.random.normal(1, 0.1, 10000) * x3 + 1e-3 * x4)\n",
    "\n",
    "def x1_to_x3(x1):\n",
    "    return 1/3 * x1 - 5\n",
    "\n",
    "def x1x2_to_x4(x1, x2):\n",
    "    return x1 * np.log(x2 **2) / 10 - 10\n",
    "\n",
    "def bn_gen():\n",
    "    \"\"\"\n",
    "    modify code from: https://github.com/divyat09/cf-feasibility/blob/master/generativecf/scripts/simple-bn-gen.py\n",
    "    \"\"\"\n",
    "    x1 = np.random.normal(50, 15, 10000)\n",
    "    x2 = np.random.normal(35, 17, 10000)\n",
    "    x3 = x1_to_x3(x1) + np.random.normal(0, 1, 10000)\n",
    "    x4 = x1x2_to_x4(x1, x2) + np.random.normal(0, 1, 10000)\n",
    "    y= bn_func(x1, x2, x3, x4)\n",
    "\n",
    "    data = np.zeros((x1.shape[0], 5))\n",
    "    data[:, 0] = x1\n",
    "    data[:, 1] = x2\n",
    "    data[:, 2] = x3\n",
    "    data[:, 3] = x4\n",
    "    data[:, 4] = np.array(y > .5, dtype=np.int)\n",
    "\n",
    "    return pd.DataFrame(data, columns=['x1', 'x2', 'x3', 'x4', 'y'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>x3</th>\n",
       "      <th>x4</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>46.444967</td>\n",
       "      <td>57.013846</td>\n",
       "      <td>11.928117</td>\n",
       "      <td>25.851832</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>37.128196</td>\n",
       "      <td>59.523073</td>\n",
       "      <td>7.821358</td>\n",
       "      <td>19.228805</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>36.446772</td>\n",
       "      <td>-0.975861</td>\n",
       "      <td>8.394476</td>\n",
       "      <td>-10.424044</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>50.349917</td>\n",
       "      <td>21.483236</td>\n",
       "      <td>11.082811</td>\n",
       "      <td>21.117522</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>45.720811</td>\n",
       "      <td>9.802985</td>\n",
       "      <td>9.591727</td>\n",
       "      <td>10.688967</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9995</th>\n",
       "      <td>34.491868</td>\n",
       "      <td>53.871741</td>\n",
       "      <td>6.297894</td>\n",
       "      <td>17.766362</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9996</th>\n",
       "      <td>43.381417</td>\n",
       "      <td>18.289577</td>\n",
       "      <td>10.738537</td>\n",
       "      <td>15.179331</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9997</th>\n",
       "      <td>43.219984</td>\n",
       "      <td>55.060190</td>\n",
       "      <td>9.571270</td>\n",
       "      <td>25.386525</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9998</th>\n",
       "      <td>71.956096</td>\n",
       "      <td>39.327834</td>\n",
       "      <td>20.292500</td>\n",
       "      <td>42.318122</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9999</th>\n",
       "      <td>68.903752</td>\n",
       "      <td>42.067585</td>\n",
       "      <td>17.239833</td>\n",
       "      <td>41.746232</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>10000 rows × 5 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "             x1         x2         x3         x4    y\n",
       "0     46.444967  57.013846  11.928117  25.851832  1.0\n",
       "1     37.128196  59.523073   7.821358  19.228805  1.0\n",
       "2     36.446772  -0.975861   8.394476 -10.424044  1.0\n",
       "3     50.349917  21.483236  11.082811  21.117522  0.0\n",
       "4     45.720811   9.802985   9.591727  10.688967  1.0\n",
       "...         ...        ...        ...        ...  ...\n",
       "9995  34.491868  53.871741   6.297894  17.766362  1.0\n",
       "9996  43.381417  18.289577  10.738537  15.179331  1.0\n",
       "9997  43.219984  55.060190   9.571270  25.386525  1.0\n",
       "9998  71.956096  39.327834  20.292500  42.318122  0.0\n",
       "9999  68.903752  42.067585  17.239833  41.746232  0.0\n",
       "\n",
       "[10000 rows x 5 columns]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = bn_gen()\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5669.0"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(data['y'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data = bn_gen()\n",
    "data.to_csv('../data/dummy_data.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adult Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "\n",
    "def load_adult_income_dataset(path=None):\n",
    "    \"\"\"Loads adult income dataset from https://archive.ics.uci.edu/ml/datasets/Adult and prepares the data for data analysis based on https://rpubs.com/H_Zhu/235617\n",
    "    :return adult_data: returns preprocessed adult income dataset.\n",
    "\n",
    "    copy from https://github.com/interpretml/DiCE/blob/master/dice_ml/utils/helpers.py\n",
    "    \"\"\"\n",
    "    if path is None:\n",
    "        raw_data = np.genfromtxt(\n",
    "            'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data',\n",
    "            delimiter=', ',\n",
    "            dtype=str\n",
    "        )\n",
    "    else:\n",
    "        raw_data = np.genfromtxt(\n",
    "            path,\n",
    "            delimiter=', ',\n",
    "            dtype=str\n",
    "        )\n",
    "\n",
    "    #  column names from \"https://archive.ics.uci.edu/ml/datasets/Adult\"\n",
    "    column_names = ['age', 'workclass', 'fnlwgt', 'education', 'educational-num',\n",
    "                    'marital-status', 'occupation', 'relationship', 'race', 'gender',\n",
    "                    'capital-gain', 'capital-loss', 'hours-per-week', 'native-country', 'income']\n",
    "\n",
    "    adult_data = pd.DataFrame(raw_data, columns=column_names)\n",
    "\n",
    "    # For more details on how the below transformations are made, please refer to https://rpubs.com/H_Zhu/235617\n",
    "    adult_data = adult_data.astype(\n",
    "        {\"age\": np.int64, \"educational-num\": np.int64, \"hours-per-week\": np.int64})\n",
    "\n",
    "    adult_data = adult_data.replace(\n",
    "        {'workclass': {'Without-pay': 'Other/Unknown', 'Never-worked': 'Other/Unknown'}})\n",
    "    adult_data = adult_data.replace({'workclass': {\n",
    "                                    'Federal-gov': 'Government', 'State-gov': 'Government', 'Local-gov': 'Government'}})\n",
    "    adult_data = adult_data.replace(\n",
    "        {'workclass': {'Self-emp-not-inc': 'Self-Employed', 'Self-emp-inc': 'Self-Employed'}})\n",
    "    adult_data = adult_data.replace(\n",
    "        {'workclass': {'Never-worked': 'Self-Employed', 'Without-pay': 'Self-Employed'}})\n",
    "    adult_data = adult_data.replace({'workclass': {'?': 'Other/Unknown'}})\n",
    "\n",
    "    adult_data = adult_data.replace({'occupation': {'Adm-clerical': 'White-Collar', 'Craft-repair': 'Blue-Collar',\n",
    "                                                    'Exec-managerial': 'White-Collar', 'Farming-fishing': 'Blue-Collar',\n",
    "                                                    'Handlers-cleaners': 'Blue-Collar',\n",
    "                                                    'Machine-op-inspct': 'Blue-Collar', 'Other-service': 'Service',\n",
    "                                                    'Priv-house-serv': 'Service',\n",
    "                                                    'Prof-specialty': 'Professional', 'Protective-serv': 'Service',\n",
    "                                                    'Tech-support': 'Service',\n",
    "                                                    'Transport-moving': 'Blue-Collar', 'Unknown': 'Other/Unknown',\n",
    "                                                    'Armed-Forces': 'Other/Unknown', '?': 'Other/Unknown'}})\n",
    "\n",
    "    adult_data = adult_data.replace({'marital-status': {'Married-civ-spouse': 'Married',\n",
    "                                                        'Married-AF-spouse': 'Married', 'Married-spouse-absent': 'Married', 'Never-married': 'Single'}})\n",
    "\n",
    "    adult_data = adult_data.replace({'race': {'Black': 'Other', 'Asian-Pac-Islander': 'Other',\n",
    "                                              'Amer-Indian-Eskimo': 'Other'}})\n",
    "\n",
    "    adult_data = adult_data[['age', 'hours-per-week', 'workclass', 'education', 'marital-status',\n",
    "                             'occupation', 'race', 'gender', 'income']]\n",
    "\n",
    "    adult_data = adult_data.replace({'income': {'<=50K': 0, '>50K': 1}})\n",
    "\n",
    "    adult_data = adult_data.replace({'education': {'Assoc-voc': 'Assoc', 'Assoc-acdm': 'Assoc',\n",
    "                                                   '11th': 'School', '10th': 'School', '7th-8th': 'School', '9th': 'School',\n",
    "                                                   '12th': 'School', '5th-6th': 'School', '1st-4th': 'School', 'Preschool': 'School'}})\n",
    "\n",
    "    adult_data = adult_data.rename(\n",
    "        columns={'marital-status': 'marital_status', 'hours-per-week': 'hours_per_week'})\n",
    "\n",
    "    return adult_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adult = load_adult_income_dataset('../data/adult.data')\n",
    "adult.to_csv('../data/adult.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s_adult = shuffle(adult)\n",
    "s_adult.to_csv('../data/s_adult.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adult['y'] = adult['income'].apply(lambda x: uniform(0.05, 0.2) if x == 0 else uniform(0.8, 0.95))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# OULAD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def load_learning_analytic_data(path='../data/oulad'):\n",
    "    def weighted_score(x):\n",
    "        d = {}\n",
    "        total_weight = sum(x['weight'])\n",
    "        d['weight'] = total_weight\n",
    "        if sum(x['weight']) == 0:\n",
    "            d['weighted_score'] = sum(x['score']) / len(x['score'])\n",
    "        else:\n",
    "            d['weighted_score'] = sum(x['score'] * x['weight']) / sum(x['weight'])\n",
    "        return pd.DataFrame(d, index=[0])\n",
    "\n",
    "    def clicks(x):\n",
    "        types = x['activity_type']\n",
    "        sum_clicks = x['sum_click']\n",
    "    #     for t, c in zip(types, sum_clicks):\n",
    "    #         x[f\"{t}_click\"] = c\n",
    "        return pd.DataFrame({f\"{t}_click\": c for t, c in zip(types, sum_clicks)}, index=[0])\n",
    "\n",
    "    print('loading pandas dataframes...')\n",
    "\n",
    "    assessment = pd.read_csv(f'{path}/assessments.csv')\n",
    "    courses = pd.read_csv(f'{path}/courses.csv')\n",
    "    student_assessment = pd.read_csv(f'{path}/studentAssessment.csv')\n",
    "    student_info = pd.read_csv(f'{path}/studentInfo.csv')\n",
    "    student_regist = pd.read_csv(f'{path}/studentRegistration.csv')\n",
    "    student_vle = pd.read_csv(f'{path}/studentVle.csv')\n",
    "    vle = pd.read_csv(f'{path}/vle.csv')\n",
    "\n",
    "    print('preprocessing assessment...')\n",
    "\n",
    "    # note: only count for submitted assessment, not weighted for unsubmitted ones\n",
    "    assessment_merged = student_assessment.merge(assessment)\n",
    "    assessment_grouped = assessment_merged.groupby(['code_module', 'code_presentation', 'id_student']).apply(weighted_score)\n",
    "    assessment_df = assessment_grouped.reset_index(None).drop(['level_3'], axis=1)\n",
    "\n",
    "    print('preprocessing vle...')\n",
    "\n",
    "    # vle\n",
    "    grouped_vle = student_vle.merge(vle).groupby(['activity_type', 'code_module', 'code_presentation', 'id_student'])\n",
    "    sumed_vle = grouped_vle.sum().drop(['id_site', 'date', 'week_from', 'week_to'], axis=1).reset_index()\n",
    "    grouped_vle = sumed_vle.groupby(['code_module', 'code_presentation', 'id_student']).apply(clicks)\n",
    "    vle_df = grouped_vle.reset_index(None).drop(['level_3'], axis=1)\n",
    "\n",
    "    student_df = student_info.merge(assessment_df, on=['code_module', 'code_presentation', 'id_student'], how='left')\\\n",
    "        .merge(vle_df, on=['code_module', 'code_presentation', 'id_student'], how='left')\n",
    "\n",
    "    return student_df[['num_of_prev_attempts', 'weight', 'weighted_score',\n",
    "                       'forumng_click', 'homepage_click', 'oucontent_click',\n",
    "                       'resource_click', 'subpage_click', 'url_click', 'dataplus_click',\n",
    "                       'glossary_click', 'oucollaborate_click', 'quiz_click',\n",
    "                       'ouelluminate_click', 'sharedsubpage_click', 'questionnaire_click',\n",
    "                       'page_click', 'externalquiz_click', 'ouwiki_click', 'dualpane_click',\n",
    "                       'folder_click', 'repeatactivity_click', 'htmlactivity_click',\n",
    "                                      'code_module', 'gender', 'region',\n",
    "                       'highest_education', 'imd_band', 'age_band','studied_credits',\n",
    "                       'disability', 'final_result']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading pandas dataframes...\n",
      "preprocessing assessment...\n",
      "preprocessing vle...\n",
      "Wall time: 1min 15s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "student_df = load_learning_analytic_data(path='../data/oulad')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student_df = shuffle(student_df)\n",
    "student_df.to_csv('../data/s_student.csv', index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student_df = pd.read_csv('../data/s_student.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>num_of_prev_attempts</th>\n",
       "      <th>weight</th>\n",
       "      <th>weighted_score</th>\n",
       "      <th>forumng_click</th>\n",
       "      <th>homepage_click</th>\n",
       "      <th>oucontent_click</th>\n",
       "      <th>resource_click</th>\n",
       "      <th>subpage_click</th>\n",
       "      <th>url_click</th>\n",
       "      <th>dataplus_click</th>\n",
       "      <th>...</th>\n",
       "      <th>htmlactivity_click</th>\n",
       "      <th>code_module</th>\n",
       "      <th>gender</th>\n",
       "      <th>region</th>\n",
       "      <th>highest_education</th>\n",
       "      <th>imd_band</th>\n",
       "      <th>age_band</th>\n",
       "      <th>studied_credits</th>\n",
       "      <th>disability</th>\n",
       "      <th>final_result</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>81.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>AAA</td>\n",
       "      <td>M</td>\n",
       "      <td>East Anglian Region</td>\n",
       "      <td>A Level or Equivalent</td>\n",
       "      <td>70-80%</td>\n",
       "      <td>0-35</td>\n",
       "      <td>120</td>\n",
       "      <td>N</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>82.0</td>\n",
       "      <td>80.731707</td>\n",
       "      <td>558.0</td>\n",
       "      <td>191.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>63.0</td>\n",
       "      <td>79.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>BBB</td>\n",
       "      <td>F</td>\n",
       "      <td>Yorkshire Region</td>\n",
       "      <td>HE Qualification</td>\n",
       "      <td>30-40%</td>\n",
       "      <td>35-55</td>\n",
       "      <td>60</td>\n",
       "      <td>N</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>46.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>FFF</td>\n",
       "      <td>M</td>\n",
       "      <td>South Region</td>\n",
       "      <td>A Level or Equivalent</td>\n",
       "      <td>90-100%</td>\n",
       "      <td>0-35</td>\n",
       "      <td>90</td>\n",
       "      <td>N</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>200.0</td>\n",
       "      <td>31.375000</td>\n",
       "      <td>914.0</td>\n",
       "      <td>730.0</td>\n",
       "      <td>434.0</td>\n",
       "      <td>71.0</td>\n",
       "      <td>142.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>CCC</td>\n",
       "      <td>F</td>\n",
       "      <td>North Western Region</td>\n",
       "      <td>A Level or Equivalent</td>\n",
       "      <td>10-20</td>\n",
       "      <td>0-35</td>\n",
       "      <td>90</td>\n",
       "      <td>N</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>12.5</td>\n",
       "      <td>90.000000</td>\n",
       "      <td>20.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>17.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>18.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>FFF</td>\n",
       "      <td>F</td>\n",
       "      <td>North Western Region</td>\n",
       "      <td>A Level or Equivalent</td>\n",
       "      <td>60-70%</td>\n",
       "      <td>0-35</td>\n",
       "      <td>120</td>\n",
       "      <td>N</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32588</th>\n",
       "      <td>0</td>\n",
       "      <td>100.0</td>\n",
       "      <td>65.500000</td>\n",
       "      <td>466.0</td>\n",
       "      <td>266.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>36.0</td>\n",
       "      <td>88.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>BBB</td>\n",
       "      <td>F</td>\n",
       "      <td>East Anglian Region</td>\n",
       "      <td>A Level or Equivalent</td>\n",
       "      <td>10-20</td>\n",
       "      <td>0-35</td>\n",
       "      <td>60</td>\n",
       "      <td>N</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32589</th>\n",
       "      <td>2</td>\n",
       "      <td>100.0</td>\n",
       "      <td>64.500000</td>\n",
       "      <td>41.0</td>\n",
       "      <td>57.0</td>\n",
       "      <td>319.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>52.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>...</td>\n",
       "      <td>1.0</td>\n",
       "      <td>FFF</td>\n",
       "      <td>M</td>\n",
       "      <td>Scotland</td>\n",
       "      <td>HE Qualification</td>\n",
       "      <td>10-20</td>\n",
       "      <td>35-55</td>\n",
       "      <td>120</td>\n",
       "      <td>N</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32590</th>\n",
       "      <td>0</td>\n",
       "      <td>70.0</td>\n",
       "      <td>58.071429</td>\n",
       "      <td>36.0</td>\n",
       "      <td>115.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>35.0</td>\n",
       "      <td>56.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>CCC</td>\n",
       "      <td>M</td>\n",
       "      <td>West Midlands Region</td>\n",
       "      <td>A Level or Equivalent</td>\n",
       "      <td>40-50%</td>\n",
       "      <td>0-35</td>\n",
       "      <td>60</td>\n",
       "      <td>N</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32591</th>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>8.0</td>\n",
       "      <td>25.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>BBB</td>\n",
       "      <td>F</td>\n",
       "      <td>South Region</td>\n",
       "      <td>Lower Than A Level</td>\n",
       "      <td>40-50%</td>\n",
       "      <td>0-35</td>\n",
       "      <td>120</td>\n",
       "      <td>Y</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32592</th>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>72.222222</td>\n",
       "      <td>6.0</td>\n",
       "      <td>83.0</td>\n",
       "      <td>102.0</td>\n",
       "      <td>43.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>GGG</td>\n",
       "      <td>F</td>\n",
       "      <td>Yorkshire Region</td>\n",
       "      <td>Lower Than A Level</td>\n",
       "      <td>20-30%</td>\n",
       "      <td>0-35</td>\n",
       "      <td>30</td>\n",
       "      <td>N</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>32593 rows × 32 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       num_of_prev_attempts  weight  weighted_score  forumng_click  \\\n",
       "0                         0     0.0        0.000000            5.0   \n",
       "1                         0    82.0       80.731707          558.0   \n",
       "2                         0     0.0        0.000000           46.0   \n",
       "3                         0   200.0       31.375000          914.0   \n",
       "4                         0    12.5       90.000000           20.0   \n",
       "...                     ...     ...             ...            ...   \n",
       "32588                     0   100.0       65.500000          466.0   \n",
       "32589                     2   100.0       64.500000           41.0   \n",
       "32590                     0    70.0       58.071429           36.0   \n",
       "32591                     1     0.0        0.000000            8.0   \n",
       "32592                     0     0.0       72.222222            6.0   \n",
       "\n",
       "       homepage_click  oucontent_click  resource_click  subpage_click  \\\n",
       "0                16.0             81.0             0.0            8.0   \n",
       "1               191.0             14.0            63.0           79.0   \n",
       "2                 7.0              3.0             0.0            3.0   \n",
       "3               730.0            434.0            71.0          142.0   \n",
       "4                 8.0             17.0             7.0           18.0   \n",
       "...               ...              ...             ...            ...   \n",
       "32588           266.0              0.0            36.0           88.0   \n",
       "32589            57.0            319.0             3.0           52.0   \n",
       "32590           115.0              0.0            35.0           56.0   \n",
       "32591            25.0              0.0             4.0            2.0   \n",
       "32592            83.0            102.0            43.0           19.0   \n",
       "\n",
       "       url_click  dataplus_click  ...  htmlactivity_click  code_module  \\\n",
       "0            2.0             0.0  ...                 0.0          AAA   \n",
       "1           14.0             0.0  ...                 0.0          BBB   \n",
       "2            0.0             0.0  ...                 0.0          FFF   \n",
       "3            8.0             0.0  ...                 0.0          CCC   \n",
       "4            1.0             0.0  ...                 0.0          FFF   \n",
       "...          ...             ...  ...                 ...          ...   \n",
       "32588       14.0             0.0  ...                 0.0          BBB   \n",
       "32589        3.0             1.0  ...                 1.0          FFF   \n",
       "32590        2.0             0.0  ...                 0.0          CCC   \n",
       "32591        2.0             0.0  ...                 0.0          BBB   \n",
       "32592        0.0             0.0  ...                 0.0          GGG   \n",
       "\n",
       "       gender                region      highest_education  imd_band  \\\n",
       "0           M   East Anglian Region  A Level or Equivalent    70-80%   \n",
       "1           F      Yorkshire Region       HE Qualification    30-40%   \n",
       "2           M          South Region  A Level or Equivalent   90-100%   \n",
       "3           F  North Western Region  A Level or Equivalent     10-20   \n",
       "4           F  North Western Region  A Level or Equivalent    60-70%   \n",
       "...       ...                   ...                    ...       ...   \n",
       "32588       F   East Anglian Region  A Level or Equivalent     10-20   \n",
       "32589       M              Scotland       HE Qualification     10-20   \n",
       "32590       M  West Midlands Region  A Level or Equivalent    40-50%   \n",
       "32591       F          South Region     Lower Than A Level    40-50%   \n",
       "32592       F      Yorkshire Region     Lower Than A Level    20-30%   \n",
       "\n",
       "       age_band  studied_credits  disability  final_result  \n",
       "0          0-35              120           N            -1  \n",
       "1         35-55               60           N             1  \n",
       "2          0-35               90           N            -1  \n",
       "3          0-35               90           N            -1  \n",
       "4          0-35              120           N            -1  \n",
       "...         ...              ...         ...           ...  \n",
       "32588      0-35               60           N             1  \n",
       "32589     35-55              120           N             1  \n",
       "32590      0-35               60           N            -1  \n",
       "32591      0-35              120           Y            -1  \n",
       "32592      0-35               30           N             1  \n",
       "\n",
       "[32593 rows x 32 columns]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "student_df['final_result'] = student_df['final_result'].apply(lambda x: -1 if x == 0 else 1)\n",
    "student_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student_df.to_csv('../data/s_margin_student.csv', index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assessment = pd.read_csv('../data/oulad/assessments.csv')\n",
    "courses = pd.read_csv('../data/oulad/courses.csv')\n",
    "student_assessment = pd.read_csv('../data/oulad/studentAssessment.csv')\n",
    "student_info = pd.read_csv('../data/oulad/studentInfo.csv')\n",
    "student_regist = pd.read_csv('../data/oulad/studentRegistration.csv')\n",
    "student_vle = pd.read_csv('../data/oulad/studentVle.csv')\n",
    "vle = pd.read_csv('../data/oulad/vle.csv')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id_student</th>\n",
       "      <th>is_banked</th>\n",
       "      <th>score</th>\n",
       "      <th>code_module</th>\n",
       "      <th>assessment_type</th>\n",
       "      <th>weight</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>11391</td>\n",
       "      <td>0</td>\n",
       "      <td>78.0</td>\n",
       "      <td>AAA</td>\n",
       "      <td>TMA</td>\n",
       "      <td>10.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>28400</td>\n",
       "      <td>0</td>\n",
       "      <td>70.0</td>\n",
       "      <td>AAA</td>\n",
       "      <td>TMA</td>\n",
       "      <td>10.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>31604</td>\n",
       "      <td>0</td>\n",
       "      <td>72.0</td>\n",
       "      <td>AAA</td>\n",
       "      <td>TMA</td>\n",
       "      <td>10.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>32885</td>\n",
       "      <td>0</td>\n",
       "      <td>69.0</td>\n",
       "      <td>AAA</td>\n",
       "      <td>TMA</td>\n",
       "      <td>10.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>38053</td>\n",
       "      <td>0</td>\n",
       "      <td>79.0</td>\n",
       "      <td>AAA</td>\n",
       "      <td>TMA</td>\n",
       "      <td>10.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>173907</th>\n",
       "      <td>527538</td>\n",
       "      <td>0</td>\n",
       "      <td>60.0</td>\n",
       "      <td>GGG</td>\n",
       "      <td>CMA</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>173908</th>\n",
       "      <td>534672</td>\n",
       "      <td>0</td>\n",
       "      <td>100.0</td>\n",
       "      <td>GGG</td>\n",
       "      <td>CMA</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>173909</th>\n",
       "      <td>546286</td>\n",
       "      <td>0</td>\n",
       "      <td>80.0</td>\n",
       "      <td>GGG</td>\n",
       "      <td>CMA</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>173910</th>\n",
       "      <td>546724</td>\n",
       "      <td>0</td>\n",
       "      <td>100.0</td>\n",
       "      <td>GGG</td>\n",
       "      <td>CMA</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>173911</th>\n",
       "      <td>558486</td>\n",
       "      <td>0</td>\n",
       "      <td>80.0</td>\n",
       "      <td>GGG</td>\n",
       "      <td>CMA</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>173912 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        id_student  is_banked  score code_module assessment_type  weight\n",
       "0            11391          0   78.0         AAA             TMA    10.0\n",
       "1            28400          0   70.0         AAA             TMA    10.0\n",
       "2            31604          0   72.0         AAA             TMA    10.0\n",
       "3            32885          0   69.0         AAA             TMA    10.0\n",
       "4            38053          0   79.0         AAA             TMA    10.0\n",
       "...            ...        ...    ...         ...             ...     ...\n",
       "173907      527538          0   60.0         GGG             CMA     0.0\n",
       "173908      534672          0  100.0         GGG             CMA     0.0\n",
       "173909      546286          0   80.0         GGG             CMA     0.0\n",
       "173910      546724          0  100.0         GGG             CMA     0.0\n",
       "173911      558486          0   80.0         GGG             CMA     0.0\n",
       "\n",
       "[173912 rows x 6 columns]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s_assess = student_assessment.merge(assessment)\n",
    "s_assess[['id_student', 'is_banked', 'score', 'code_module', 'assessment_type', 'weight']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wall time: 15.9 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "def weighted_score(x):\n",
    "    d = {}\n",
    "    total_weight = sum(x['weight'])\n",
    "    d['weight'] = total_weight\n",
    "    if sum(x['weight']) == 0:\n",
    "        d['weighted_score'] = sum(x['score']) / len(x['score'])\n",
    "    else:\n",
    "        d['weighted_score'] = sum(x['score'] * x['weight']) / sum(x['weight'])\n",
    "    return pd.DataFrame(d, index=[0])\n",
    "\n",
    "s_assess = s_assess.groupby(['code_module', 'code_presentation', 'id_student']).apply(weighted_score).reset_index(None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wall time: 54 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "def clicks(x):\n",
    "    \n",
    "    types = x['activity_type']\n",
    "    sum_clicks = x['sum_click']\n",
    "#     for t, c in zip(types, sum_clicks):\n",
    "#         x[f\"{t}_click\"] = c\n",
    "    return pd.DataFrame({f\"{t}_click\": c for t, c in zip(types, sum_clicks)}, index=[0])\n",
    "\n",
    "grouped = student_vle.merge(vle).groupby(['activity_type', 'code_module', 'code_presentation', 'id_student'])\n",
    "grouped_vle = grouped.sum().drop(['id_site', 'date', 'week_from', 'week_to'], axis=1).reset_index()\n",
    "grouped_vle = grouped_vle.groupby(['code_module', 'code_presentation', 'id_student']).apply(clicks)\n",
    "grouped_vle = grouped_vle.reset_index(None).drop(['level_3'], axis=1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>code_module</th>\n",
       "      <th>code_presentation</th>\n",
       "      <th>id_student</th>\n",
       "      <th>forumng_click</th>\n",
       "      <th>homepage_click</th>\n",
       "      <th>oucontent_click</th>\n",
       "      <th>resource_click</th>\n",
       "      <th>subpage_click</th>\n",
       "      <th>url_click</th>\n",
       "      <th>dataplus_click</th>\n",
       "      <th>...</th>\n",
       "      <th>ouelluminate_click</th>\n",
       "      <th>sharedsubpage_click</th>\n",
       "      <th>questionnaire_click</th>\n",
       "      <th>page_click</th>\n",
       "      <th>externalquiz_click</th>\n",
       "      <th>ouwiki_click</th>\n",
       "      <th>dualpane_click</th>\n",
       "      <th>folder_click</th>\n",
       "      <th>repeatactivity_click</th>\n",
       "      <th>htmlactivity_click</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>AAA</td>\n",
       "      <td>2013J</td>\n",
       "      <td>11391</td>\n",
       "      <td>193.0</td>\n",
       "      <td>138.0</td>\n",
       "      <td>553.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>32.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>AAA</td>\n",
       "      <td>2013J</td>\n",
       "      <td>28400</td>\n",
       "      <td>417.0</td>\n",
       "      <td>324.0</td>\n",
       "      <td>537.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>87.0</td>\n",
       "      <td>48.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>AAA</td>\n",
       "      <td>2013J</td>\n",
       "      <td>30268</td>\n",
       "      <td>126.0</td>\n",
       "      <td>59.0</td>\n",
       "      <td>66.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>AAA</td>\n",
       "      <td>2013J</td>\n",
       "      <td>31604</td>\n",
       "      <td>634.0</td>\n",
       "      <td>432.0</td>\n",
       "      <td>836.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>144.0</td>\n",
       "      <td>90.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>AAA</td>\n",
       "      <td>2013J</td>\n",
       "      <td>32885</td>\n",
       "      <td>194.0</td>\n",
       "      <td>204.0</td>\n",
       "      <td>494.0</td>\n",
       "      <td>45.0</td>\n",
       "      <td>79.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29223</th>\n",
       "      <td>GGG</td>\n",
       "      <td>2014J</td>\n",
       "      <td>2640965</td>\n",
       "      <td>NaN</td>\n",
       "      <td>22.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29224</th>\n",
       "      <td>GGG</td>\n",
       "      <td>2014J</td>\n",
       "      <td>2645731</td>\n",
       "      <td>65.0</td>\n",
       "      <td>167.0</td>\n",
       "      <td>348.0</td>\n",
       "      <td>109.0</td>\n",
       "      <td>47.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29225</th>\n",
       "      <td>GGG</td>\n",
       "      <td>2014J</td>\n",
       "      <td>2648187</td>\n",
       "      <td>NaN</td>\n",
       "      <td>63.0</td>\n",
       "      <td>79.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>20.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29226</th>\n",
       "      <td>GGG</td>\n",
       "      <td>2014J</td>\n",
       "      <td>2679821</td>\n",
       "      <td>118.0</td>\n",
       "      <td>65.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29227</th>\n",
       "      <td>GGG</td>\n",
       "      <td>2014J</td>\n",
       "      <td>2684003</td>\n",
       "      <td>111.0</td>\n",
       "      <td>120.0</td>\n",
       "      <td>265.0</td>\n",
       "      <td>33.0</td>\n",
       "      <td>27.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>29228 rows × 23 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      code_module code_presentation  id_student  forumng_click  \\\n",
       "0             AAA             2013J       11391          193.0   \n",
       "1             AAA             2013J       28400          417.0   \n",
       "2             AAA             2013J       30268          126.0   \n",
       "3             AAA             2013J       31604          634.0   \n",
       "4             AAA             2013J       32885          194.0   \n",
       "...           ...               ...         ...            ...   \n",
       "29223         GGG             2014J     2640965            NaN   \n",
       "29224         GGG             2014J     2645731           65.0   \n",
       "29225         GGG             2014J     2648187            NaN   \n",
       "29226         GGG             2014J     2679821          118.0   \n",
       "29227         GGG             2014J     2684003          111.0   \n",
       "\n",
       "       homepage_click  oucontent_click  resource_click  subpage_click  \\\n",
       "0               138.0            553.0            13.0           32.0   \n",
       "1               324.0            537.0            12.0           87.0   \n",
       "2                59.0             66.0             4.0           22.0   \n",
       "3               432.0            836.0            19.0          144.0   \n",
       "4               204.0            494.0            45.0           79.0   \n",
       "...               ...              ...             ...            ...   \n",
       "29223            22.0              6.0             4.0            9.0   \n",
       "29224           167.0            348.0           109.0           47.0   \n",
       "29225            63.0             79.0            19.0           20.0   \n",
       "29226            65.0             40.0             9.0           12.0   \n",
       "29227           120.0            265.0            33.0           27.0   \n",
       "\n",
       "       url_click  dataplus_click  ...  ouelluminate_click  \\\n",
       "0            5.0             NaN  ...                 NaN   \n",
       "1           48.0            10.0  ...                 NaN   \n",
       "2            4.0             NaN  ...                 NaN   \n",
       "3           90.0             2.0  ...                 NaN   \n",
       "4           14.0             NaN  ...                 NaN   \n",
       "...          ...             ...  ...                 ...   \n",
       "29223        NaN             NaN  ...                 NaN   \n",
       "29224        NaN             NaN  ...                 NaN   \n",
       "29225        NaN             NaN  ...                 NaN   \n",
       "29226        NaN             NaN  ...                 NaN   \n",
       "29227        NaN             NaN  ...                 NaN   \n",
       "\n",
       "       sharedsubpage_click  questionnaire_click  page_click  \\\n",
       "0                      NaN                  NaN         NaN   \n",
       "1                      NaN                  NaN         NaN   \n",
       "2                      NaN                  NaN         NaN   \n",
       "3                      NaN                  NaN         NaN   \n",
       "4                      NaN                  NaN         NaN   \n",
       "...                    ...                  ...         ...   \n",
       "29223                  NaN                  NaN         NaN   \n",
       "29224                  NaN                  NaN         NaN   \n",
       "29225                  NaN                  NaN         NaN   \n",
       "29226                  NaN                  NaN         NaN   \n",
       "29227                  NaN                  NaN         NaN   \n",
       "\n",
       "       externalquiz_click  ouwiki_click  dualpane_click  folder_click  \\\n",
       "0                     NaN           NaN             NaN           NaN   \n",
       "1                     NaN           NaN             NaN           NaN   \n",
       "2                     NaN           NaN             NaN           NaN   \n",
       "3                     NaN           NaN             NaN           NaN   \n",
       "4                     NaN           NaN             NaN           NaN   \n",
       "...                   ...           ...             ...           ...   \n",
       "29223                 NaN           NaN             NaN           NaN   \n",
       "29224                 NaN           NaN             NaN           NaN   \n",
       "29225                 NaN           NaN             NaN           NaN   \n",
       "29226                 NaN           NaN             NaN           NaN   \n",
       "29227                 NaN           NaN             NaN           NaN   \n",
       "\n",
       "       repeatactivity_click  htmlactivity_click  \n",
       "0                       NaN                 NaN  \n",
       "1                       NaN                 NaN  \n",
       "2                       NaN                 NaN  \n",
       "3                       NaN                 NaN  \n",
       "4                       NaN                 NaN  \n",
       "...                     ...                 ...  \n",
       "29223                   NaN                 NaN  \n",
       "29224                   NaN                 NaN  \n",
       "29225                   NaN                 NaN  \n",
       "29226                   NaN                 NaN  \n",
       "29227                   NaN                 NaN  \n",
       "\n",
       "[29228 rows x 23 columns]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grouped_vle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>code_module</th>\n",
       "      <th>code_presentation</th>\n",
       "      <th>id_student</th>\n",
       "      <th>gender</th>\n",
       "      <th>region</th>\n",
       "      <th>highest_education</th>\n",
       "      <th>imd_band</th>\n",
       "      <th>age_band</th>\n",
       "      <th>num_of_prev_attempts</th>\n",
       "      <th>studied_credits</th>\n",
       "      <th>...</th>\n",
       "      <th>page_click</th>\n",
       "      <th>externalquiz_click</th>\n",
       "      <th>ouwiki_click</th>\n",
       "      <th>dualpane_click</th>\n",
       "      <th>folder_click</th>\n",
       "      <th>repeatactivity_click</th>\n",
       "      <th>htmlactivity_click</th>\n",
       "      <th>level_3</th>\n",
       "      <th>weight</th>\n",
       "      <th>weighted_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>AAA</td>\n",
       "      <td>2013J</td>\n",
       "      <td>11391</td>\n",
       "      <td>M</td>\n",
       "      <td>East Anglian Region</td>\n",
       "      <td>HE Qualification</td>\n",
       "      <td>90-100%</td>\n",
       "      <td>55&lt;=</td>\n",
       "      <td>0.0</td>\n",
       "      <td>240.0</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>100.0</td>\n",
       "      <td>82.400000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>AAA</td>\n",
       "      <td>2013J</td>\n",
       "      <td>28400</td>\n",
       "      <td>F</td>\n",
       "      <td>Scotland</td>\n",
       "      <td>HE Qualification</td>\n",
       "      <td>20-30%</td>\n",
       "      <td>35-55</td>\n",
       "      <td>0.0</td>\n",
       "      <td>60.0</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>100.0</td>\n",
       "      <td>65.400000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>AAA</td>\n",
       "      <td>2013J</td>\n",
       "      <td>31604</td>\n",
       "      <td>F</td>\n",
       "      <td>South East Region</td>\n",
       "      <td>A Level or Equivalent</td>\n",
       "      <td>50-60%</td>\n",
       "      <td>35-55</td>\n",
       "      <td>0.0</td>\n",
       "      <td>60.0</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>100.0</td>\n",
       "      <td>76.300000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>AAA</td>\n",
       "      <td>2013J</td>\n",
       "      <td>32885</td>\n",
       "      <td>F</td>\n",
       "      <td>West Midlands Region</td>\n",
       "      <td>Lower Than A Level</td>\n",
       "      <td>50-60%</td>\n",
       "      <td>0-35</td>\n",
       "      <td>0.0</td>\n",
       "      <td>60.0</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>100.0</td>\n",
       "      <td>55.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>AAA</td>\n",
       "      <td>2013J</td>\n",
       "      <td>38053</td>\n",
       "      <td>M</td>\n",
       "      <td>Wales</td>\n",
       "      <td>A Level or Equivalent</td>\n",
       "      <td>80-90%</td>\n",
       "      <td>35-55</td>\n",
       "      <td>0.0</td>\n",
       "      <td>60.0</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>100.0</td>\n",
       "      <td>66.900000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25838</th>\n",
       "      <td>GGG</td>\n",
       "      <td>2014J</td>\n",
       "      <td>2620947</td>\n",
       "      <td>F</td>\n",
       "      <td>Scotland</td>\n",
       "      <td>A Level or Equivalent</td>\n",
       "      <td>80-90%</td>\n",
       "      <td>0-35</td>\n",
       "      <td>0.0</td>\n",
       "      <td>30.0</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>88.888889</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25839</th>\n",
       "      <td>GGG</td>\n",
       "      <td>2014J</td>\n",
       "      <td>2645731</td>\n",
       "      <td>F</td>\n",
       "      <td>East Anglian Region</td>\n",
       "      <td>Lower Than A Level</td>\n",
       "      <td>40-50%</td>\n",
       "      <td>35-55</td>\n",
       "      <td>0.0</td>\n",
       "      <td>30.0</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>88.111111</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25840</th>\n",
       "      <td>GGG</td>\n",
       "      <td>2014J</td>\n",
       "      <td>2648187</td>\n",
       "      <td>F</td>\n",
       "      <td>South Region</td>\n",
       "      <td>A Level or Equivalent</td>\n",
       "      <td>20-30%</td>\n",
       "      <td>0-35</td>\n",
       "      <td>0.0</td>\n",
       "      <td>30.0</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>76.666667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25841</th>\n",
       "      <td>GGG</td>\n",
       "      <td>2014J</td>\n",
       "      <td>2679821</td>\n",
       "      <td>F</td>\n",
       "      <td>South East Region</td>\n",
       "      <td>Lower Than A Level</td>\n",
       "      <td>90-100%</td>\n",
       "      <td>35-55</td>\n",
       "      <td>0.0</td>\n",
       "      <td>30.0</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>91.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25842</th>\n",
       "      <td>GGG</td>\n",
       "      <td>2014J</td>\n",
       "      <td>2684003</td>\n",
       "      <td>F</td>\n",
       "      <td>Yorkshire Region</td>\n",
       "      <td>HE Qualification</td>\n",
       "      <td>50-60%</td>\n",
       "      <td>35-55</td>\n",
       "      <td>0.0</td>\n",
       "      <td>30.0</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>82.857143</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>25843 rows × 35 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      code_module code_presentation  id_student gender                region  \\\n",
       "0             AAA             2013J       11391      M   East Anglian Region   \n",
       "1             AAA             2013J       28400      F              Scotland   \n",
       "2             AAA             2013J       31604      F     South East Region   \n",
       "3             AAA             2013J       32885      F  West Midlands Region   \n",
       "4             AAA             2013J       38053      M                 Wales   \n",
       "...           ...               ...         ...    ...                   ...   \n",
       "25838         GGG             2014J     2620947      F              Scotland   \n",
       "25839         GGG             2014J     2645731      F   East Anglian Region   \n",
       "25840         GGG             2014J     2648187      F          South Region   \n",
       "25841         GGG             2014J     2679821      F     South East Region   \n",
       "25842         GGG             2014J     2684003      F      Yorkshire Region   \n",
       "\n",
       "           highest_education imd_band age_band  num_of_prev_attempts  \\\n",
       "0           HE Qualification  90-100%     55<=                   0.0   \n",
       "1           HE Qualification   20-30%    35-55                   0.0   \n",
       "2      A Level or Equivalent   50-60%    35-55                   0.0   \n",
       "3         Lower Than A Level   50-60%     0-35                   0.0   \n",
       "4      A Level or Equivalent   80-90%    35-55                   0.0   \n",
       "...                      ...      ...      ...                   ...   \n",
       "25838  A Level or Equivalent   80-90%     0-35                   0.0   \n",
       "25839     Lower Than A Level   40-50%    35-55                   0.0   \n",
       "25840  A Level or Equivalent   20-30%     0-35                   0.0   \n",
       "25841     Lower Than A Level  90-100%    35-55                   0.0   \n",
       "25842       HE Qualification   50-60%    35-55                   0.0   \n",
       "\n",
       "       studied_credits  ... page_click externalquiz_click  ouwiki_click  \\\n",
       "0                240.0  ...        NaN                NaN           NaN   \n",
       "1                 60.0  ...        NaN                NaN           NaN   \n",
       "2                 60.0  ...        NaN                NaN           NaN   \n",
       "3                 60.0  ...        NaN                NaN           NaN   \n",
       "4                 60.0  ...        NaN                NaN           NaN   \n",
       "...                ...  ...        ...                ...           ...   \n",
       "25838             30.0  ...        NaN                NaN           NaN   \n",
       "25839             30.0  ...        NaN                NaN           NaN   \n",
       "25840             30.0  ...        NaN                NaN           NaN   \n",
       "25841             30.0  ...        NaN                NaN           NaN   \n",
       "25842             30.0  ...        NaN                NaN           NaN   \n",
       "\n",
       "       dualpane_click  folder_click  repeatactivity_click  htmlactivity_click  \\\n",
       "0                 NaN           NaN                   NaN                 NaN   \n",
       "1                 NaN           NaN                   NaN                 NaN   \n",
       "2                 NaN           NaN                   NaN                 NaN   \n",
       "3                 NaN           NaN                   NaN                 NaN   \n",
       "4                 NaN           NaN                   NaN                 NaN   \n",
       "...               ...           ...                   ...                 ...   \n",
       "25838             NaN           NaN                   NaN                 NaN   \n",
       "25839             NaN           NaN                   NaN                 NaN   \n",
       "25840             NaN           NaN                   NaN                 NaN   \n",
       "25841             NaN           NaN                   NaN                 NaN   \n",
       "25842             NaN           NaN                   NaN                 NaN   \n",
       "\n",
       "       level_3  weight  weighted_score  \n",
       "0            0   100.0       82.400000  \n",
       "1            0   100.0       65.400000  \n",
       "2            0   100.0       76.300000  \n",
       "3            0   100.0       55.000000  \n",
       "4            0   100.0       66.900000  \n",
       "...        ...     ...             ...  \n",
       "25838        0     0.0       88.888889  \n",
       "25839        0     0.0       88.111111  \n",
       "25840        0     0.0       76.666667  \n",
       "25841        0     0.0       91.500000  \n",
       "25842        0     0.0       82.857143  \n",
       "\n",
       "[25843 rows x 35 columns]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "student_info.merge(grouped_vle, on=['code_module', 'code_presentation', 'id_student'], how='right').merge(s_assess, on=['code_module', 'code_presentation', 'id_student'], how='right')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# HELOC\n",
    "\n",
    "home equity line of credit\n",
    "https://community.fico.com/s/explainable-machine-learning-challenge"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>RiskPerformance</th>\n",
       "      <th>ExternalRiskEstimate</th>\n",
       "      <th>MSinceOldestTradeOpen</th>\n",
       "      <th>MSinceMostRecentTradeOpen</th>\n",
       "      <th>AverageMInFile</th>\n",
       "      <th>NumSatisfactoryTrades</th>\n",
       "      <th>NumTrades60Ever2DerogPubRec</th>\n",
       "      <th>NumTrades90Ever2DerogPubRec</th>\n",
       "      <th>PercentTradesNeverDelq</th>\n",
       "      <th>MSinceMostRecentDelq</th>\n",
       "      <th>...</th>\n",
       "      <th>PercentInstallTrades</th>\n",
       "      <th>MSinceMostRecentInqexcl7days</th>\n",
       "      <th>NumInqLast6M</th>\n",
       "      <th>NumInqLast6Mexcl7days</th>\n",
       "      <th>NetFractionRevolvingBurden</th>\n",
       "      <th>NetFractionInstallBurden</th>\n",
       "      <th>NumRevolvingTradesWBalance</th>\n",
       "      <th>NumInstallTradesWBalance</th>\n",
       "      <th>NumBank2NatlTradesWHighUtilization</th>\n",
       "      <th>PercentTradesWBalance</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Bad</td>\n",
       "      <td>55</td>\n",
       "      <td>144</td>\n",
       "      <td>4</td>\n",
       "      <td>84</td>\n",
       "      <td>20</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>83</td>\n",
       "      <td>2</td>\n",
       "      <td>...</td>\n",
       "      <td>43</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>33</td>\n",
       "      <td>-8</td>\n",
       "      <td>8</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>69</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Bad</td>\n",
       "      <td>61</td>\n",
       "      <td>58</td>\n",
       "      <td>15</td>\n",
       "      <td>41</td>\n",
       "      <td>2</td>\n",
       "      <td>4</td>\n",
       "      <td>4</td>\n",
       "      <td>100</td>\n",
       "      <td>-7</td>\n",
       "      <td>...</td>\n",
       "      <td>67</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>-8</td>\n",
       "      <td>0</td>\n",
       "      <td>-8</td>\n",
       "      <td>-8</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Bad</td>\n",
       "      <td>67</td>\n",
       "      <td>66</td>\n",
       "      <td>5</td>\n",
       "      <td>24</td>\n",
       "      <td>9</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>100</td>\n",
       "      <td>-7</td>\n",
       "      <td>...</td>\n",
       "      <td>44</td>\n",
       "      <td>0</td>\n",
       "      <td>4</td>\n",
       "      <td>4</td>\n",
       "      <td>53</td>\n",
       "      <td>66</td>\n",
       "      <td>4</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>86</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Bad</td>\n",
       "      <td>66</td>\n",
       "      <td>169</td>\n",
       "      <td>1</td>\n",
       "      <td>73</td>\n",
       "      <td>28</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>93</td>\n",
       "      <td>76</td>\n",
       "      <td>...</td>\n",
       "      <td>57</td>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "      <td>4</td>\n",
       "      <td>72</td>\n",
       "      <td>83</td>\n",
       "      <td>6</td>\n",
       "      <td>4</td>\n",
       "      <td>3</td>\n",
       "      <td>91</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Bad</td>\n",
       "      <td>81</td>\n",
       "      <td>333</td>\n",
       "      <td>27</td>\n",
       "      <td>132</td>\n",
       "      <td>12</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>100</td>\n",
       "      <td>-7</td>\n",
       "      <td>...</td>\n",
       "      <td>25</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>51</td>\n",
       "      <td>89</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>80</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10454</th>\n",
       "      <td>Good</td>\n",
       "      <td>73</td>\n",
       "      <td>131</td>\n",
       "      <td>5</td>\n",
       "      <td>57</td>\n",
       "      <td>21</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>95</td>\n",
       "      <td>80</td>\n",
       "      <td>...</td>\n",
       "      <td>19</td>\n",
       "      <td>7</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>26</td>\n",
       "      <td>-8</td>\n",
       "      <td>5</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10455</th>\n",
       "      <td>Bad</td>\n",
       "      <td>65</td>\n",
       "      <td>147</td>\n",
       "      <td>39</td>\n",
       "      <td>68</td>\n",
       "      <td>11</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>92</td>\n",
       "      <td>28</td>\n",
       "      <td>...</td>\n",
       "      <td>42</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>86</td>\n",
       "      <td>53</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>80</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10456</th>\n",
       "      <td>Bad</td>\n",
       "      <td>74</td>\n",
       "      <td>129</td>\n",
       "      <td>6</td>\n",
       "      <td>64</td>\n",
       "      <td>18</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>100</td>\n",
       "      <td>-7</td>\n",
       "      <td>...</td>\n",
       "      <td>33</td>\n",
       "      <td>3</td>\n",
       "      <td>4</td>\n",
       "      <td>4</td>\n",
       "      <td>6</td>\n",
       "      <td>-8</td>\n",
       "      <td>5</td>\n",
       "      <td>-8</td>\n",
       "      <td>0</td>\n",
       "      <td>56</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10457</th>\n",
       "      <td>Bad</td>\n",
       "      <td>72</td>\n",
       "      <td>234</td>\n",
       "      <td>12</td>\n",
       "      <td>113</td>\n",
       "      <td>42</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>96</td>\n",
       "      <td>35</td>\n",
       "      <td>...</td>\n",
       "      <td>20</td>\n",
       "      <td>6</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>19</td>\n",
       "      <td>-8</td>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>38</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10458</th>\n",
       "      <td>Bad</td>\n",
       "      <td>66</td>\n",
       "      <td>28</td>\n",
       "      <td>1</td>\n",
       "      <td>17</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>100</td>\n",
       "      <td>-7</td>\n",
       "      <td>...</td>\n",
       "      <td>60</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>67</td>\n",
       "      <td>-8</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>10459 rows × 24 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      RiskPerformance  ExternalRiskEstimate  MSinceOldestTradeOpen  \\\n",
       "0                 Bad                    55                    144   \n",
       "1                 Bad                    61                     58   \n",
       "2                 Bad                    67                     66   \n",
       "3                 Bad                    66                    169   \n",
       "4                 Bad                    81                    333   \n",
       "...               ...                   ...                    ...   \n",
       "10454            Good                    73                    131   \n",
       "10455             Bad                    65                    147   \n",
       "10456             Bad                    74                    129   \n",
       "10457             Bad                    72                    234   \n",
       "10458             Bad                    66                     28   \n",
       "\n",
       "       MSinceMostRecentTradeOpen  AverageMInFile  NumSatisfactoryTrades  \\\n",
       "0                              4              84                     20   \n",
       "1                             15              41                      2   \n",
       "2                              5              24                      9   \n",
       "3                              1              73                     28   \n",
       "4                             27             132                     12   \n",
       "...                          ...             ...                    ...   \n",
       "10454                          5              57                     21   \n",
       "10455                         39              68                     11   \n",
       "10456                          6              64                     18   \n",
       "10457                         12             113                     42   \n",
       "10458                          1              17                      4   \n",
       "\n",
       "       NumTrades60Ever2DerogPubRec  NumTrades90Ever2DerogPubRec  \\\n",
       "0                                3                            0   \n",
       "1                                4                            4   \n",
       "2                                0                            0   \n",
       "3                                1                            1   \n",
       "4                                0                            0   \n",
       "...                            ...                          ...   \n",
       "10454                            0                            0   \n",
       "10455                            0                            0   \n",
       "10456                            1                            1   \n",
       "10457                            2                            2   \n",
       "10458                            0                            0   \n",
       "\n",
       "       PercentTradesNeverDelq  MSinceMostRecentDelq  ...  \\\n",
       "0                          83                     2  ...   \n",
       "1                         100                    -7  ...   \n",
       "2                         100                    -7  ...   \n",
       "3                          93                    76  ...   \n",
       "4                         100                    -7  ...   \n",
       "...                       ...                   ...  ...   \n",
       "10454                      95                    80  ...   \n",
       "10455                      92                    28  ...   \n",
       "10456                     100                    -7  ...   \n",
       "10457                      96                    35  ...   \n",
       "10458                     100                    -7  ...   \n",
       "\n",
       "       PercentInstallTrades  MSinceMostRecentInqexcl7days  NumInqLast6M  \\\n",
       "0                        43                             0             0   \n",
       "1                        67                             0             0   \n",
       "2                        44                             0             4   \n",
       "3                        57                             0             5   \n",
       "4                        25                             0             1   \n",
       "...                     ...                           ...           ...   \n",
       "10454                    19                             7             0   \n",
       "10455                    42                             1             1   \n",
       "10456                    33                             3             4   \n",
       "10457                    20                             6             0   \n",
       "10458                    60                             3             3   \n",
       "\n",
       "       NumInqLast6Mexcl7days  NetFractionRevolvingBurden  \\\n",
       "0                          0                          33   \n",
       "1                          0                           0   \n",
       "2                          4                          53   \n",
       "3                          4                          72   \n",
       "4                          1                          51   \n",
       "...                      ...                         ...   \n",
       "10454                      0                          26   \n",
       "10455                      1                          86   \n",
       "10456                      4                           6   \n",
       "10457                      0                          19   \n",
       "10458                      2                          67   \n",
       "\n",
       "       NetFractionInstallBurden  NumRevolvingTradesWBalance  \\\n",
       "0                            -8                           8   \n",
       "1                            -8                           0   \n",
       "2                            66                           4   \n",
       "3                            83                           6   \n",
       "4                            89                           3   \n",
       "...                         ...                         ...   \n",
       "10454                        -8                           5   \n",
       "10455                        53                           2   \n",
       "10456                        -8                           5   \n",
       "10457                        -8                           4   \n",
       "10458                        -8                           2   \n",
       "\n",
       "       NumInstallTradesWBalance  NumBank2NatlTradesWHighUtilization  \\\n",
       "0                             1                                   1   \n",
       "1                            -8                                  -8   \n",
       "2                             2                                   1   \n",
       "3                             4                                   3   \n",
       "4                             1                                   0   \n",
       "...                         ...                                 ...   \n",
       "10454                         2                                   0   \n",
       "10455                         2                                   1   \n",
       "10456                        -8                                   0   \n",
       "10457                         1                                   0   \n",
       "10458                         1                                   0   \n",
       "\n",
       "       PercentTradesWBalance  \n",
       "0                         69  \n",
       "1                          0  \n",
       "2                         86  \n",
       "3                         91  \n",
       "4                         80  \n",
       "...                      ...  \n",
       "10454                    100  \n",
       "10455                     80  \n",
       "10456                     56  \n",
       "10457                     38  \n",
       "10458                    100  \n",
       "\n",
       "[10459 rows x 24 columns]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv(\"../data/heloc_dataset.csv\")\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def map_MaxDelq2PublicRecLast12M(x):\n",
    "    if x == 0:\n",
    "        return \"Derogatory Comment\"\n",
    "    elif x == 1:\n",
    "        return \"120+ Days Delinquent\"\n",
    "    elif x == 2:\n",
    "        return \"90 Days Delinquent\"\n",
    "    elif x == 3:\n",
    "        return \"60 Days Delinquent\"\n",
    "    elif x == 4:\n",
    "        return \"30 Days Delinquent\"\n",
    "    elif x == 7:\n",
    "        return \"Never Delinquent\"\n",
    "    else:\n",
    "        return None\n",
    "\n",
    "def map_MaxDelqEver(x):\n",
    "    if x == 2:\n",
    "        return \"Derogatory Comment\"\n",
    "    elif x == 3:\n",
    "        return \"120+ Days Delinquent\"\n",
    "    elif x == 4:\n",
    "        return \"90 Days Delinquent\"\n",
    "    elif x == 5:\n",
    "        return \"60 Days Delinquent\"\n",
    "    elif x == 6:\n",
    "        return \"30 Days Delinquent\"\n",
    "    elif x == 8:\n",
    "        return \"Never Delinquent\"\n",
    "    else:\n",
    "        return None\n",
    "        \n",
    "df['MaxDelq2PublicRecLast12M'] = df['MaxDelq2PublicRecLast12M'].apply(map_MaxDelq2PublicRecLast12M)\n",
    "df['MaxDelqEver'] = df['MaxDelqEver'].apply(map_MaxDelqEver)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ExternalRiskEstimate</th>\n",
       "      <th>MSinceOldestTradeOpen</th>\n",
       "      <th>MSinceMostRecentTradeOpen</th>\n",
       "      <th>AverageMInFile</th>\n",
       "      <th>NumSatisfactoryTrades</th>\n",
       "      <th>NumTrades60Ever2DerogPubRec</th>\n",
       "      <th>NumTrades90Ever2DerogPubRec</th>\n",
       "      <th>PercentTradesNeverDelq</th>\n",
       "      <th>MSinceMostRecentDelq</th>\n",
       "      <th>NumTotalTrades</th>\n",
       "      <th>...</th>\n",
       "      <th>NumInqLast6Mexcl7days</th>\n",
       "      <th>NetFractionRevolvingBurden</th>\n",
       "      <th>NetFractionInstallBurden</th>\n",
       "      <th>NumRevolvingTradesWBalance</th>\n",
       "      <th>NumInstallTradesWBalance</th>\n",
       "      <th>NumBank2NatlTradesWHighUtilization</th>\n",
       "      <th>PercentTradesWBalance</th>\n",
       "      <th>MaxDelq2PublicRecLast12M</th>\n",
       "      <th>MaxDelqEver</th>\n",
       "      <th>RiskPerformance</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>55</td>\n",
       "      <td>144</td>\n",
       "      <td>4</td>\n",
       "      <td>84</td>\n",
       "      <td>20</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>83</td>\n",
       "      <td>2</td>\n",
       "      <td>23</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>33</td>\n",
       "      <td>-8</td>\n",
       "      <td>8</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>69</td>\n",
       "      <td>60 Days Delinquent</td>\n",
       "      <td>60 Days Delinquent</td>\n",
       "      <td>Bad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>61</td>\n",
       "      <td>58</td>\n",
       "      <td>15</td>\n",
       "      <td>41</td>\n",
       "      <td>2</td>\n",
       "      <td>4</td>\n",
       "      <td>4</td>\n",
       "      <td>100</td>\n",
       "      <td>-7</td>\n",
       "      <td>7</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>-8</td>\n",
       "      <td>0</td>\n",
       "      <td>-8</td>\n",
       "      <td>-8</td>\n",
       "      <td>0</td>\n",
       "      <td>Derogatory Comment</td>\n",
       "      <td>Never Delinquent</td>\n",
       "      <td>Bad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>67</td>\n",
       "      <td>66</td>\n",
       "      <td>5</td>\n",
       "      <td>24</td>\n",
       "      <td>9</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>100</td>\n",
       "      <td>-7</td>\n",
       "      <td>9</td>\n",
       "      <td>...</td>\n",
       "      <td>4</td>\n",
       "      <td>53</td>\n",
       "      <td>66</td>\n",
       "      <td>4</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>86</td>\n",
       "      <td>Never Delinquent</td>\n",
       "      <td>Never Delinquent</td>\n",
       "      <td>Bad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>66</td>\n",
       "      <td>169</td>\n",
       "      <td>1</td>\n",
       "      <td>73</td>\n",
       "      <td>28</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>93</td>\n",
       "      <td>76</td>\n",
       "      <td>30</td>\n",
       "      <td>...</td>\n",
       "      <td>4</td>\n",
       "      <td>72</td>\n",
       "      <td>83</td>\n",
       "      <td>6</td>\n",
       "      <td>4</td>\n",
       "      <td>3</td>\n",
       "      <td>91</td>\n",
       "      <td>None</td>\n",
       "      <td>30 Days Delinquent</td>\n",
       "      <td>Bad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>81</td>\n",
       "      <td>333</td>\n",
       "      <td>27</td>\n",
       "      <td>132</td>\n",
       "      <td>12</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>100</td>\n",
       "      <td>-7</td>\n",
       "      <td>12</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>51</td>\n",
       "      <td>89</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>80</td>\n",
       "      <td>Never Delinquent</td>\n",
       "      <td>Never Delinquent</td>\n",
       "      <td>Bad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10454</th>\n",
       "      <td>73</td>\n",
       "      <td>131</td>\n",
       "      <td>5</td>\n",
       "      <td>57</td>\n",
       "      <td>21</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>95</td>\n",
       "      <td>80</td>\n",
       "      <td>21</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>26</td>\n",
       "      <td>-8</td>\n",
       "      <td>5</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>100</td>\n",
       "      <td>None</td>\n",
       "      <td>30 Days Delinquent</td>\n",
       "      <td>Good</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10455</th>\n",
       "      <td>65</td>\n",
       "      <td>147</td>\n",
       "      <td>39</td>\n",
       "      <td>68</td>\n",
       "      <td>11</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>92</td>\n",
       "      <td>28</td>\n",
       "      <td>12</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>86</td>\n",
       "      <td>53</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>80</td>\n",
       "      <td>None</td>\n",
       "      <td>30 Days Delinquent</td>\n",
       "      <td>Bad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10456</th>\n",
       "      <td>74</td>\n",
       "      <td>129</td>\n",
       "      <td>6</td>\n",
       "      <td>64</td>\n",
       "      <td>18</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>100</td>\n",
       "      <td>-7</td>\n",
       "      <td>18</td>\n",
       "      <td>...</td>\n",
       "      <td>4</td>\n",
       "      <td>6</td>\n",
       "      <td>-8</td>\n",
       "      <td>5</td>\n",
       "      <td>-8</td>\n",
       "      <td>0</td>\n",
       "      <td>56</td>\n",
       "      <td>None</td>\n",
       "      <td>Never Delinquent</td>\n",
       "      <td>Bad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10457</th>\n",
       "      <td>72</td>\n",
       "      <td>234</td>\n",
       "      <td>12</td>\n",
       "      <td>113</td>\n",
       "      <td>42</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>96</td>\n",
       "      <td>35</td>\n",
       "      <td>45</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>19</td>\n",
       "      <td>-8</td>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>38</td>\n",
       "      <td>None</td>\n",
       "      <td>Derogatory Comment</td>\n",
       "      <td>Bad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10458</th>\n",
       "      <td>66</td>\n",
       "      <td>28</td>\n",
       "      <td>1</td>\n",
       "      <td>17</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>100</td>\n",
       "      <td>-7</td>\n",
       "      <td>5</td>\n",
       "      <td>...</td>\n",
       "      <td>2</td>\n",
       "      <td>67</td>\n",
       "      <td>-8</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>100</td>\n",
       "      <td>Never Delinquent</td>\n",
       "      <td>Never Delinquent</td>\n",
       "      <td>Bad</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>10459 rows × 24 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       ExternalRiskEstimate  MSinceOldestTradeOpen  MSinceMostRecentTradeOpen  \\\n",
       "0                        55                    144                          4   \n",
       "1                        61                     58                         15   \n",
       "2                        67                     66                          5   \n",
       "3                        66                    169                          1   \n",
       "4                        81                    333                         27   \n",
       "...                     ...                    ...                        ...   \n",
       "10454                    73                    131                          5   \n",
       "10455                    65                    147                         39   \n",
       "10456                    74                    129                          6   \n",
       "10457                    72                    234                         12   \n",
       "10458                    66                     28                          1   \n",
       "\n",
       "       AverageMInFile  NumSatisfactoryTrades  NumTrades60Ever2DerogPubRec  \\\n",
       "0                  84                     20                            3   \n",
       "1                  41                      2                            4   \n",
       "2                  24                      9                            0   \n",
       "3                  73                     28                            1   \n",
       "4                 132                     12                            0   \n",
       "...               ...                    ...                          ...   \n",
       "10454              57                     21                            0   \n",
       "10455              68                     11                            0   \n",
       "10456              64                     18                            1   \n",
       "10457             113                     42                            2   \n",
       "10458              17                      4                            0   \n",
       "\n",
       "       NumTrades90Ever2DerogPubRec  PercentTradesNeverDelq  \\\n",
       "0                                0                      83   \n",
       "1                                4                     100   \n",
       "2                                0                     100   \n",
       "3                                1                      93   \n",
       "4                                0                     100   \n",
       "...                            ...                     ...   \n",
       "10454                            0                      95   \n",
       "10455                            0                      92   \n",
       "10456                            1                     100   \n",
       "10457                            2                      96   \n",
       "10458                            0                     100   \n",
       "\n",
       "       MSinceMostRecentDelq  NumTotalTrades  ...  NumInqLast6Mexcl7days  \\\n",
       "0                         2              23  ...                      0   \n",
       "1                        -7               7  ...                      0   \n",
       "2                        -7               9  ...                      4   \n",
       "3                        76              30  ...                      4   \n",
       "4                        -7              12  ...                      1   \n",
       "...                     ...             ...  ...                    ...   \n",
       "10454                    80              21  ...                      0   \n",
       "10455                    28              12  ...                      1   \n",
       "10456                    -7              18  ...                      4   \n",
       "10457                    35              45  ...                      0   \n",
       "10458                    -7               5  ...                      2   \n",
       "\n",
       "       NetFractionRevolvingBurden  NetFractionInstallBurden  \\\n",
       "0                              33                        -8   \n",
       "1                               0                        -8   \n",
       "2                              53                        66   \n",
       "3                              72                        83   \n",
       "4                              51                        89   \n",
       "...                           ...                       ...   \n",
       "10454                          26                        -8   \n",
       "10455                          86                        53   \n",
       "10456                           6                        -8   \n",
       "10457                          19                        -8   \n",
       "10458                          67                        -8   \n",
       "\n",
       "       NumRevolvingTradesWBalance  NumInstallTradesWBalance  \\\n",
       "0                               8                         1   \n",
       "1                               0                        -8   \n",
       "2                               4                         2   \n",
       "3                               6                         4   \n",
       "4                               3                         1   \n",
       "...                           ...                       ...   \n",
       "10454                           5                         2   \n",
       "10455                           2                         2   \n",
       "10456                           5                        -8   \n",
       "10457                           4                         1   \n",
       "10458                           2                         1   \n",
       "\n",
       "       NumBank2NatlTradesWHighUtilization  PercentTradesWBalance  \\\n",
       "0                                       1                     69   \n",
       "1                                      -8                      0   \n",
       "2                                       1                     86   \n",
       "3                                       3                     91   \n",
       "4                                       0                     80   \n",
       "...                                   ...                    ...   \n",
       "10454                                   0                    100   \n",
       "10455                                   1                     80   \n",
       "10456                                   0                     56   \n",
       "10457                                   0                     38   \n",
       "10458                                   0                    100   \n",
       "\n",
       "       MaxDelq2PublicRecLast12M         MaxDelqEver  RiskPerformance  \n",
       "0            60 Days Delinquent  60 Days Delinquent              Bad  \n",
       "1            Derogatory Comment    Never Delinquent              Bad  \n",
       "2              Never Delinquent    Never Delinquent              Bad  \n",
       "3                          None  30 Days Delinquent              Bad  \n",
       "4              Never Delinquent    Never Delinquent              Bad  \n",
       "...                         ...                 ...              ...  \n",
       "10454                      None  30 Days Delinquent             Good  \n",
       "10455                      None  30 Days Delinquent              Bad  \n",
       "10456                      None    Never Delinquent              Bad  \n",
       "10457                      None  Derogatory Comment              Bad  \n",
       "10458          Never Delinquent    Never Delinquent              Bad  \n",
       "\n",
       "[10459 rows x 24 columns]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "columns = df.columns.tolist()\n",
    "cat_cols_list = ['MaxDelq2PublicRecLast12M', 'MaxDelqEver', 'RiskPerformance']\n",
    "\n",
    "for col in cat_cols_list:\n",
    "    columns.remove(col)\n",
    "columns += cat_cols_list\n",
    "df = df[columns]\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['ExternalRiskEstimate',\n",
       " 'MSinceOldestTradeOpen',\n",
       " 'MSinceMostRecentTradeOpen',\n",
       " 'AverageMInFile',\n",
       " 'NumSatisfactoryTrades',\n",
       " 'NumTrades60Ever2DerogPubRec',\n",
       " 'NumTrades90Ever2DerogPubRec',\n",
       " 'PercentTradesNeverDelq',\n",
       " 'MSinceMostRecentDelq',\n",
       " 'NumTotalTrades',\n",
       " 'NumTradesOpeninLast12M',\n",
       " 'PercentInstallTrades',\n",
       " 'MSinceMostRecentInqexcl7days',\n",
       " 'NumInqLast6M',\n",
       " 'NumInqLast6Mexcl7days',\n",
       " 'NetFractionRevolvingBurden',\n",
       " 'NetFractionInstallBurden',\n",
       " 'NumRevolvingTradesWBalance',\n",
       " 'NumInstallTradesWBalance',\n",
       " 'NumBank2NatlTradesWHighUtilization',\n",
       " 'PercentTradesWBalance',\n",
       " 'MaxDelq2PublicRecLast12M',\n",
       " 'MaxDelqEver',\n",
       " 'RiskPerformance']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.columns.tolist()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Default of credit card clients\n",
    "\n",
    "http://archive.ics.uci.edu/ml/datasets/default+of+credit+card+clients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>LIMIT_BAL</th>\n",
       "      <th>SEX</th>\n",
       "      <th>EDUCATION</th>\n",
       "      <th>MARRIAGE</th>\n",
       "      <th>AGE</th>\n",
       "      <th>PAY_0</th>\n",
       "      <th>PAY_2</th>\n",
       "      <th>PAY_3</th>\n",
       "      <th>PAY_4</th>\n",
       "      <th>PAY_5</th>\n",
       "      <th>...</th>\n",
       "      <th>BILL_AMT4</th>\n",
       "      <th>BILL_AMT5</th>\n",
       "      <th>BILL_AMT6</th>\n",
       "      <th>PAY_AMT1</th>\n",
       "      <th>PAY_AMT2</th>\n",
       "      <th>PAY_AMT3</th>\n",
       "      <th>PAY_AMT4</th>\n",
       "      <th>PAY_AMT5</th>\n",
       "      <th>PAY_AMT6</th>\n",
       "      <th>Y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>20000</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>24</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-2</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>689</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>120000</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>26</td>\n",
       "      <td>-1</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>3272</td>\n",
       "      <td>3455</td>\n",
       "      <td>3261</td>\n",
       "      <td>0</td>\n",
       "      <td>1000</td>\n",
       "      <td>1000</td>\n",
       "      <td>1000</td>\n",
       "      <td>0</td>\n",
       "      <td>2000</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>90000</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>34</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>14331</td>\n",
       "      <td>14948</td>\n",
       "      <td>15549</td>\n",
       "      <td>1518</td>\n",
       "      <td>1500</td>\n",
       "      <td>1000</td>\n",
       "      <td>1000</td>\n",
       "      <td>1000</td>\n",
       "      <td>5000</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>50000</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>37</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>28314</td>\n",
       "      <td>28959</td>\n",
       "      <td>29547</td>\n",
       "      <td>2000</td>\n",
       "      <td>2019</td>\n",
       "      <td>1200</td>\n",
       "      <td>1100</td>\n",
       "      <td>1069</td>\n",
       "      <td>1000</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>50000</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>57</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>20940</td>\n",
       "      <td>19146</td>\n",
       "      <td>19131</td>\n",
       "      <td>2000</td>\n",
       "      <td>36681</td>\n",
       "      <td>10000</td>\n",
       "      <td>9000</td>\n",
       "      <td>689</td>\n",
       "      <td>679</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29995</th>\n",
       "      <td>220000</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>39</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>88004</td>\n",
       "      <td>31237</td>\n",
       "      <td>15980</td>\n",
       "      <td>8500</td>\n",
       "      <td>20000</td>\n",
       "      <td>5003</td>\n",
       "      <td>3047</td>\n",
       "      <td>5000</td>\n",
       "      <td>1000</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29996</th>\n",
       "      <td>150000</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>43</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>8979</td>\n",
       "      <td>5190</td>\n",
       "      <td>0</td>\n",
       "      <td>1837</td>\n",
       "      <td>3526</td>\n",
       "      <td>8998</td>\n",
       "      <td>129</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29997</th>\n",
       "      <td>30000</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>37</td>\n",
       "      <td>4</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>20878</td>\n",
       "      <td>20582</td>\n",
       "      <td>19357</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>22000</td>\n",
       "      <td>4200</td>\n",
       "      <td>2000</td>\n",
       "      <td>3100</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29998</th>\n",
       "      <td>80000</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>41</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>52774</td>\n",
       "      <td>11855</td>\n",
       "      <td>48944</td>\n",
       "      <td>85900</td>\n",
       "      <td>3409</td>\n",
       "      <td>1178</td>\n",
       "      <td>1926</td>\n",
       "      <td>52964</td>\n",
       "      <td>1804</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29999</th>\n",
       "      <td>50000</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>46</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>36535</td>\n",
       "      <td>32428</td>\n",
       "      <td>15313</td>\n",
       "      <td>2078</td>\n",
       "      <td>1800</td>\n",
       "      <td>1430</td>\n",
       "      <td>1000</td>\n",
       "      <td>1000</td>\n",
       "      <td>1000</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>30000 rows × 24 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       LIMIT_BAL  SEX  EDUCATION  MARRIAGE  AGE  PAY_0  PAY_2  PAY_3  PAY_4  \\\n",
       "0          20000    2          2         1   24      2      2     -1     -1   \n",
       "1         120000    2          2         2   26     -1      2      0      0   \n",
       "2          90000    2          2         2   34      0      0      0      0   \n",
       "3          50000    2          2         1   37      0      0      0      0   \n",
       "4          50000    1          2         1   57     -1      0     -1      0   \n",
       "...          ...  ...        ...       ...  ...    ...    ...    ...    ...   \n",
       "29995     220000    1          3         1   39      0      0      0      0   \n",
       "29996     150000    1          3         2   43     -1     -1     -1     -1   \n",
       "29997      30000    1          2         2   37      4      3      2     -1   \n",
       "29998      80000    1          3         1   41      1     -1      0      0   \n",
       "29999      50000    1          2         1   46      0      0      0      0   \n",
       "\n",
       "       PAY_5  ...  BILL_AMT4  BILL_AMT5  BILL_AMT6  PAY_AMT1  PAY_AMT2  \\\n",
       "0         -2  ...          0          0          0         0       689   \n",
       "1          0  ...       3272       3455       3261         0      1000   \n",
       "2          0  ...      14331      14948      15549      1518      1500   \n",
       "3          0  ...      28314      28959      29547      2000      2019   \n",
       "4          0  ...      20940      19146      19131      2000     36681   \n",
       "...      ...  ...        ...        ...        ...       ...       ...   \n",
       "29995      0  ...      88004      31237      15980      8500     20000   \n",
       "29996      0  ...       8979       5190          0      1837      3526   \n",
       "29997      0  ...      20878      20582      19357         0         0   \n",
       "29998      0  ...      52774      11855      48944     85900      3409   \n",
       "29999      0  ...      36535      32428      15313      2078      1800   \n",
       "\n",
       "       PAY_AMT3  PAY_AMT4  PAY_AMT5  PAY_AMT6  Y  \n",
       "0             0         0         0         0  1  \n",
       "1          1000      1000         0      2000  1  \n",
       "2          1000      1000      1000      5000  0  \n",
       "3          1200      1100      1069      1000  0  \n",
       "4         10000      9000       689       679  0  \n",
       "...         ...       ...       ...       ... ..  \n",
       "29995      5003      3047      5000      1000  0  \n",
       "29996      8998       129         0         0  0  \n",
       "29997     22000      4200      2000      3100  1  \n",
       "29998      1178      1926     52964      1804  1  \n",
       "29999      1430      1000      1000      1000  1  \n",
       "\n",
       "[30000 rows x 24 columns]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('../data/extra/credit_card.csv')\n",
    "df = df.drop(['ID'], axis=1)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_cols = ['SEX', 'EDUCATION', 'MARRIAGE',]\n",
    "outcome_col = 'Y'\n",
    "file_name = \"data/extra/s_credit_cart.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = simple_transform(df, cat_cols, outcome_col)\n",
    "df = shuffle(df)\n",
    "df.to_csv('../' + file_name, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'data_dir': 'data/extra/s_credit_cart.csv',\n",
       " 'lr': 0.01,\n",
       " 'batch_size': 128,\n",
       " 'lambda_1': 1.0,\n",
       " 'lambda_2': 0.01,\n",
       " 'lambda_3': 1.0,\n",
       " 'threshold': 1.0,\n",
       " 'continous_cols': ['LIMIT_BAL',\n",
       "  'AGE',\n",
       "  'PAY_0',\n",
       "  'PAY_2',\n",
       "  'PAY_3',\n",
       "  'PAY_4',\n",
       "  'PAY_5',\n",
       "  'PAY_6',\n",
       "  'BILL_AMT1',\n",
       "  'BILL_AMT2',\n",
       "  'BILL_AMT3',\n",
       "  'BILL_AMT4',\n",
       "  'BILL_AMT5',\n",
       "  'BILL_AMT6',\n",
       "  'PAY_AMT1',\n",
       "  'PAY_AMT2',\n",
       "  'PAY_AMT3',\n",
       "  'PAY_AMT4',\n",
       "  'PAY_AMT5',\n",
       "  'PAY_AMT6'],\n",
       " 'discret_cols': ['SEX', 'EDUCATION', 'MARRIAGE'],\n",
       " 'encoder_dims': [29, 50, 10],\n",
       " 'decoder_dims': [10, 10],\n",
       " 'explainer_dims': [10, 50],\n",
       " 'loss_1': 'mse',\n",
       " 'loss_2': 'mse',\n",
       " 'loss_3': 'mse'}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# load configs of adult for no reasons\n",
    "param = load_json(\"../counterfactual/configs/adult.json\")\n",
    "cols = df.columns.tolist()\n",
    "for col in cat_cols:\n",
    "    cols.remove(col)\n",
    "cols.remove(outcome_col)\n",
    "# copy to cont_cols\n",
    "cont_cols = cols\n",
    "param['data_dir'] = file_name\n",
    "param['continous_cols'] = cont_cols\n",
    "param['discret_cols'] = cat_cols\n",
    "update_json_file(param, \"../counterfactual/configs/extra/s_credit_cart.json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# German Credit Card Dataset\n",
    "\n",
    "http://archive.ics.uci.edu/ml/datasets/statlog+(german+credit+data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = np.genfromtxt('../data/extra/german.data', delimiter=' ',\n",
    "            dtype=str)\n",
    "cols = ['status', 'duration', 'history', 'purpose', 'credit_amount', \n",
    "        'savings', 'employment', 'installment_rate', 'sex', 'other_debtors',\n",
    "       'residence', 'property', 'age', 'other_plan', 'housing', 'credits', 'job',\n",
    "       'num_people_liable', 'telephone', 'foreign_worker', 'Y']\n",
    "cat_cols = ['status', 'history','purpose', 'savings', 'employment','sex','other_debtors',\n",
    "           'property','other_plan','housing','job','telephone', 'foreign_worker']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(data, columns=cols)\n",
    "# df = df.infer_objects()\n",
    "df['Y'] = df['Y'].astype(int).apply(lambda x: x - 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outcome_col = 'Y'\n",
    "file_name = \"data/extra/s_german_credit.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = simple_transform(df, cat_cols=cat_cols, outcome_col='Y')\n",
    "df = shuffle(df)\n",
    "df.to_csv('../' + file_name, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'data_dir': 'data/extra/s_german_credit.csv',\n",
       " 'lr': 0.01,\n",
       " 'batch_size': 128,\n",
       " 'lambda_1': 1.0,\n",
       " 'lambda_2': 0.01,\n",
       " 'lambda_3': 1.0,\n",
       " 'threshold': 1.0,\n",
       " 'continous_cols': ['duration',\n",
       "  'credit_amount',\n",
       "  'installment_rate',\n",
       "  'residence',\n",
       "  'age',\n",
       "  'credits',\n",
       "  'num_people_liable'],\n",
       " 'discret_cols': ['status',\n",
       "  'history',\n",
       "  'purpose',\n",
       "  'savings',\n",
       "  'employment',\n",
       "  'sex',\n",
       "  'other_debtors',\n",
       "  'property',\n",
       "  'other_plan',\n",
       "  'housing',\n",
       "  'job',\n",
       "  'telephone',\n",
       "  'foreign_worker'],\n",
       " 'encoder_dims': [29, 50, 10],\n",
       " 'decoder_dims': [10, 10],\n",
       " 'explainer_dims': [10, 50],\n",
       " 'loss_1': 'mse',\n",
       " 'loss_2': 'mse',\n",
       " 'loss_3': 'mse'}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# load configs of adult for no reasons\n",
    "param = load_json(\"../counterfactual/configs/adult.json\")\n",
    "cols = df.columns.tolist()\n",
    "for col in cat_cols:\n",
    "    cols.remove(col)\n",
    "cols.remove(outcome_col)\n",
    "# copy to cont_cols\n",
    "cont_cols = cols\n",
    "param['data_dir'] = file_name\n",
    "param['continous_cols'] = cont_cols\n",
    "param['discret_cols'] = cat_cols\n",
    "update_json_file(param, \"../counterfactual/configs/extra/s_german_credit.json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Student Performance\n",
    "\n",
    "http://archive.ics.uci.edu/ml/datasets/default+of+credit+card+clients\n",
    "https://dash-xai.herokuapp.com/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>G2</th>\n",
       "      <th>G1</th>\n",
       "      <th>failures</th>\n",
       "      <th>higher</th>\n",
       "      <th>age</th>\n",
       "      <th>school</th>\n",
       "      <th>goout</th>\n",
       "      <th>Mjob</th>\n",
       "      <th>Fjob</th>\n",
       "      <th>health</th>\n",
       "      <th>freetime</th>\n",
       "      <th>absences</th>\n",
       "      <th>Walc</th>\n",
       "      <th>famrel</th>\n",
       "      <th>Medu</th>\n",
       "      <th>Fedu</th>\n",
       "      <th>G3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>D</td>\n",
       "      <td>F</td>\n",
       "      <td>0</td>\n",
       "      <td>yes</td>\n",
       "      <td>18</td>\n",
       "      <td>GP</td>\n",
       "      <td>4</td>\n",
       "      <td>at_home</td>\n",
       "      <td>teacher</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>4</td>\n",
       "      <td>4</td>\n",
       "      <td>D</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>D</td>\n",
       "      <td>F</td>\n",
       "      <td>0</td>\n",
       "      <td>yes</td>\n",
       "      <td>17</td>\n",
       "      <td>GP</td>\n",
       "      <td>3</td>\n",
       "      <td>at_home</td>\n",
       "      <td>other</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>D</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>C</td>\n",
       "      <td>C</td>\n",
       "      <td>0</td>\n",
       "      <td>yes</td>\n",
       "      <td>15</td>\n",
       "      <td>GP</td>\n",
       "      <td>2</td>\n",
       "      <td>at_home</td>\n",
       "      <td>other</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>6</td>\n",
       "      <td>3</td>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>C</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>B</td>\n",
       "      <td>B</td>\n",
       "      <td>0</td>\n",
       "      <td>yes</td>\n",
       "      <td>15</td>\n",
       "      <td>GP</td>\n",
       "      <td>2</td>\n",
       "      <td>health</td>\n",
       "      <td>services</td>\n",
       "      <td>5</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>4</td>\n",
       "      <td>2</td>\n",
       "      <td>B</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>C</td>\n",
       "      <td>D</td>\n",
       "      <td>0</td>\n",
       "      <td>yes</td>\n",
       "      <td>16</td>\n",
       "      <td>GP</td>\n",
       "      <td>2</td>\n",
       "      <td>other</td>\n",
       "      <td>other</td>\n",
       "      <td>5</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>4</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>C</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>644</th>\n",
       "      <td>D</td>\n",
       "      <td>D</td>\n",
       "      <td>1</td>\n",
       "      <td>yes</td>\n",
       "      <td>19</td>\n",
       "      <td>MS</td>\n",
       "      <td>2</td>\n",
       "      <td>services</td>\n",
       "      <td>other</td>\n",
       "      <td>5</td>\n",
       "      <td>4</td>\n",
       "      <td>4</td>\n",
       "      <td>2</td>\n",
       "      <td>5</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>D</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>645</th>\n",
       "      <td>B</td>\n",
       "      <td>B</td>\n",
       "      <td>0</td>\n",
       "      <td>yes</td>\n",
       "      <td>18</td>\n",
       "      <td>MS</td>\n",
       "      <td>4</td>\n",
       "      <td>teacher</td>\n",
       "      <td>services</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>A</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>646</th>\n",
       "      <td>C</td>\n",
       "      <td>D</td>\n",
       "      <td>0</td>\n",
       "      <td>yes</td>\n",
       "      <td>18</td>\n",
       "      <td>MS</td>\n",
       "      <td>1</td>\n",
       "      <td>other</td>\n",
       "      <td>other</td>\n",
       "      <td>5</td>\n",
       "      <td>1</td>\n",
       "      <td>6</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>F</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>647</th>\n",
       "      <td>D</td>\n",
       "      <td>D</td>\n",
       "      <td>0</td>\n",
       "      <td>yes</td>\n",
       "      <td>17</td>\n",
       "      <td>MS</td>\n",
       "      <td>5</td>\n",
       "      <td>services</td>\n",
       "      <td>services</td>\n",
       "      <td>2</td>\n",
       "      <td>4</td>\n",
       "      <td>6</td>\n",
       "      <td>4</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>D</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>648</th>\n",
       "      <td>D</td>\n",
       "      <td>D</td>\n",
       "      <td>0</td>\n",
       "      <td>yes</td>\n",
       "      <td>18</td>\n",
       "      <td>MS</td>\n",
       "      <td>1</td>\n",
       "      <td>services</td>\n",
       "      <td>other</td>\n",
       "      <td>5</td>\n",
       "      <td>4</td>\n",
       "      <td>4</td>\n",
       "      <td>4</td>\n",
       "      <td>4</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>D</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>649 rows × 17 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "    G2 G1  failures higher  age school  goout      Mjob      Fjob  health  \\\n",
       "0    D  F         0    yes   18     GP      4   at_home   teacher       3   \n",
       "1    D  F         0    yes   17     GP      3   at_home     other       3   \n",
       "2    C  C         0    yes   15     GP      2   at_home     other       3   \n",
       "3    B  B         0    yes   15     GP      2    health  services       5   \n",
       "4    C  D         0    yes   16     GP      2     other     other       5   \n",
       "..  .. ..       ...    ...  ...    ...    ...       ...       ...     ...   \n",
       "644  D  D         1    yes   19     MS      2  services     other       5   \n",
       "645  B  B         0    yes   18     MS      4   teacher  services       1   \n",
       "646  C  D         0    yes   18     MS      1     other     other       5   \n",
       "647  D  D         0    yes   17     MS      5  services  services       2   \n",
       "648  D  D         0    yes   18     MS      1  services     other       5   \n",
       "\n",
       "     freetime  absences  Walc  famrel  Medu  Fedu G3  \n",
       "0           3         4     1       4     4     4  D  \n",
       "1           3         2     1       5     1     1  D  \n",
       "2           3         6     3       4     1     1  C  \n",
       "3           2         0     1       3     4     2  B  \n",
       "4           3         0     2       4     3     3  C  \n",
       "..        ...       ...   ...     ...   ...   ... ..  \n",
       "644         4         4     2       5     2     3  D  \n",
       "645         3         4     1       4     3     1  A  \n",
       "646         1         6     1       1     1     1  F  \n",
       "647         4         6     4       2     3     1  D  \n",
       "648         4         4     4       4     3     2  D  \n",
       "\n",
       "[649 rows x 17 columns]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('../data/extra/student_performance.csv')\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_cols = ['G2', 'G1', 'higher', 'school', 'goout', 'Mjob', 'Fjob', 'health', \n",
    "            'freetime', 'absences', 'Walc', 'famrel', 'Medu', 'Fedu']\n",
    "outcome_col = 'G3'\n",
    "file_name = \"data/extra/s_student_performance.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['G3'] = df['G3'].apply(lambda x: 1 if (x == 'A' or x == 'B') else 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = simple_transform(df, cat_cols, outcome_col)\n",
    "df = shuffle(df)\n",
    "df.to_csv('../' + file_name, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'data_dir': 'data/extra/s_student_performance.csv',\n",
       " 'lr': 0.01,\n",
       " 'batch_size': 128,\n",
       " 'lambda_1': 1.0,\n",
       " 'lambda_2': 0.01,\n",
       " 'lambda_3': 1.0,\n",
       " 'threshold': 1.0,\n",
       " 'continous_cols': ['failures', 'age'],\n",
       " 'discret_cols': ['G2',\n",
       "  'G1',\n",
       "  'higher',\n",
       "  'school',\n",
       "  'goout',\n",
       "  'Mjob',\n",
       "  'Fjob',\n",
       "  'health',\n",
       "  'freetime',\n",
       "  'absences',\n",
       "  'Walc',\n",
       "  'famrel',\n",
       "  'Medu',\n",
       "  'Fedu'],\n",
       " 'encoder_dims': [29, 50, 10],\n",
       " 'decoder_dims': [10, 10],\n",
       " 'explainer_dims': [10, 50],\n",
       " 'loss_1': 'mse',\n",
       " 'loss_2': 'mse',\n",
       " 'loss_3': 'mse'}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# load configs of adult for no reasons\n",
    "param = load_json(\"../counterfactual/configs/adult.json\")\n",
    "cols = df.columns.tolist()\n",
    "for col in cat_cols:\n",
    "    cols.remove(col)\n",
    "cols.remove(outcome_col)\n",
    "# copy to cont_cols\n",
    "cont_cols = cols\n",
    "param['data_dir'] = file_name\n",
    "param['continous_cols'] = cont_cols\n",
    "param['discret_cols'] = cat_cols\n",
    "update_json_file(param, \"../counterfactual/configs/extra/s_student_performance.json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Heart\n",
    "\n",
    "https://www.kaggle.com/andrewmvd/heart-failure-clinical-data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>anaemia</th>\n",
       "      <th>creatinine_phosphokinase</th>\n",
       "      <th>diabetes</th>\n",
       "      <th>ejection_fraction</th>\n",
       "      <th>high_blood_pressure</th>\n",
       "      <th>platelets</th>\n",
       "      <th>serum_creatinine</th>\n",
       "      <th>serum_sodium</th>\n",
       "      <th>sex</th>\n",
       "      <th>smoking</th>\n",
       "      <th>time</th>\n",
       "      <th>DEATH_EVENT</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>75.0</td>\n",
       "      <td>0</td>\n",
       "      <td>582</td>\n",
       "      <td>0</td>\n",
       "      <td>20</td>\n",
       "      <td>1</td>\n",
       "      <td>265000.00</td>\n",
       "      <td>1.9</td>\n",
       "      <td>130</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>55.0</td>\n",
       "      <td>0</td>\n",
       "      <td>7861</td>\n",
       "      <td>0</td>\n",
       "      <td>38</td>\n",
       "      <td>0</td>\n",
       "      <td>263358.03</td>\n",
       "      <td>1.1</td>\n",
       "      <td>136</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>6</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>65.0</td>\n",
       "      <td>0</td>\n",
       "      <td>146</td>\n",
       "      <td>0</td>\n",
       "      <td>20</td>\n",
       "      <td>0</td>\n",
       "      <td>162000.00</td>\n",
       "      <td>1.3</td>\n",
       "      <td>129</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>7</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>50.0</td>\n",
       "      <td>1</td>\n",
       "      <td>111</td>\n",
       "      <td>0</td>\n",
       "      <td>20</td>\n",
       "      <td>0</td>\n",
       "      <td>210000.00</td>\n",
       "      <td>1.9</td>\n",
       "      <td>137</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>7</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>65.0</td>\n",
       "      <td>1</td>\n",
       "      <td>160</td>\n",
       "      <td>1</td>\n",
       "      <td>20</td>\n",
       "      <td>0</td>\n",
       "      <td>327000.00</td>\n",
       "      <td>2.7</td>\n",
       "      <td>116</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>8</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>294</th>\n",
       "      <td>62.0</td>\n",
       "      <td>0</td>\n",
       "      <td>61</td>\n",
       "      <td>1</td>\n",
       "      <td>38</td>\n",
       "      <td>1</td>\n",
       "      <td>155000.00</td>\n",
       "      <td>1.1</td>\n",
       "      <td>143</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>270</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>295</th>\n",
       "      <td>55.0</td>\n",
       "      <td>0</td>\n",
       "      <td>1820</td>\n",
       "      <td>0</td>\n",
       "      <td>38</td>\n",
       "      <td>0</td>\n",
       "      <td>270000.00</td>\n",
       "      <td>1.2</td>\n",
       "      <td>139</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>271</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>296</th>\n",
       "      <td>45.0</td>\n",
       "      <td>0</td>\n",
       "      <td>2060</td>\n",
       "      <td>1</td>\n",
       "      <td>60</td>\n",
       "      <td>0</td>\n",
       "      <td>742000.00</td>\n",
       "      <td>0.8</td>\n",
       "      <td>138</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>278</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>297</th>\n",
       "      <td>45.0</td>\n",
       "      <td>0</td>\n",
       "      <td>2413</td>\n",
       "      <td>0</td>\n",
       "      <td>38</td>\n",
       "      <td>0</td>\n",
       "      <td>140000.00</td>\n",
       "      <td>1.4</td>\n",
       "      <td>140</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>280</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>298</th>\n",
       "      <td>50.0</td>\n",
       "      <td>0</td>\n",
       "      <td>196</td>\n",
       "      <td>0</td>\n",
       "      <td>45</td>\n",
       "      <td>0</td>\n",
       "      <td>395000.00</td>\n",
       "      <td>1.6</td>\n",
       "      <td>136</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>285</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>299 rows × 13 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      age  anaemia  creatinine_phosphokinase  diabetes  ejection_fraction  \\\n",
       "0    75.0        0                       582         0                 20   \n",
       "1    55.0        0                      7861         0                 38   \n",
       "2    65.0        0                       146         0                 20   \n",
       "3    50.0        1                       111         0                 20   \n",
       "4    65.0        1                       160         1                 20   \n",
       "..    ...      ...                       ...       ...                ...   \n",
       "294  62.0        0                        61         1                 38   \n",
       "295  55.0        0                      1820         0                 38   \n",
       "296  45.0        0                      2060         1                 60   \n",
       "297  45.0        0                      2413         0                 38   \n",
       "298  50.0        0                       196         0                 45   \n",
       "\n",
       "     high_blood_pressure  platelets  serum_creatinine  serum_sodium  sex  \\\n",
       "0                      1  265000.00               1.9           130    1   \n",
       "1                      0  263358.03               1.1           136    1   \n",
       "2                      0  162000.00               1.3           129    1   \n",
       "3                      0  210000.00               1.9           137    1   \n",
       "4                      0  327000.00               2.7           116    0   \n",
       "..                   ...        ...               ...           ...  ...   \n",
       "294                    1  155000.00               1.1           143    1   \n",
       "295                    0  270000.00               1.2           139    0   \n",
       "296                    0  742000.00               0.8           138    0   \n",
       "297                    0  140000.00               1.4           140    1   \n",
       "298                    0  395000.00               1.6           136    1   \n",
       "\n",
       "     smoking  time  DEATH_EVENT  \n",
       "0          0     4            1  \n",
       "1          0     6            1  \n",
       "2          1     7            1  \n",
       "3          0     7            1  \n",
       "4          0     8            1  \n",
       "..       ...   ...          ...  \n",
       "294        1   270            0  \n",
       "295        0   271            0  \n",
       "296        0   278            0  \n",
       "297        1   280            0  \n",
       "298        1   285            0  \n",
       "\n",
       "[299 rows x 13 columns]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('../data/extra/heart.csv')\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_cols = ['anaemia', 'diabetes', 'high_blood_pressure', 'sex', 'smoking']\n",
    "outcome_col = 'DEATH_EVENT'\n",
    "file_name = \"data/extra/s_heart.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = simple_transform(df, cat_cols=cat_cols, outcome_col=outcome_col)\n",
    "df = shuffle(df)\n",
    "df.to_csv('../' + file_name, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'data_dir': 'data/extra/s_heart.csv',\n",
       " 'lr': 0.01,\n",
       " 'batch_size': 128,\n",
       " 'lambda_1': 1.0,\n",
       " 'lambda_2': 0.01,\n",
       " 'lambda_3': 1.0,\n",
       " 'threshold': 1.0,\n",
       " 'continous_cols': ['age',\n",
       "  'creatinine_phosphokinase',\n",
       "  'ejection_fraction',\n",
       "  'platelets',\n",
       "  'serum_creatinine',\n",
       "  'serum_sodium',\n",
       "  'time'],\n",
       " 'discret_cols': ['anaemia',\n",
       "  'diabetes',\n",
       "  'high_blood_pressure',\n",
       "  'sex',\n",
       "  'smoking'],\n",
       " 'encoder_dims': [29, 50, 10],\n",
       " 'decoder_dims': [10, 10],\n",
       " 'explainer_dims': [10, 50],\n",
       " 'loss_1': 'mse',\n",
       " 'loss_2': 'mse',\n",
       " 'loss_3': 'mse'}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# load configs of adult for no reasons\n",
    "param = load_json(\"../counterfactual/configs/adult.json\")\n",
    "cols = df.columns.tolist()\n",
    "for col in cat_cols:\n",
    "    cols.remove(col)\n",
    "cols.remove(outcome_col)\n",
    "# copy to cont_cols\n",
    "cont_cols = cols\n",
    "param['data_dir'] = file_name\n",
    "param['continous_cols'] = cont_cols\n",
    "param['discret_cols'] = cat_cols\n",
    "update_json_file(param, \"../counterfactual/configs/extra/heart.json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Titanic\n",
    "\n",
    "https://www.kaggle.com/ak1352/titanic-cl?select=train_cl.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Survived</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Age</th>\n",
       "      <th>Fare</th>\n",
       "      <th>family</th>\n",
       "      <th>AgeGroup_1</th>\n",
       "      <th>AgeGroup_2</th>\n",
       "      <th>AgeGroup_3</th>\n",
       "      <th>AgeGroup_4</th>\n",
       "      <th>SibSp_0</th>\n",
       "      <th>...</th>\n",
       "      <th>Pclass_2</th>\n",
       "      <th>Pclass_3</th>\n",
       "      <th>Parch_0</th>\n",
       "      <th>Parch_1</th>\n",
       "      <th>Parch_2</th>\n",
       "      <th>Parch_3</th>\n",
       "      <th>Parch_4</th>\n",
       "      <th>Parch_5</th>\n",
       "      <th>Parch_6</th>\n",
       "      <th>isAlone</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>22.000000</td>\n",
       "      <td>7.2500</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>38.000000</td>\n",
       "      <td>71.2833</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>26.000000</td>\n",
       "      <td>7.9250</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>35.000000</td>\n",
       "      <td>53.1000</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>35.000000</td>\n",
       "      <td>8.0500</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>886</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>27.000000</td>\n",
       "      <td>13.0000</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>887</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>19.000000</td>\n",
       "      <td>30.0000</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>888</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>29.699118</td>\n",
       "      <td>23.4500</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>889</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>26.000000</td>\n",
       "      <td>30.0000</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>890</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>32.000000</td>\n",
       "      <td>7.7500</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>891 rows × 27 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     Survived  Sex        Age     Fare  family  AgeGroup_1  AgeGroup_2  \\\n",
       "0           0    0  22.000000   7.2500       2           0           0   \n",
       "1           1    1  38.000000  71.2833       2           0           0   \n",
       "2           1    1  26.000000   7.9250       1           0           0   \n",
       "3           1    1  35.000000  53.1000       2           0           0   \n",
       "4           0    0  35.000000   8.0500       1           0           0   \n",
       "..        ...  ...        ...      ...     ...         ...         ...   \n",
       "886         0    0  27.000000  13.0000       1           0           0   \n",
       "887         1    1  19.000000  30.0000       1           0           0   \n",
       "888         0    1  29.699118  23.4500       4           0           0   \n",
       "889         1    0  26.000000  30.0000       1           0           0   \n",
       "890         0    0  32.000000   7.7500       1           0           0   \n",
       "\n",
       "     AgeGroup_3  AgeGroup_4  SibSp_0  ...  Pclass_2  Pclass_3  Parch_0  \\\n",
       "0             0           1        0  ...         0         1        1   \n",
       "1             1           0        0  ...         0         0        1   \n",
       "2             1           0        1  ...         0         1        1   \n",
       "3             1           0        0  ...         0         0        1   \n",
       "4             0           1        1  ...         0         1        1   \n",
       "..          ...         ...      ...  ...       ...       ...      ...   \n",
       "886           0           1        1  ...         1         0        1   \n",
       "887           1           0        1  ...         0         0        1   \n",
       "888           1           0        0  ...         0         1        0   \n",
       "889           0           1        1  ...         0         0        1   \n",
       "890           0           1        1  ...         0         1        1   \n",
       "\n",
       "     Parch_1  Parch_2  Parch_3  Parch_4  Parch_5  Parch_6  isAlone  \n",
       "0          0        0        0        0        0        0        0  \n",
       "1          0        0        0        0        0        0        0  \n",
       "2          0        0        0        0        0        0        1  \n",
       "3          0        0        0        0        0        0        0  \n",
       "4          0        0        0        0        0        0        1  \n",
       "..       ...      ...      ...      ...      ...      ...      ...  \n",
       "886        0        0        0        0        0        0        1  \n",
       "887        0        0        0        0        0        0        1  \n",
       "888        0        1        0        0        0        0        0  \n",
       "889        0        0        0        0        0        0        1  \n",
       "890        0        0        0        0        0        0        1  \n",
       "\n",
       "[891 rows x 27 columns]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('../data/extra/titanic.csv')\n",
    "df = df.drop(['Embarked_0'], axis=1)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['Survived', 'Sex', 'Age', 'Fare', 'family', 'AgeGroup_1', 'AgeGroup_2',\n",
       "       'AgeGroup_3', 'AgeGroup_4', 'SibSp_0', 'SibSp_1', 'SibSp_2', 'SibSp_3',\n",
       "       'SibSp_4', 'SibSp_5', 'SibSp_8', 'Pclass_1', 'Pclass_2', 'Pclass_3',\n",
       "       'Parch_0', 'Parch_1', 'Parch_2', 'Parch_3', 'Parch_4', 'Parch_5',\n",
       "       'Parch_6', 'isAlone'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_cols = ['Sex', 'family', 'AgeGroup_1', 'AgeGroup_2',\n",
    "       'AgeGroup_3', 'AgeGroup_4', 'SibSp_0', 'SibSp_1', 'SibSp_2', 'SibSp_3',\n",
    "       'SibSp_4', 'SibSp_5', 'SibSp_8', 'Pclass_1', 'Pclass_2', 'Pclass_3',\n",
    "       'Parch_0', 'Parch_1', 'Parch_2', 'Parch_3', 'Parch_4', 'Parch_5',\n",
    "       'Parch_6', 'isAlone']\n",
    "outcome_col = 'Survived'\n",
    "file_name = \"data/extra/s_titanic.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = simple_transform(df, cat_cols=cat_cols, outcome_col=outcome_col)\n",
    "df = shuffle(df)\n",
    "df.to_csv('../' + file_name, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'data_dir': 'data/extra/s_titanic.csv',\n",
       " 'lr': 0.01,\n",
       " 'batch_size': 128,\n",
       " 'lambda_1': 1.0,\n",
       " 'lambda_2': 0.01,\n",
       " 'lambda_3': 1.0,\n",
       " 'threshold': 1.0,\n",
       " 'continous_cols': ['Age', 'Fare'],\n",
       " 'discret_cols': ['Sex',\n",
       "  'family',\n",
       "  'AgeGroup_1',\n",
       "  'AgeGroup_2',\n",
       "  'AgeGroup_3',\n",
       "  'AgeGroup_4',\n",
       "  'SibSp_0',\n",
       "  'SibSp_1',\n",
       "  'SibSp_2',\n",
       "  'SibSp_3',\n",
       "  'SibSp_4',\n",
       "  'SibSp_5',\n",
       "  'SibSp_8',\n",
       "  'Pclass_1',\n",
       "  'Pclass_2',\n",
       "  'Pclass_3',\n",
       "  'Parch_0',\n",
       "  'Parch_1',\n",
       "  'Parch_2',\n",
       "  'Parch_3',\n",
       "  'Parch_4',\n",
       "  'Parch_5',\n",
       "  'Parch_6',\n",
       "  'isAlone'],\n",
       " 'encoder_dims': [29, 50, 10],\n",
       " 'decoder_dims': [10, 10],\n",
       " 'explainer_dims': [10, 50],\n",
       " 'loss_1': 'mse',\n",
       " 'loss_2': 'mse',\n",
       " 'loss_3': 'mse'}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# load configs of adult for no reasons\n",
    "param = load_json(\"../counterfactual/configs/adult.json\")\n",
    "cols = df.columns.tolist()\n",
    "for col in cat_cols:\n",
    "    cols.remove(col)\n",
    "cols.remove(outcome_col)\n",
    "# copy to cont_cols\n",
    "cont_cols = cols\n",
    "param['data_dir'] = file_name\n",
    "param['continous_cols'] = cont_cols\n",
    "param['discret_cols'] = cat_cols\n",
    "update_json_file(param, \"../counterfactual/configs/extra/titanic.json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# breast-cancer-wisconsin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = np.genfromtxt('../data/wdbc.data', delimiter=',',\n",
    "            dtype=str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cols = ['ID', 'Diagnosis']\n",
    "for i in range(1, 31):\n",
    "    cols.append(f'X_{i}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_cols = []\n",
    "outcome_col = 'Diagnosis'\n",
    "file_name = \"data/extra/s_breast_cancer.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(data, columns=cols)\n",
    "df = df.drop(['ID'], axis=1)\n",
    "# df = df.infer_objects()\n",
    "df['Diagnosis'] = df['Diagnosis'].apply(lambda x: int(x == 'B'))\n",
    "\n",
    "df = simple_transform(df, cat_cols=cat_cols, outcome_col=outcome_col)\n",
    "df = shuffle(df)\n",
    "df.to_csv('../' + file_name, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'data_dir': 'data/extra/s_breast_cancer.csv',\n",
       " 'lr': 0.01,\n",
       " 'batch_size': 128,\n",
       " 'lambda_1': 1.0,\n",
       " 'lambda_2': 0.01,\n",
       " 'lambda_3': 1.0,\n",
       " 'threshold': 1.0,\n",
       " 'continous_cols': ['X_1',\n",
       "  'X_2',\n",
       "  'X_3',\n",
       "  'X_4',\n",
       "  'X_5',\n",
       "  'X_6',\n",
       "  'X_7',\n",
       "  'X_8',\n",
       "  'X_9',\n",
       "  'X_10',\n",
       "  'X_11',\n",
       "  'X_12',\n",
       "  'X_13',\n",
       "  'X_14',\n",
       "  'X_15',\n",
       "  'X_16',\n",
       "  'X_17',\n",
       "  'X_18',\n",
       "  'X_19',\n",
       "  'X_20',\n",
       "  'X_21',\n",
       "  'X_22',\n",
       "  'X_23',\n",
       "  'X_24',\n",
       "  'X_25',\n",
       "  'X_26',\n",
       "  'X_27',\n",
       "  'X_28',\n",
       "  'X_29',\n",
       "  'X_30'],\n",
       " 'discret_cols': [],\n",
       " 'encoder_dims': [29, 50, 10],\n",
       " 'decoder_dims': [10, 10],\n",
       " 'explainer_dims': [10, 50],\n",
       " 'loss_1': 'mse',\n",
       " 'loss_2': 'mse',\n",
       " 'loss_3': 'mse'}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# load configs of adult for no reasons\n",
    "param = load_json(\"../counterfactual/configs/adult.json\")\n",
    "cols = df.columns.tolist()\n",
    "for col in cat_cols:\n",
    "    cols.remove(col)\n",
    "cols.remove(outcome_col)\n",
    "# copy to cont_cols\n",
    "cont_cols = cols\n",
    "param['data_dir'] = file_name\n",
    "param['continous_cols'] = cont_cols\n",
    "param['discret_cols'] = cat_cols\n",
    "update_json_file(param, \"../counterfactual/configs/extra/breast_cancer.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = simple_transform(df, [], 'Diagnosis')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['842302', 'M', '17.99', '10.38', '122.8', '1001', '0.1184',\n",
       "       '0.2776', '0.3001', '0.1471', '0.2419', '0.07871', '1.095',\n",
       "       '0.9053', '8.589', '153.4', '0.006399', '0.04904', '0.05373',\n",
       "       '0.01587', '0.03003', '0.006193', '25.38', '17.33', '184.6',\n",
       "       '2019', '0.1622', '0.6656', '0.7119', '0.2654', '0.4601', '0.1189'],\n",
       "      dtype='<U9')"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Describe Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "configs = [\n",
    "    (\"adult\", load_json(\"../counterfactual/configs/adult.json\")),\n",
    "    (\"home\", load_json(\"../counterfactual/configs/home.json\")),\n",
    "    (\"student\", load_json(\"../counterfactual/configs/student.json\")),\n",
    "    (\"breast_cancer\", load_json(\"../counterfactual/configs/extra/breast_cancer.json\")),\n",
    "    (\"credit_card\", load_json(\"../counterfactual/configs/extra/credit_card.json\")),\n",
    "    (\"german_credit\", load_json(\"../counterfactual/configs/extra/german_credit.json\")),\n",
    "    (\"heart\", load_json(\"../counterfactual/configs/extra/heart.json\")),\n",
    "    (\"student_performance\", load_json(\"../counterfactual/configs/extra/student_performance.json\")),\n",
    "    (\"titanic\", load_json(\"../counterfactual/configs/extra/titanic.json\")),\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def describe(configs: list):\n",
    "    r = {\"size\": {}, \"# of Cont\": {}, \"# of Cat\": {}}\n",
    "    for data_name, config in configs:\n",
    "        data = pd.read_csv(f\"../{config['data_dir']}\")\n",
    "        data_size = len(data)\n",
    "        cat_len = len(config['discret_cols'])\n",
    "        cont_len = len(config['continous_cols'])\n",
    "        r['size'][data_name] = data_size\n",
    "        r['# of Cont'][data_name] = cont_len\n",
    "        r['# of Cat'][data_name] = cat_len\n",
    "    \n",
    "    pd.DataFrame.from_dict(r).to_csv(\"../results/data_describe.csv\")\n",
    "    return r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'size': {'adult': 32561,\n",
       "  'home': 10459,\n",
       "  'student': 32593,\n",
       "  'breast_cancer': 569,\n",
       "  'credit_card': 30000,\n",
       "  'german_credit': 1000,\n",
       "  'heart': 299,\n",
       "  'student_performance': 649,\n",
       "  'titanic': 891},\n",
       " '# of Cont': {'adult': 2,\n",
       "  'home': 21,\n",
       "  'student': 23,\n",
       "  'breast_cancer': 30,\n",
       "  'credit_card': 20,\n",
       "  'german_credit': 7,\n",
       "  'heart': 7,\n",
       "  'student_performance': 2,\n",
       "  'titanic': 2},\n",
       " '# of Cat': {'adult': 6,\n",
       "  'home': 2,\n",
       "  'student': 8,\n",
       "  'breast_cancer': 0,\n",
       "  'credit_card': 3,\n",
       "  'german_credit': 13,\n",
       "  'heart': 5,\n",
       "  'student_performance': 14,\n",
       "  'titanic': 24}}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "describe(configs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pytorch Utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "\n",
    "class NumpyDataset(TensorDataset):\n",
    "    def __init__(self, *arrs, ):\n",
    "        super(NumpyDataset, self).__init__()\n",
    "        # init tensors\n",
    "        # small patch: skip continous or discrete array without content\n",
    "        self.tensors = [torch.tensor(arr).float() for arr in arrs if arr.shape[-1] != 0]\n",
    "        assert all(self.tensors[0].size(0) == tensor.size(0) for tensor in self.tensors)\n",
    "\n",
    "    def data_loader(self, batch_size=128, shuffle=True, num_workers=4):\n",
    "        return DataLoader(self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)\n",
    "\n",
    "    def features(self, test=False):\n",
    "        return tuple(self.tensors[:-1] if not test else self.tensors)\n",
    "\n",
    "    def target(self, test=False):\n",
    "        return self.tensors[-1] if not test else None\n",
    "\n",
    "class PandasDataset(NumpyDataset):\n",
    "    def __init__(self, df: pd.DataFrame):\n",
    "        cols = df.columns\n",
    "        X = df[cols[:-1]].to_numpy()\n",
    "        y = df[cols[-1]].to_numpy()\n",
    "        super().__init__(X, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = np.random.normal(50, 15, 100)\n",
    "y = np.random.normal(50, 15, 100)\n",
    "df_test = pd.DataFrame({'x': x, 'y': y})\n",
    "arrs = np.column_stack((x, y))\n",
    "np_dataset = NumpyDataset(x, y)\n",
    "pd_dataset = PandasDataset(df_test)\n",
    "\n",
    "assert (arrs == df_test.to_numpy()).all()\n",
    "assert len(np_dataset) == len(pd_dataset)\n",
    "\n",
    "for i in range(len(np_dataset)):\n",
    "    assert np_dataset[i] == pd_dataset[i]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.column_stack((x, y)) == df_test.to_numpy()).all()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
