{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:31.944221Z",
     "start_time": "2023-05-14T19:59:31.869652Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.insert(0, \"../utils\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.291538Z",
     "start_time": "2023-05-14T19:59:31.876778Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "import sklearn.datasets as skds\n",
    "from sklearn.preprocessing import QuantileTransformer, KBinsDiscretizer, OrdinalEncoder, LabelEncoder\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from transformation import BSplineTransformer, spline_transform_dataset\n",
    "from trainers import FFMTrainer\n",
    "import math\n",
    "import optuna\n",
    "import optuna.samplers\n",
    "from typing import Callable\n",
    "from sklearn.model_selection import train_test_split\n",
    "import torch\n",
    "from torch.utils.data import TensorDataset\n",
    "from tqdm import trange"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.291828Z",
     "start_time": "2023-05-14T19:59:35.238271Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda:0\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.291910Z",
     "start_time": "2023-05-14T19:59:35.244997Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(42)\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.445506Z",
     "start_time": "2023-05-14T19:59:35.256872Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "raw_df = pd.read_csv(\"../data/adult-all.txt\",\n",
    "                      names=[\"Age\", \"Workclass\", \"fnlwgt\", \"Education\", \"Education-Num\", \"Martial Status\",\n",
    "                             \"Occupation\", \"Relationship\", \"Race\", \"Sex\", \"Capital Gain\", \"Capital Loss\",\n",
    "                             \"Hours per week\", \"Country\", \"Target\"],\n",
    "                      dtype={0:int, 1:str, 2:int, 3:str, 4:int, 5: str, 6:str ,\n",
    "                             7:str ,8:str ,9: str, 10:int, 11:int, 12:int, 13:str,14: str},\n",
    "                      na_values=\"?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.465430Z",
     "start_time": "2023-05-14T19:59:35.384297Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Age</th>\n",
       "      <th>Workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>Education</th>\n",
       "      <th>Education-Num</th>\n",
       "      <th>Martial Status</th>\n",
       "      <th>Occupation</th>\n",
       "      <th>Relationship</th>\n",
       "      <th>Race</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Capital Gain</th>\n",
       "      <th>Capital Loss</th>\n",
       "      <th>Hours per week</th>\n",
       "      <th>Country</th>\n",
       "      <th>Target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>7762</th>\n",
       "      <td>18</td>\n",
       "      <td>Private</td>\n",
       "      <td>423024</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Other-service</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>20</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23881</th>\n",
       "      <td>17</td>\n",
       "      <td>Private</td>\n",
       "      <td>178953</td>\n",
       "      <td>12th</td>\n",
       "      <td>8</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Sales</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>White</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>20</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30507</th>\n",
       "      <td>25</td>\n",
       "      <td>Local-gov</td>\n",
       "      <td>348986</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Other-relative</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28911</th>\n",
       "      <td>20</td>\n",
       "      <td>Private</td>\n",
       "      <td>218215</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Sales</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>White</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>30</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19484</th>\n",
       "      <td>47</td>\n",
       "      <td>Private</td>\n",
       "      <td>244025</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Unmarried</td>\n",
       "      <td>Amer-Indian-Eskimo</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>56</td>\n",
       "      <td>Puerto-Rico</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>43031</th>\n",
       "      <td>33</td>\n",
       "      <td>Private</td>\n",
       "      <td>399531</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Craft-repair</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28188</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>200220</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Craft-repair</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12761</th>\n",
       "      <td>21</td>\n",
       "      <td>Private</td>\n",
       "      <td>329530</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Craft-repair</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>Mexico</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40834</th>\n",
       "      <td>43</td>\n",
       "      <td>Private</td>\n",
       "      <td>282155</td>\n",
       "      <td>Assoc-acdm</td>\n",
       "      <td>12</td>\n",
       "      <td>Divorced</td>\n",
       "      <td>Prof-specialty</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Female</td>\n",
       "      <td>4650</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27875</th>\n",
       "      <td>55</td>\n",
       "      <td>Private</td>\n",
       "      <td>202220</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Other-service</td>\n",
       "      <td>Wife</td>\n",
       "      <td>Black</td>\n",
       "      <td>Female</td>\n",
       "      <td>2407</td>\n",
       "      <td>0</td>\n",
       "      <td>35</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1276</th>\n",
       "      <td>46</td>\n",
       "      <td>Private</td>\n",
       "      <td>129007</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Sales</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>1977</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&gt;50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22608</th>\n",
       "      <td>34</td>\n",
       "      <td>Private</td>\n",
       "      <td>261799</td>\n",
       "      <td>Assoc-voc</td>\n",
       "      <td>11</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>45</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&gt;50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36230</th>\n",
       "      <td>40</td>\n",
       "      <td>NaN</td>\n",
       "      <td>246862</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Widowed</td>\n",
       "      <td>NaN</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>8</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13398</th>\n",
       "      <td>50</td>\n",
       "      <td>Private</td>\n",
       "      <td>173754</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Craft-repair</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>43536</th>\n",
       "      <td>29</td>\n",
       "      <td>Private</td>\n",
       "      <td>176727</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Craft-repair</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>38</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18627</th>\n",
       "      <td>62</td>\n",
       "      <td>Private</td>\n",
       "      <td>24515</td>\n",
       "      <td>9th</td>\n",
       "      <td>5</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38424</th>\n",
       "      <td>56</td>\n",
       "      <td>Private</td>\n",
       "      <td>158776</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Sales</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>55</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35505</th>\n",
       "      <td>40</td>\n",
       "      <td>Federal-gov</td>\n",
       "      <td>90737</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>1887</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&gt;50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2372</th>\n",
       "      <td>74</td>\n",
       "      <td>NaN</td>\n",
       "      <td>340939</td>\n",
       "      <td>9th</td>\n",
       "      <td>5</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>NaN</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>3471</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3375</th>\n",
       "      <td>21</td>\n",
       "      <td>Private</td>\n",
       "      <td>305874</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Craft-repair</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>54</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       Age    Workclass  fnlwgt     Education  Education-Num  \\\n",
       "7762    18      Private  423024       HS-grad              9   \n",
       "23881   17      Private  178953          12th              8   \n",
       "30507   25    Local-gov  348986       HS-grad              9   \n",
       "28911   20      Private  218215  Some-college             10   \n",
       "19484   47      Private  244025       HS-grad              9   \n",
       "43031   33      Private  399531     Bachelors             13   \n",
       "28188   38      Private  200220       HS-grad              9   \n",
       "12761   21      Private  329530          11th              7   \n",
       "40834   43      Private  282155    Assoc-acdm             12   \n",
       "27875   55      Private  202220       HS-grad              9   \n",
       "1276    46      Private  129007     Bachelors             13   \n",
       "22608   34      Private  261799     Assoc-voc             11   \n",
       "36230   40          NaN  246862     Bachelors             13   \n",
       "13398   50      Private  173754       HS-grad              9   \n",
       "43536   29      Private  176727  Some-college             10   \n",
       "18627   62      Private   24515           9th              5   \n",
       "38424   56      Private  158776          11th              7   \n",
       "35505   40  Federal-gov   90737  Some-college             10   \n",
       "2372    74          NaN  340939           9th              5   \n",
       "3375    21      Private  305874  Some-college             10   \n",
       "\n",
       "           Martial Status         Occupation    Relationship  \\\n",
       "7762        Never-married      Other-service   Not-in-family   \n",
       "23881       Never-married              Sales       Own-child   \n",
       "30507       Never-married  Handlers-cleaners  Other-relative   \n",
       "28911       Never-married              Sales       Own-child   \n",
       "19484       Never-married  Machine-op-inspct       Unmarried   \n",
       "43031  Married-civ-spouse       Craft-repair         Husband   \n",
       "28188  Married-civ-spouse       Craft-repair         Husband   \n",
       "12761  Married-civ-spouse       Craft-repair         Husband   \n",
       "40834            Divorced     Prof-specialty   Not-in-family   \n",
       "27875  Married-civ-spouse      Other-service            Wife   \n",
       "1276   Married-civ-spouse              Sales         Husband   \n",
       "22608  Married-civ-spouse       Adm-clerical         Husband   \n",
       "36230             Widowed                NaN   Not-in-family   \n",
       "13398  Married-civ-spouse       Craft-repair         Husband   \n",
       "43536       Never-married       Craft-repair   Not-in-family   \n",
       "18627  Married-civ-spouse    Exec-managerial         Husband   \n",
       "38424  Married-civ-spouse              Sales         Husband   \n",
       "35505  Married-civ-spouse       Adm-clerical         Husband   \n",
       "2372   Married-civ-spouse                NaN         Husband   \n",
       "3375   Married-civ-spouse       Craft-repair         Husband   \n",
       "\n",
       "                     Race     Sex  Capital Gain  Capital Loss  Hours per week  \\\n",
       "7762                White    Male             0             0              20   \n",
       "23881               White  Female             0             0              20   \n",
       "30507               Black    Male             0             0              40   \n",
       "28911               White  Female             0             0              30   \n",
       "19484  Amer-Indian-Eskimo    Male             0             0              56   \n",
       "43031               Black    Male             0             0              40   \n",
       "28188               White    Male             0             0              40   \n",
       "12761               White    Male             0             0              40   \n",
       "40834               White  Female          4650             0              40   \n",
       "27875               Black  Female          2407             0              35   \n",
       "1276                White    Male             0          1977              40   \n",
       "22608               Black    Male             0             0              45   \n",
       "36230               White  Female             0             0               8   \n",
       "13398               White    Male             0             0              40   \n",
       "43536               White    Male             0             0              38   \n",
       "18627               White    Male             0             0              40   \n",
       "38424               White    Male             0             0              55   \n",
       "35505               White    Male             0          1887              40   \n",
       "2372                White    Male          3471             0              40   \n",
       "3375                White    Male             0             0              54   \n",
       "\n",
       "             Country Target  \n",
       "7762   United-States  <=50K  \n",
       "23881  United-States  <=50K  \n",
       "30507  United-States  <=50K  \n",
       "28911  United-States  <=50K  \n",
       "19484    Puerto-Rico  <=50K  \n",
       "43031  United-States  <=50K  \n",
       "28188  United-States  <=50K  \n",
       "12761         Mexico  <=50K  \n",
       "40834  United-States  <=50K  \n",
       "27875  United-States  <=50K  \n",
       "1276   United-States   >50K  \n",
       "22608  United-States   >50K  \n",
       "36230  United-States  <=50K  \n",
       "13398  United-States  <=50K  \n",
       "43536  United-States  <=50K  \n",
       "18627  United-States  <=50K  \n",
       "38424  United-States  <=50K  \n",
       "35505  United-States   >50K  \n",
       "2372   United-States  <=50K  \n",
       "3375   United-States  <=50K  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_df.sample(20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.465923Z",
     "start_time": "2023-05-14T19:59:35.408593Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['Age', 'Workclass', 'fnlwgt', 'Education', 'Education-Num',\n",
       "       'Martial Status', 'Occupation', 'Relationship', 'Race', 'Sex',\n",
       "       'Capital Gain', 'Capital Loss', 'Hours per week', 'Country', 'Target'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.466125Z",
     "start_time": "2023-05-14T19:59:35.413880Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "categorical_columns = ['Workclass', 'Education', 'Martial Status', 'Occupation', 'Relationship', 'Race', 'Sex', 'Country']\n",
    "numerical_columns = ['Age', 'fnlwgt', 'Education-Num', 'Capital Gain', 'Capital Loss', 'Hours per week']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.738692Z",
     "start_time": "2023-05-14T19:59:35.421123Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "na_dict = {col: f'NA_{col}' for col in categorical_columns}\n",
    "cat_ordinal = raw_df.fillna(na_dict)\n",
    "cat_ordinal[categorical_columns] = OrdinalEncoder().fit_transform(cat_ordinal[categorical_columns])\n",
    "cat_ordinal[\"Target\"] = LabelEncoder().fit_transform(cat_ordinal[\"Target\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.772301Z",
     "start_time": "2023-05-14T19:59:35.617641Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "train, test = train_test_split(cat_ordinal, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.772529Z",
     "start_time": "2023-05-14T19:59:35.642783Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "tr_cat = train[categorical_columns]\n",
    "tr_num = train[numerical_columns]\n",
    "tr_target = train[\"Target\"]\n",
    "\n",
    "te_cat = test[categorical_columns]\n",
    "te_num = test[numerical_columns]\n",
    "te_target = test[\"Target\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.773052Z",
     "start_time": "2023-05-14T19:59:35.647130Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "tr_num_qs = []\n",
    "te_num_qs = []\n",
    "special_values = dict()\n",
    "for col_idx, col in enumerate(tr_num.columns):\n",
    "    tr_col = tr_num.iloc[:, col_idx].to_numpy().astype(np.float32)\n",
    "    te_col = te_num.iloc[:, col_idx].to_numpy().astype(np.float32)\n",
    "\n",
    "    if col in ['Capital Loss', 'Capital Gain']:\n",
    "        regular_tr_mask = tr_col > 0\n",
    "        regular_te_mask = te_col > 0\n",
    "        tr_col[~regular_tr_mask] = -1.\n",
    "        te_col[~regular_te_mask] = -1.\n",
    "        special_values[col_idx] = [-1.]\n",
    "    else:\n",
    "        regular_tr_mask = np.ones_like(tr_col, dtype=bool)\n",
    "        regular_te_mask = np.ones_like(te_col, dtype=bool)\n",
    "\n",
    "    transformer = QuantileTransformer(subsample=np.sum(regular_tr_mask), output_distribution='uniform')\n",
    "    tr_col[regular_tr_mask] = transformer.fit_transform(tr_col[regular_tr_mask].reshape(-1, 1)).reshape(-1)\n",
    "    te_col[regular_te_mask] = transformer.transform(te_col[regular_te_mask].reshape(-1, 1)).reshape(-1)\n",
    "\n",
    "    tr_num_qs.append(tr_col)\n",
    "    te_num_qs.append(te_col)\n",
    "\n",
    "tr_num_qs = np.stack(tr_num_qs, axis=1)\n",
    "te_num_qs = np.stack(te_num_qs, axis=1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.773125Z",
     "start_time": "2023-05-14T19:59:35.712144Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "num_cat_fields = tr_cat.shape[1]\n",
    "cat_offsets = np.cumsum([0] + [cat_ordinal[col].nunique() for col in categorical_columns])\n",
    "num_cat_embeddings = cat_offsets[-1]\n",
    "cat_offsets = cat_offsets[:-1]\n",
    "\n",
    "tr_cat_indices = tr_cat.values + np.tile(cat_offsets, (len(tr_cat), 1))\n",
    "tr_cat_weights = np.ones_like(tr_cat_indices, dtype=np.float32)\n",
    "tr_cat_offsets = np.tile(np.arange(num_cat_fields, dtype=np.int32), (tr_cat.shape[0], 1))\n",
    "tr_cat_fields = tr_cat_offsets\n",
    "\n",
    "te_cat_indices = te_cat.values + np.tile(cat_offsets, (len(te_cat), 1))\n",
    "te_cat_weights = np.ones_like(te_cat_indices, dtype=np.float32)\n",
    "te_cat_offsets = np.tile(np.arange(num_cat_fields, dtype=np.int32), (te_cat.shape[0], 1))\n",
    "te_cat_fields = te_cat_offsets\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.773208Z",
     "start_time": "2023-05-14T19:59:35.735974Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "def train_spline_ffm(embedding_dim: int, step_size: float, batch_size: int, num_knots: int, num_epochs: int,\n",
    "                     callback: Callable[[int, float], None]=None):\n",
    "    bs = BSplineTransformer(num_knots, 3)\n",
    "    tr_num_indices, tr_num_weights, tr_num_offsets, tr_num_fields = spline_transform_dataset(tr_num_qs, bs, special_values=special_values)\n",
    "    te_num_indices, te_num_weights, te_num_offsets, te_num_fields = spline_transform_dataset(te_num_qs, bs, special_values=special_values)\n",
    "\n",
    "    num_numerical_fields = tr_num_qs.shape[1]\n",
    "    num_numerical_embeddings = int(max(np.max(tr_num_indices), np.max(te_num_indices)) + 1)\n",
    "\n",
    "    num_fields = num_numerical_fields + num_cat_fields\n",
    "    num_embeddings = num_numerical_embeddings + num_cat_embeddings\n",
    "\n",
    "    tr_indices = np.concatenate([tr_cat_indices, tr_num_indices + num_cat_embeddings], axis=1)\n",
    "    tr_weights = np.concatenate([tr_cat_weights, tr_num_weights], axis=1)\n",
    "    tr_offsets = np.concatenate([tr_cat_offsets, tr_num_offsets + num_cat_fields], axis=1)\n",
    "    tr_fields = np.concatenate([tr_cat_fields, tr_num_fields + num_cat_fields], axis=1)\n",
    "\n",
    "    te_indices = np.concatenate([te_cat_indices, te_num_indices + num_cat_embeddings], axis=1)\n",
    "    te_weights = np.concatenate([te_cat_weights, te_num_weights], axis=1)\n",
    "    te_offsets = np.concatenate([te_cat_offsets, te_num_offsets + num_cat_fields], axis=1)\n",
    "    te_fields = np.concatenate([te_cat_fields, te_num_fields + num_cat_fields], axis=1)\n",
    "\n",
    "    train_ds = TensorDataset(\n",
    "        torch.tensor(tr_indices, dtype=torch.int64),\n",
    "        torch.tensor(tr_weights, dtype=torch.float32),\n",
    "        torch.tensor(tr_offsets, dtype=torch.int64),\n",
    "        torch.tensor(tr_fields, dtype=torch.int64),\n",
    "        torch.tensor(tr_target.values, dtype=torch.float32))\n",
    "\n",
    "    test_ds = TensorDataset(\n",
    "        torch.tensor(te_indices, dtype=torch.int64),\n",
    "        torch.tensor(te_weights, dtype=torch.float32),\n",
    "        torch.tensor(te_offsets, dtype=torch.int64),\n",
    "        torch.tensor(te_fields, dtype=torch.int64),\n",
    "        torch.tensor(te_target.values, dtype=torch.float32))\n",
    "\n",
    "\n",
    "    trainer = FFMTrainer(embedding_dim, step_size, batch_size, num_epochs, callback)\n",
    "    return trainer.train(num_fields, num_embeddings, train_ds, test_ds, torch.nn.BCEWithLogitsLoss(), device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.773274Z",
     "start_time": "2023-05-14T19:59:35.744532Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "def train_spline_objective(trial: optuna.Trial):\n",
    "    embedding_dim = trial.suggest_int('embedding_dim', 1, 10)\n",
    "    step_size = trial.suggest_float('step_size', 1e-2, 0.5, log=True)\n",
    "    batch_size = trial.suggest_int('batch_size', 2, 32)\n",
    "    num_knots = trial.suggest_int('num_knots', 3, 48)\n",
    "    num_epochs = trial.suggest_int('num_epochs', 5, 15)\n",
    "\n",
    "    def callback(epoch: int, loss: float):\n",
    "        trial.report(loss, epoch)\n",
    "        if trial.should_prune():\n",
    "            raise optuna.TrialPruned()\n",
    "\n",
    "    return train_spline_ffm(embedding_dim, step_size, batch_size, num_knots, num_epochs,\n",
    "                           callback=callback)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-05-14T19:59:35.750585Z"
    },
    "collapsed": true,
    "is_executing": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m[I 2023-05-16 19:04:42,736]\u001b[0m A new study created in memory with name: splines\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:06:55,632]\u001b[0m Trial 0 finished with value: 0.3308318555355072 and parameters: {'embedding_dim': 4, 'step_size': 0.4123206532618726, 'batch_size': 24, 'num_knots': 30, 'num_epochs': 6}. Best is trial 0 with value: 0.3308318555355072.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:11:10,824]\u001b[0m Trial 1 finished with value: 0.2987254559993744 and parameters: {'embedding_dim': 2, 'step_size': 0.012551115172973842, 'batch_size': 28, 'num_knots': 30, 'num_epochs': 12}. Best is trial 1 with value: 0.2987254559993744.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:13:41,166]\u001b[0m Trial 2 finished with value: 0.308167040348053 and parameters: {'embedding_dim': 1, 'step_size': 0.44447541666908114, 'batch_size': 27, 'num_knots': 12, 'num_epochs': 7}. Best is trial 1 with value: 0.2987254559993744.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:18:26,030]\u001b[0m Trial 3 finished with value: 0.2933942675590515 and parameters: {'embedding_dim': 2, 'step_size': 0.0328774741399112, 'batch_size': 18, 'num_knots': 22, 'num_epochs': 8}. Best is trial 3 with value: 0.2933942675590515.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:26:54,210]\u001b[0m Trial 4 finished with value: 0.2895911633968353 and parameters: {'embedding_dim': 7, 'step_size': 0.017258215396625, 'batch_size': 11, 'num_knots': 19, 'num_epochs': 10}. Best is trial 4 with value: 0.2895911633968353.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:29:16,986]\u001b[0m Trial 5 finished with value: 0.28712841868400574 and parameters: {'embedding_dim': 8, 'step_size': 0.021839352923182977, 'batch_size': 17, 'num_knots': 30, 'num_epochs': 5}. Best is trial 5 with value: 0.28712841868400574.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:01:38,830]\u001b[0m Trial 6 finished with value: 0.28596022725105286 and parameters: {'embedding_dim': 7, 'step_size': 0.019485671251272575, 'batch_size': 4, 'num_knots': 46, 'num_epochs': 15}. Best is trial 6 with value: 0.28596022725105286.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:18:48,998]\u001b[0m Trial 7 finished with value: 0.29043394327163696 and parameters: {'embedding_dim': 9, 'step_size': 0.032925293631105246, 'batch_size': 5, 'num_knots': 34, 'num_epochs': 9}. Best is trial 6 with value: 0.28596022725105286.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:41:45,796]\u001b[0m Trial 8 finished with value: 0.28961437940597534 and parameters: {'embedding_dim': 2, 'step_size': 0.06938901412739397, 'batch_size': 3, 'num_knots': 44, 'num_epochs': 7}. Best is trial 6 with value: 0.28596022725105286.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:42:23,897]\u001b[0m Trial 9 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:43:21,803]\u001b[0m Trial 10 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:54:52,689]\u001b[0m Trial 11 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:55:24,798]\u001b[0m Trial 12 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:00:02,801]\u001b[0m Trial 13 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:01:24,320]\u001b[0m Trial 14 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:08:49,223]\u001b[0m Trial 15 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:42:43,899]\u001b[0m Trial 16 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:46:03,522]\u001b[0m Trial 17 finished with value: 0.28969407081604004 and parameters: {'embedding_dim': 6, 'step_size': 0.021721521376611954, 'batch_size': 14, 'num_knots': 24, 'num_epochs': 5}. Best is trial 6 with value: 0.28596022725105286.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:50:04,614]\u001b[0m Trial 18 finished with value: 0.28562605381011963 and parameters: {'embedding_dim': 9, 'step_size': 0.01404915501132402, 'batch_size': 21, 'num_knots': 34, 'num_epochs': 9}. Best is trial 18 with value: 0.28562605381011963.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:50:30,101]\u001b[0m Trial 19 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:51:59,745]\u001b[0m Trial 20 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:56:23,019]\u001b[0m Trial 21 finished with value: 0.28827688097953796 and parameters: {'embedding_dim': 8, 'step_size': 0.022702957394650548, 'batch_size': 21, 'num_knots': 34, 'num_epochs': 10}. Best is trial 18 with value: 0.28562605381011963.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:56:45,753]\u001b[0m Trial 22 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:57:18,215]\u001b[0m Trial 23 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:57:30,485]\u001b[0m Trial 24 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:01:21,461]\u001b[0m Trial 25 finished with value: 0.28463757038116455 and parameters: {'embedding_dim': 10, 'step_size': 0.015507328009848734, 'batch_size': 19, 'num_knots': 42, 'num_epochs': 8}. Best is trial 25 with value: 0.28463757038116455.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:01:50,284]\u001b[0m Trial 26 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:08:48,600]\u001b[0m Trial 27 finished with value: 0.28615275025367737 and parameters: {'embedding_dim': 9, 'step_size': 0.014103190106531065, 'batch_size': 13, 'num_knots': 48, 'num_epochs': 10}. Best is trial 25 with value: 0.28463757038116455.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:12:45,217]\u001b[0m Trial 28 finished with value: 0.2895828187465668 and parameters: {'embedding_dim': 10, 'step_size': 0.02582821902928931, 'batch_size': 24, 'num_knots': 37, 'num_epochs': 11}. Best is trial 25 with value: 0.28463757038116455.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:15:06,625]\u001b[0m Trial 29 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:17:01,477]\u001b[0m Trial 30 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:17:42,523]\u001b[0m Trial 31 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:25:47,914]\u001b[0m Trial 32 finished with value: 0.28570377826690674 and parameters: {'embedding_dim': 10, 'step_size': 0.016148405038761453, 'batch_size': 10, 'num_knots': 41, 'num_epochs': 9}. Best is trial 25 with value: 0.28463757038116455.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:34:32,024]\u001b[0m Trial 33 finished with value: 0.28472790122032166 and parameters: {'embedding_dim': 10, 'step_size': 0.017898973177799225, 'batch_size': 10, 'num_knots': 41, 'num_epochs': 9}. Best is trial 25 with value: 0.28463757038116455.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:35:27,217]\u001b[0m Trial 34 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:45:38,667]\u001b[0m Trial 35 finished with value: 0.2841481864452362 and parameters: {'embedding_dim': 10, 'step_size': 0.01628886658284197, 'batch_size': 7, 'num_knots': 41, 'num_epochs': 8}. Best is trial 35 with value: 0.2841481864452362.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:49:26,856]\u001b[0m Trial 36 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:49:50,087]\u001b[0m Trial 37 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:50:09,745]\u001b[0m Trial 38 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:54:31,498]\u001b[0m Trial 39 finished with value: 0.2877384126186371 and parameters: {'embedding_dim': 10, 'step_size': 0.029506218527871258, 'batch_size': 17, 'num_knots': 36, 'num_epochs': 8}. Best is trial 35 with value: 0.2841481864452362.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:56:01,040]\u001b[0m Trial 40 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:56:55,935]\u001b[0m Trial 41 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:07:52,600]\u001b[0m Trial 42 finished with value: 0.284860223531723 and parameters: {'embedding_dim': 10, 'step_size': 0.015781427803654453, 'batch_size': 8, 'num_knots': 44, 'num_epochs': 10}. Best is trial 35 with value: 0.2841481864452362.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:10:03,298]\u001b[0m Trial 43 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:12:15,603]\u001b[0m Trial 44 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:12:44,465]\u001b[0m Trial 45 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:13:28,673]\u001b[0m Trial 46 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:35:02,584]\u001b[0m Trial 47 finished with value: 0.2866511940956116 and parameters: {'embedding_dim': 9, 'step_size': 0.019474627680383216, 'batch_size': 4, 'num_knots': 46, 'num_epochs': 10}. Best is trial 35 with value: 0.2841481864452362.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:35:35,885]\u001b[0m Trial 48 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:48:37,070]\u001b[0m Trial 49 finished with value: 0.285453736782074 and parameters: {'embedding_dim': 10, 'step_size': 0.016255622278133826, 'batch_size': 8, 'num_knots': 38, 'num_epochs': 12}. Best is trial 35 with value: 0.2841481864452362.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:51:53,340]\u001b[0m Trial 50 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:53:20,770]\u001b[0m Trial 51 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:45:03,902]\u001b[0m Trial 52 finished with value: 0.28706908226013184 and parameters: {'embedding_dim': 10, 'step_size': 0.018306112716872682, 'batch_size': 2, 'num_knots': 38, 'num_epochs': 12}. Best is trial 35 with value: 0.2841481864452362.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:45:53,635]\u001b[0m Trial 53 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:58:42,623]\u001b[0m Trial 54 finished with value: 0.28566741943359375 and parameters: {'embedding_dim': 10, 'step_size': 0.02147460300997374, 'batch_size': 7, 'num_knots': 38, 'num_epochs': 10}. Best is trial 35 with value: 0.2841481864452362.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:59:38,518]\u001b[0m Trial 55 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:01:27,214]\u001b[0m Trial 56 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:05:33,988]\u001b[0m Trial 57 finished with value: 0.28643038868904114 and parameters: {'embedding_dim': 9, 'step_size': 0.02361589004001423, 'batch_size': 18, 'num_knots': 46, 'num_epochs': 8}. Best is trial 35 with value: 0.2841481864452362.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:06:10,450]\u001b[0m Trial 58 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:07:03,684]\u001b[0m Trial 59 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:07:28,041]\u001b[0m Trial 60 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:09:59,660]\u001b[0m Trial 61 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:11:06,941]\u001b[0m Trial 62 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:12:28,059]\u001b[0m Trial 63 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:14:42,563]\u001b[0m Trial 64 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:15:20,984]\u001b[0m Trial 65 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:15:40,772]\u001b[0m Trial 66 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:16:32,698]\u001b[0m Trial 67 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:17:15,265]\u001b[0m Trial 68 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:17:47,724]\u001b[0m Trial 69 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:18:59,351]\u001b[0m Trial 70 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:20:21,071]\u001b[0m Trial 71 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:21:05,797]\u001b[0m Trial 72 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:21:44,119]\u001b[0m Trial 73 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:22:41,889]\u001b[0m Trial 74 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:23:15,626]\u001b[0m Trial 75 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:23:56,832]\u001b[0m Trial 76 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:25:02,969]\u001b[0m Trial 77 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:25:39,164]\u001b[0m Trial 78 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:26:04,023]\u001b[0m Trial 79 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:26:22,837]\u001b[0m Trial 80 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:28:40,401]\u001b[0m Trial 81 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:30:21,125]\u001b[0m Trial 82 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:31:13,173]\u001b[0m Trial 83 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:33:48,348]\u001b[0m Trial 84 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:48:24,829]\u001b[0m Trial 85 finished with value: 0.2861608564853668 and parameters: {'embedding_dim': 10, 'step_size': 0.017533967555932624, 'batch_size': 6, 'num_knots': 44, 'num_epochs': 13}. Best is trial 35 with value: 0.2841481864452362.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:48:46,518]\u001b[0m Trial 86 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:49:08,314]\u001b[0m Trial 87 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:55:55,865]\u001b[0m Trial 88 finished with value: 0.2855015993118286 and parameters: {'embedding_dim': 10, 'step_size': 0.016438184664674376, 'batch_size': 14, 'num_knots': 39, 'num_epochs': 14}. Best is trial 35 with value: 0.2841481864452362.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:56:24,886]\u001b[0m Trial 89 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:56:51,441]\u001b[0m Trial 90 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:58:58,730]\u001b[0m Trial 91 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:59:41,106]\u001b[0m Trial 92 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 02:00:32,523]\u001b[0m Trial 93 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 02:01:10,092]\u001b[0m Trial 94 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 02:01:40,310]\u001b[0m Trial 95 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 02:19:21,485]\u001b[0m Trial 96 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 02:19:51,793]\u001b[0m Trial 97 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 02:21:00,049]\u001b[0m Trial 98 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 02:22:32,640]\u001b[0m Trial 99 pruned. \u001b[0m\n"
     ]
    }
   ],
   "source": [
    "study = optuna.create_study(study_name='splines',\n",
    "                            direction='minimize',\n",
    "                            sampler=optuna.samplers.TPESampler(seed=42))\n",
    "study.optimize(train_spline_objective, n_trials=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": false,
    "is_executing": true,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test loss: 0.2841481864452362\n",
      "Best hyperparameters: {'embedding_dim': 10, 'step_size': 0.01628886658284197, 'batch_size': 7, 'num_knots': 41, 'num_epochs': 8}\n"
     ]
    }
   ],
   "source": [
    "trial = study.best_trial\n",
    "\n",
    "print('Test loss: {}'.format(trial.value))\n",
    "print(\"Best hyperparameters: {}\".format(trial.params))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": false,
    "is_executing": true,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'embedding_dim': 10,\n",
       " 'step_size': 0.01628886658284197,\n",
       " 'batch_size': 7,\n",
       " 'num_knots': 41,\n",
       " 'num_epochs': 8}"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "study.best_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": false,
    "is_executing": true,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.28711041808128357"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_spline_ffm(**study.best_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [2:32:39<00:00, 457.99s/it]  \n"
     ]
    }
   ],
   "source": [
    "spline_losses = []\n",
    "for i in trange(20):\n",
    "    loss = train_spline_ffm(**study.best_params)\n",
    "    spline_losses.append(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.28441962599754333,\n",
       " 0.28561073541641235,\n",
       " 0.28526991605758667,\n",
       " 0.2852311432361603,\n",
       " 0.2871018052101135,\n",
       " 0.28534361720085144,\n",
       " 0.2867548167705536,\n",
       " 0.2859247922897339,\n",
       " 0.28690609335899353,\n",
       " 0.28653082251548767,\n",
       " 0.28614985942840576,\n",
       " 0.2878064811229706,\n",
       " 0.286538690328598,\n",
       " 0.2865554690361023,\n",
       " 0.28602224588394165,\n",
       " 0.2861751914024353,\n",
       " 0.2853567898273468,\n",
       " 0.28668129444122314,\n",
       " 0.2849292755126953,\n",
       " 0.2857675850391388]"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "spline_losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.2860538125038147, 0.0008072314680490749)"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(spline_losses), np.std(spline_losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": false,
    "is_executing": true,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "def train_bin_ffm(embedding_dim: int, step_size: float, batch_size: int,\n",
    "                  num_bins: int, bin_strategy: str, num_epochs: int,\n",
    "                  callback: Callable[[int, float], None]=None):\n",
    "    num_numerical_fields = tr_num_qs.shape[1]\n",
    "    num_numerical_embeddings = num_numerical_fields * num_bins\n",
    "    numerical_offsets = np.arange(0, num_numerical_fields) * num_bins\n",
    "\n",
    "    discretizer = KBinsDiscretizer(num_bins, encode='ordinal', strategy=bin_strategy, random_state=42)\n",
    "    discretizer.fit(tr_num)\n",
    "\n",
    "    tr_num_indices = discretizer.transform(tr_num)\n",
    "    tr_num_indices += np.tile(numerical_offsets, (tr_num.shape[0], 1))\n",
    "    tr_num_weights = np.ones_like(tr_num_indices)\n",
    "    tr_num_fields = np.tile(np.arange(0, num_numerical_fields), (tr_num.shape[0], 1))\n",
    "    tr_num_offsets = tr_num_fields.copy()\n",
    "\n",
    "    te_num_indices = discretizer.transform(te_num)\n",
    "    te_num_indices += np.tile(numerical_offsets, (te_num.shape[0], 1))\n",
    "    te_num_weights = np.ones_like(te_num_indices)\n",
    "    te_num_fields = np.tile(np.arange(0, num_numerical_fields), (te_num.shape[0], 1))\n",
    "    te_num_offsets = te_num_fields.copy()\n",
    "\n",
    "\n",
    "    num_fields = num_numerical_fields + num_cat_fields\n",
    "    num_embeddings = num_numerical_embeddings + num_cat_embeddings\n",
    "\n",
    "    tr_indices = np.concatenate([tr_cat_indices, tr_num_indices + num_cat_embeddings], axis=1)\n",
    "    tr_weights = np.concatenate([tr_cat_weights, tr_num_weights], axis=1)\n",
    "    tr_offsets = np.concatenate([tr_cat_offsets, tr_num_offsets + num_cat_fields], axis=1)\n",
    "    tr_fields = np.concatenate([tr_cat_fields, tr_num_fields + num_cat_fields], axis=1)\n",
    "\n",
    "    te_indices = np.concatenate([te_cat_indices, te_num_indices + num_cat_embeddings], axis=1)\n",
    "    te_weights = np.concatenate([te_cat_weights, te_num_weights], axis=1)\n",
    "    te_offsets = np.concatenate([te_cat_offsets, te_num_offsets + num_cat_fields], axis=1)\n",
    "    te_fields = np.concatenate([te_cat_fields, te_num_fields + num_cat_fields], axis=1)\n",
    "\n",
    "    train_ds = TensorDataset(\n",
    "        torch.tensor(tr_indices, dtype=torch.int64),\n",
    "        torch.tensor(tr_weights, dtype=torch.float32),\n",
    "        torch.tensor(tr_offsets, dtype=torch.int64),\n",
    "        torch.tensor(tr_fields, dtype=torch.int64),\n",
    "        torch.tensor(tr_target.values, dtype=torch.float32))\n",
    "\n",
    "    test_ds = TensorDataset(\n",
    "        torch.tensor(te_indices, dtype=torch.int64),\n",
    "        torch.tensor(te_weights, dtype=torch.float32),\n",
    "        torch.tensor(te_offsets, dtype=torch.int64),\n",
    "        torch.tensor(te_fields, dtype=torch.int64),\n",
    "        torch.tensor(te_target.values, dtype=torch.float32))\n",
    "\n",
    "    trainer = FFMTrainer(embedding_dim, step_size, batch_size, num_epochs, callback)\n",
    "    return trainer.train(num_fields, num_embeddings, train_ds, test_ds, torch.nn.BCEWithLogitsLoss(), device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": false,
    "is_executing": true,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "def test_bins_objective(trial: optuna.Trial):\n",
    "    embedding_dim = trial.suggest_int('embedding_dim', 1, 10)\n",
    "    step_size = trial.suggest_float('step_size', 1e-2, 0.5, log=True)\n",
    "    batch_size = trial.suggest_int('batch_size', 2, 32)\n",
    "    num_bins = trial.suggest_int('num_bins', 2, 100)\n",
    "    bin_strategy = trial.suggest_categorical('bin_strategy', ['uniform', 'quantile'])\n",
    "    num_epochs = trial.suggest_int('num_epochs', 5, 15)\n",
    "\n",
    "    def callback(epoch: int, loss: float):\n",
    "        trial.report(loss, epoch)\n",
    "        if trial.should_prune():\n",
    "            raise optuna.TrialPruned()\n",
    "\n",
    "    return train_bin_ffm(embedding_dim, step_size, batch_size, num_bins, bin_strategy, num_epochs,\n",
    "                         callback=callback)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": false,
    "is_executing": true,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m[I 2023-05-17 02:30:20,329]\u001b[0m A new study created in memory with name: bins\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 02:31:48,083]\u001b[0m Trial 0 finished with value: 0.36048880219459534 and parameters: {'embedding_dim': 4, 'step_size': 0.4123206532618726, 'batch_size': 24, 'num_bins': 61, 'bin_strategy': 'uniform', 'num_epochs': 5}. Best is trial 0 with value: 0.36048880219459534.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 02:33:58,528]\u001b[0m Trial 1 finished with value: 0.35097000002861023 and parameters: {'embedding_dim': 9, 'step_size': 0.10502105436744279, 'batch_size': 23, 'num_bins': 4, 'bin_strategy': 'uniform', 'num_epochs': 7}. Best is trial 1 with value: 0.35097000002861023.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 02:40:40,921]\u001b[0m Trial 2 finished with value: 0.3105776607990265 and parameters: {'embedding_dim': 2, 'step_size': 0.020492680115417352, 'batch_size': 11, 'num_bins': 53, 'bin_strategy': 'uniform', 'num_epochs': 11}. Best is trial 2 with value: 0.3105776607990265.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 02:45:58,507]\u001b[0m Trial 3 finished with value: 0.30633899569511414 and parameters: {'embedding_dim': 2, 'step_size': 0.03135775732257745, 'batch_size': 13, 'num_bins': 47, 'bin_strategy': 'uniform', 'num_epochs': 10}. Best is trial 3 with value: 0.30633899569511414.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 2 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 3 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 4 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 5 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 02:51:06,519]\u001b[0m Trial 4 finished with value: 0.32146045565605164 and parameters: {'embedding_dim': 6, 'step_size': 0.011992724522955167, 'batch_size': 20, 'num_bins': 18, 'bin_strategy': 'quantile', 'num_epochs': 15}. Best is trial 3 with value: 0.30633899569511414.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 03:04:29,678]\u001b[0m Trial 5 finished with value: 0.3089120388031006 and parameters: {'embedding_dim': 9, 'step_size': 0.032925293631105246, 'batch_size': 5, 'num_bins': 69, 'bin_strategy': 'uniform', 'num_epochs': 10}. Best is trial 3 with value: 0.30633899569511414.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 0 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 2 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 3 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 4 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 5 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 03:08:36,422]\u001b[0m Trial 6 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 03:09:57,126]\u001b[0m Trial 7 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 03:12:11,602]\u001b[0m Trial 8 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 03:18:03,071]\u001b[0m Trial 9 finished with value: 0.3057221472263336 and parameters: {'embedding_dim': 4, 'step_size': 0.030012301808980443, 'batch_size': 18, 'num_bins': 15, 'bin_strategy': 'uniform', 'num_epochs': 15}. Best is trial 9 with value: 0.3057221472263336.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 3 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 4 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 5 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 03:18:16,684]\u001b[0m Trial 10 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 03:22:39,833]\u001b[0m Trial 11 finished with value: 0.30919402837753296 and parameters: {'embedding_dim': 4, 'step_size': 0.04642647300540488, 'batch_size': 14, 'num_bins': 30, 'bin_strategy': 'uniform', 'num_epochs': 9}. Best is trial 9 with value: 0.3057221472263336.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 03:26:09,676]\u001b[0m Trial 12 finished with value: 0.3141358494758606 and parameters: {'embedding_dim': 4, 'step_size': 0.08128484055802063, 'batch_size': 16, 'num_bins': 36, 'bin_strategy': 'uniform', 'num_epochs': 8}. Best is trial 9 with value: 0.3057221472263336.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 03:26:54,554]\u001b[0m Trial 13 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 2 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 3 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 4 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 5 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 03:28:24,068]\u001b[0m Trial 14 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 03:34:47,198]\u001b[0m Trial 15 finished with value: 0.3047816753387451 and parameters: {'embedding_dim': 3, 'step_size': 0.044922957549390394, 'batch_size': 14, 'num_bins': 46, 'bin_strategy': 'uniform', 'num_epochs': 13}. Best is trial 15 with value: 0.3047816753387451.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 03:36:42,706]\u001b[0m Trial 16 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 03:42:06,253]\u001b[0m Trial 17 finished with value: 0.32851821184158325 and parameters: {'embedding_dim': 7, 'step_size': 0.051868823187409284, 'batch_size': 17, 'num_bins': 80, 'bin_strategy': 'uniform', 'num_epochs': 13}. Best is trial 15 with value: 0.3047816753387451.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 0 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 2 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 3 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 4 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 5 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 03:43:04,642]\u001b[0m Trial 18 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 03:43:24,763]\u001b[0m Trial 19 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 03:48:35,377]\u001b[0m Trial 20 finished with value: 0.3056541383266449 and parameters: {'embedding_dim': 10, 'step_size': 0.020693482079285105, 'batch_size': 16, 'num_bins': 38, 'bin_strategy': 'uniform', 'num_epochs': 12}. Best is trial 15 with value: 0.3047816753387451.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 03:53:51,363]\u001b[0m Trial 21 finished with value: 0.3041217029094696 and parameters: {'embedding_dim': 10, 'step_size': 0.017182589531927382, 'batch_size': 16, 'num_bins': 40, 'bin_strategy': 'uniform', 'num_epochs': 12}. Best is trial 21 with value: 0.3041217029094696.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 03:59:26,777]\u001b[0m Trial 22 finished with value: 0.30423784255981445 and parameters: {'embedding_dim': 10, 'step_size': 0.015889555648010128, 'batch_size': 15, 'num_bins': 44, 'bin_strategy': 'uniform', 'num_epochs': 12}. Best is trial 21 with value: 0.3041217029094696.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 03:59:58,972]\u001b[0m Trial 23 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:00:28,524]\u001b[0m Trial 24 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:01:15,985]\u001b[0m Trial 25 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 0 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 2 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 3 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 4 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 5 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 04:01:46,624]\u001b[0m Trial 26 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:08:06,495]\u001b[0m Trial 27 finished with value: 0.30776745080947876 and parameters: {'embedding_dim': 10, 'step_size': 0.02266745511267857, 'batch_size': 15, 'num_bins': 74, 'bin_strategy': 'uniform', 'num_epochs': 14}. Best is trial 21 with value: 0.3041217029094696.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:11:50,365]\u001b[0m Trial 28 finished with value: 0.30118101835250854 and parameters: {'embedding_dim': 8, 'step_size': 0.013689434707013629, 'batch_size': 23, 'num_bins': 56, 'bin_strategy': 'uniform', 'num_epochs': 12}. Best is trial 28 with value: 0.30118101835250854.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:14:59,191]\u001b[0m Trial 29 finished with value: 0.29582858085632324 and parameters: {'embedding_dim': 8, 'step_size': 0.013772781946733741, 'batch_size': 25, 'num_bins': 58, 'bin_strategy': 'uniform', 'num_epochs': 11}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:15:16,142]\u001b[0m Trial 30 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:18:41,663]\u001b[0m Trial 31 finished with value: 0.3001174330711365 and parameters: {'embedding_dim': 9, 'step_size': 0.01600497838064572, 'batch_size': 23, 'num_bins': 58, 'bin_strategy': 'uniform', 'num_epochs': 11}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:18:59,106]\u001b[0m Trial 32 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:19:10,543]\u001b[0m Trial 33 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:19:27,642]\u001b[0m Trial 34 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:22:01,930]\u001b[0m Trial 35 finished with value: 0.30076929926872253 and parameters: {'embedding_dim': 9, 'step_size': 0.017679085410782114, 'batch_size': 28, 'num_bins': 50, 'bin_strategy': 'uniform', 'num_epochs': 10}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:24:15,945]\u001b[0m Trial 36 finished with value: 0.3045254051685333 and parameters: {'embedding_dim': 9, 'step_size': 0.024278676446518384, 'batch_size': 29, 'num_bins': 52, 'bin_strategy': 'uniform', 'num_epochs': 10}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 0 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 2 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 3 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 4 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 5 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 04:24:30,824]\u001b[0m Trial 37 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:27:24,617]\u001b[0m Trial 38 finished with value: 0.3024236261844635 and parameters: {'embedding_dim': 8, 'step_size': 0.018602665611405316, 'batch_size': 25, 'num_bins': 71, 'bin_strategy': 'uniform', 'num_epochs': 10}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:27:48,650]\u001b[0m Trial 39 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:28:26,064]\u001b[0m Trial 40 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:31:18,026]\u001b[0m Trial 41 finished with value: 0.2989811897277832 and parameters: {'embedding_dim': 8, 'step_size': 0.016781120570308682, 'batch_size': 25, 'num_bins': 72, 'bin_strategy': 'uniform', 'num_epochs': 10}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:31:32,313]\u001b[0m Trial 42 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:34:19,356]\u001b[0m Trial 43 finished with value: 0.3026280105113983 and parameters: {'embedding_dim': 9, 'step_size': 0.025610956679079467, 'batch_size': 26, 'num_bins': 81, 'bin_strategy': 'uniform', 'num_epochs': 10}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:34:52,190]\u001b[0m Trial 44 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:37:51,600]\u001b[0m Trial 45 finished with value: 0.30039864778518677 and parameters: {'embedding_dim': 9, 'step_size': 0.013926851210708779, 'batch_size': 24, 'num_bins': 69, 'bin_strategy': 'uniform', 'num_epochs': 10}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 0 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 2 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 3 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 4 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 5 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 04:38:09,223]\u001b[0m Trial 46 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:38:25,507]\u001b[0m Trial 47 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:38:39,778]\u001b[0m Trial 48 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:38:56,972]\u001b[0m Trial 49 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 0 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 2 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 3 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 4 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 5 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 04:39:17,432]\u001b[0m Trial 50 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:39:55,277]\u001b[0m Trial 51 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:40:41,806]\u001b[0m Trial 52 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:41:04,478]\u001b[0m Trial 53 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:41:24,107]\u001b[0m Trial 54 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:41:41,254]\u001b[0m Trial 55 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:41:58,017]\u001b[0m Trial 56 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:42:10,630]\u001b[0m Trial 57 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:42:26,297]\u001b[0m Trial 58 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:42:46,767]\u001b[0m Trial 59 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:46:59,616]\u001b[0m Trial 60 finished with value: 0.30149054527282715 and parameters: {'embedding_dim': 9, 'step_size': 0.016542823348484607, 'batch_size': 18, 'num_bins': 100, 'bin_strategy': 'uniform', 'num_epochs': 11}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:51:16,197]\u001b[0m Trial 61 finished with value: 0.298368364572525 and parameters: {'embedding_dim': 9, 'step_size': 0.016604846417793938, 'batch_size': 18, 'num_bins': 92, 'bin_strategy': 'uniform', 'num_epochs': 11}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:51:34,113]\u001b[0m Trial 62 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:51:53,620]\u001b[0m Trial 63 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:37:06,903]\u001b[0m Trial 64 finished with value: 0.30193400382995605 and parameters: {'embedding_dim': 9, 'step_size': 0.016685807356309056, 'batch_size': 2, 'num_bins': 65, 'bin_strategy': 'uniform', 'num_epochs': 13}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:37:25,002]\u001b[0m Trial 65 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 0 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 2 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 3 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 4 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 5 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 05:37:53,258]\u001b[0m Trial 66 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:38:13,413]\u001b[0m Trial 67 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:38:35,738]\u001b[0m Trial 68 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:39:00,269]\u001b[0m Trial 69 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:42:54,379]\u001b[0m Trial 70 finished with value: 0.29890763759613037 and parameters: {'embedding_dim': 10, 'step_size': 0.016526218941787352, 'batch_size': 23, 'num_bins': 72, 'bin_strategy': 'uniform', 'num_epochs': 10}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:43:17,916]\u001b[0m Trial 71 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:43:42,518]\u001b[0m Trial 72 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:44:05,204]\u001b[0m Trial 73 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:48:21,735]\u001b[0m Trial 74 finished with value: 0.3004951477050781 and parameters: {'embedding_dim': 9, 'step_size': 0.014788329393552348, 'batch_size': 21, 'num_bins': 59, 'bin_strategy': 'uniform', 'num_epochs': 10}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:49:22,809]\u001b[0m Trial 75 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:53:11,710]\u001b[0m Trial 76 finished with value: 0.30084165930747986 and parameters: {'embedding_dim': 9, 'step_size': 0.015510325380640089, 'batch_size': 21, 'num_bins': 71, 'bin_strategy': 'uniform', 'num_epochs': 9}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:55:50,704]\u001b[0m Trial 77 finished with value: 0.3046785593032837 and parameters: {'embedding_dim': 10, 'step_size': 0.024487933212022406, 'batch_size': 26, 'num_bins': 67, 'bin_strategy': 'uniform', 'num_epochs': 10}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 0 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 2 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 3 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 4 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 5 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 05:56:15,031]\u001b[0m Trial 78 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:56:32,604]\u001b[0m Trial 79 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:56:52,941]\u001b[0m Trial 80 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:59:25,960]\u001b[0m Trial 81 finished with value: 0.29877328872680664 and parameters: {'embedding_dim': 9, 'step_size': 0.015665965782853387, 'batch_size': 21, 'num_bins': 70, 'bin_strategy': 'uniform', 'num_epochs': 8}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:59:59,871]\u001b[0m Trial 82 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:00:19,755]\u001b[0m Trial 83 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:00:42,196]\u001b[0m Trial 84 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:01:01,542]\u001b[0m Trial 85 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:03:51,558]\u001b[0m Trial 86 finished with value: 0.30616626143455505 and parameters: {'embedding_dim': 10, 'step_size': 0.02268770472091906, 'batch_size': 27, 'num_bins': 67, 'bin_strategy': 'uniform', 'num_epochs': 10}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:04:15,445]\u001b[0m Trial 87 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:04:34,359]\u001b[0m Trial 88 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:08:41,776]\u001b[0m Trial 89 finished with value: 0.30008405447006226 and parameters: {'embedding_dim': 9, 'step_size': 0.017192731186566303, 'batch_size': 20, 'num_bins': 83, 'bin_strategy': 'uniform', 'num_epochs': 9}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 0 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 2 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 3 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 4 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 5 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 06:09:13,735]\u001b[0m Trial 90 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:09:44,235]\u001b[0m Trial 91 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:10:05,698]\u001b[0m Trial 92 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:11:53,132]\u001b[0m Trial 93 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:12:18,070]\u001b[0m Trial 94 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:16:39,839]\u001b[0m Trial 95 finished with value: 0.2962709665298462 and parameters: {'embedding_dim': 10, 'step_size': 0.01851901416896681, 'batch_size': 21, 'num_bins': 97, 'bin_strategy': 'uniform', 'num_epochs': 10}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:19:44,587]\u001b[0m Trial 96 finished with value: 0.29874420166015625 and parameters: {'embedding_dim': 10, 'step_size': 0.016062342363613872, 'batch_size': 21, 'num_bins': 100, 'bin_strategy': 'uniform', 'num_epochs': 7}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:22:57,478]\u001b[0m Trial 97 finished with value: 0.303432822227478 and parameters: {'embedding_dim': 10, 'step_size': 0.02685173836114753, 'batch_size': 20, 'num_bins': 100, 'bin_strategy': 'uniform', 'num_epochs': 7}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:25:31,353]\u001b[0m Trial 98 finished with value: 0.29946252703666687 and parameters: {'embedding_dim': 10, 'step_size': 0.018564930846237025, 'batch_size': 24, 'num_bins': 93, 'bin_strategy': 'uniform', 'num_epochs': 7}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:28:52,888]\u001b[0m Trial 99 finished with value: 0.29889658093452454 and parameters: {'embedding_dim': 10, 'step_size': 0.019834766929900396, 'batch_size': 19, 'num_bins': 97, 'bin_strategy': 'uniform', 'num_epochs': 7}. Best is trial 29 with value: 0.29582858085632324.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "study_bins = optuna.create_study(study_name='bins',\n",
    "                                 direction='minimize',\n",
    "                                 sampler=optuna.samplers.TPESampler(seed=42))\n",
    "study_bins.optimize(test_bins_objective, n_trials=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": false,
    "is_executing": true,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'embedding_dim': 8,\n",
       " 'step_size': 0.013772781946733741,\n",
       " 'batch_size': 25,\n",
       " 'num_bins': 58,\n",
       " 'bin_strategy': 'uniform',\n",
       " 'num_epochs': 11}"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "study_bins.best_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": false,
    "is_executing": true,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test loss: 0.29582858085632324\n",
      "Best hyperparameters: {'embedding_dim': 8, 'step_size': 0.013772781946733741, 'batch_size': 25, 'num_bins': 58, 'bin_strategy': 'uniform', 'num_epochs': 11}\n"
     ]
    }
   ],
   "source": [
    "trial = study_bins.best_trial\n",
    "\n",
    "print('Test loss: {}'.format(trial.value))\n",
    "print(\"Best hyperparameters: {}\".format(trial.params))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "collapsed": false,
    "is_executing": true,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.2982105612754822"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_bin_ffm(**study_bins.best_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [1:02:58<00:00, 188.92s/it]\n"
     ]
    }
   ],
   "source": [
    "bin_losses = []\n",
    "for i in trange(20):\n",
    "    loss = train_bin_ffm(**study_bins.best_params)\n",
    "    bin_losses.append(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.2994536757469177,\n",
       " 0.3009205758571625,\n",
       " 0.296610563993454,\n",
       " 0.2973214089870453,\n",
       " 0.2992994785308838,\n",
       " 0.2983883023262024,\n",
       " 0.29931512475013733,\n",
       " 0.2996693253517151,\n",
       " 0.2996058762073517,\n",
       " 0.2982383966445923,\n",
       " 0.30039355158805847,\n",
       " 0.3006632924079895,\n",
       " 0.29591503739356995,\n",
       " 0.2998608648777008,\n",
       " 0.2998100817203522,\n",
       " 0.2978292405605316,\n",
       " 0.29889991879463196,\n",
       " 0.2984970808029175,\n",
       " 0.29809093475341797,\n",
       " 0.3017430901527405]"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bin_losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.2990262910723686, 0.001413742829358982)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(bin_losses), np.std(bin_losses)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
