{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:50:53.880580Z",
     "start_time": "2023-05-14T19:50:53.821459Z"
    }
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.insert(0, \"../utils\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:50:53.880825Z",
     "start_time": "2023-05-14T19:50:53.831491Z"
    }
   },
   "outputs": [],
   "source": [
    "import sklearn.datasets as skds\n",
    "from sklearn.preprocessing import QuantileTransformer, KBinsDiscretizer, StandardScaler\n",
    "from sklearn.model_selection import train_test_split\n",
    "from weighted_fm import WeightedFFM, WeightedFM\n",
    "from trainers import FFMTrainer\n",
    "from transformation import BSplineTransformer, spline_transform_dataset\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "import math\n",
    "import optuna\n",
    "import optuna.samplers\n",
    "from typing import Callable\n",
    "from tqdm import tqdm, trange"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:50:53.880991Z",
     "start_time": "2023-05-14T19:50:53.835514Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda:0\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:50:53.881070Z",
     "start_time": "2023-05-14T19:50:53.842689Z"
    }
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(42)\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tarfile\n",
    "import joblib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tarfile.open(mode=\"r:gz\", name=\"../data/cal_housing.tgz\") as f:\n",
    "    cal_housing = np.loadtxt(\n",
    "        f.extractfile(\"CaliforniaHousing/cal_housing.data\"), delimiter=\",\"\n",
    "    )\n",
    "    # Columns are not in the same order compared to the previous\n",
    "    # URL resource on lib.stat.cmu.edu\n",
    "    columns_index = [8, 7, 2, 3, 4, 5, 6, 1, 0]\n",
    "    cal_housing = cal_housing[:, columns_index]\n",
    "\n",
    "    joblib.dump(cal_housing, \"../data/cal_housing_py3.pkz\", compress=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:50:53.881631Z",
     "start_time": "2023-05-14T19:50:53.858870Z"
    }
   },
   "outputs": [],
   "source": [
    "ds = skds.fetch_california_housing(data_home=\"../data\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:50:53.896771Z",
     "start_time": "2023-05-14T19:50:53.885514Z"
    }
   },
   "outputs": [],
   "source": [
    "X_train, X_test, y_train, y_test = train_test_split(ds['data'], ds['target'], test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_scaler = StandardScaler()\n",
    "y_train = target_scaler.fit_transform(y_train.reshape(-1, 1)).reshape(-1)\n",
    "y_test = target_scaler.transform(y_test.reshape(-1, 1)).reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:50:55.384716Z",
     "start_time": "2023-05-14T19:50:53.898710Z"
    }
   },
   "outputs": [],
   "source": [
    "quant_transform = QuantileTransformer(output_distribution='uniform',\n",
    "                                      n_quantiles=10000,\n",
    "                                      subsample=len(X_train),\n",
    "                                      random_state=42)\n",
    "X_train_qs = quant_transform.fit_transform(X_train)\n",
    "X_test_qs = quant_transform.transform(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:50:55.393511Z",
     "start_time": "2023-05-14T19:50:55.391799Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "def train_spline_ffm(embedding_dim: int, step_size: float, batch_size: int, num_knots: int, num_epochs: int,\n",
    "                     callback: Callable[[int, float], None]=None):\n",
    "    bs = BSplineTransformer(num_knots, 3)\n",
    "    train_indices, train_weights, train_offsets, train_fields = spline_transform_dataset(X_train_qs, bs)\n",
    "    test_indices, test_weights, test_offsets, test_fields = spline_transform_dataset(X_test_qs, bs)\n",
    "    num_fields = ds['data'].shape[1]\n",
    "    num_embeddings = bs.basis_size() * num_fields\n",
    "\n",
    "    train_ds = TensorDataset(\n",
    "        torch.tensor(train_indices, dtype=torch.int64),\n",
    "        torch.tensor(train_weights, dtype=torch.float32),\n",
    "        torch.tensor(train_offsets, dtype=torch.int64),\n",
    "        torch.tensor(train_fields, dtype=torch.int64),\n",
    "        torch.tensor(y_train, dtype=torch.float32))\n",
    "    test_ds = TensorDataset(\n",
    "        torch.tensor(test_indices, dtype=torch.int64),\n",
    "        torch.tensor(test_weights, dtype=torch.float32),\n",
    "        torch.tensor(test_offsets, dtype=torch.int64),\n",
    "        torch.tensor(test_fields, dtype=torch.int64),\n",
    "        torch.tensor(y_test, dtype=torch.float32))\n",
    "\n",
    "    criterion = torch.nn.MSELoss()\n",
    "    trainer = FFMTrainer(embedding_dim, step_size, batch_size, num_epochs, callback)\n",
    "    return trainer.train(num_fields, num_embeddings, train_ds, test_ds, criterion, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:50:55.402879Z",
     "start_time": "2023-05-14T19:50:55.395568Z"
    }
   },
   "outputs": [],
   "source": [
    "def train_spline_objective(trial: optuna.Trial):\n",
    "    embedding_dim = trial.suggest_int('embedding_dim', 1, 10)\n",
    "    step_size = trial.suggest_float('step_size', 1e-2, 0.5, log=True)\n",
    "    batch_size = trial.suggest_int('batch_size', 2, 32)\n",
    "    num_knots = trial.suggest_int('num_knots', 3, 48)\n",
    "    num_epochs = trial.suggest_int('num_epochs', 5, 15)\n",
    "\n",
    "    def callback(epoch: int, loss: float):\n",
    "        rmse = math.sqrt(loss)\n",
    "        trial.report(rmse, epoch)\n",
    "        if trial.should_prune():\n",
    "            raise optuna.TrialPruned()\n",
    "\n",
    "    mse = train_spline_ffm(embedding_dim, step_size, batch_size, num_knots, num_epochs,\n",
    "                           callback=callback)\n",
    "    return math.sqrt(mse)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:52:10.650270Z",
     "start_time": "2023-05-14T19:50:55.403316Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m[I 2023-05-16 19:06:04,290]\u001b[0m A new study created in memory with name: splines\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:07:06,616]\u001b[0m Trial 0 finished with value: 0.4704895326349986 and parameters: {'embedding_dim': 4, 'step_size': 0.4123206532618726, 'batch_size': 24, 'num_knots': 30, 'num_epochs': 6}. Best is trial 0 with value: 0.4704895326349986.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:08:50,430]\u001b[0m Trial 1 finished with value: 0.5324508054601986 and parameters: {'embedding_dim': 2, 'step_size': 0.012551115172973842, 'batch_size': 28, 'num_knots': 30, 'num_epochs': 12}. Best is trial 0 with value: 0.4704895326349986.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:09:57,194]\u001b[0m Trial 2 finished with value: 0.45800390831558263 and parameters: {'embedding_dim': 1, 'step_size': 0.44447541666908114, 'batch_size': 27, 'num_knots': 12, 'num_epochs': 7}. Best is trial 2 with value: 0.45800390831558263.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:11:51,841]\u001b[0m Trial 3 finished with value: 0.46901124985568104 and parameters: {'embedding_dim': 2, 'step_size': 0.0328774741399112, 'batch_size': 18, 'num_knots': 22, 'num_epochs': 8}. Best is trial 2 with value: 0.45800390831558263.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:15:35,305]\u001b[0m Trial 4 finished with value: 0.4475541496014591 and parameters: {'embedding_dim': 7, 'step_size': 0.017258215396625, 'batch_size': 11, 'num_knots': 19, 'num_epochs': 10}. Best is trial 4 with value: 0.4475541496014591.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:16:53,720]\u001b[0m Trial 5 finished with value: 0.45202152127010237 and parameters: {'embedding_dim': 8, 'step_size': 0.021839352923182977, 'batch_size': 17, 'num_knots': 30, 'num_epochs': 5}. Best is trial 4 with value: 0.4475541496014591.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:20:13,198]\u001b[0m Trial 6 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:26:56,554]\u001b[0m Trial 7 finished with value: 0.4374097662924263 and parameters: {'embedding_dim': 9, 'step_size': 0.032925293631105246, 'batch_size': 5, 'num_knots': 34, 'num_epochs': 9}. Best is trial 7 with value: 0.4374097662924263.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:27:54,756]\u001b[0m Trial 8 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:29:23,192]\u001b[0m Trial 9 finished with value: 0.4333838736728957 and parameters: {'embedding_dim': 7, 'step_size': 0.033852267834519785, 'batch_size': 18, 'num_knots': 28, 'num_epochs': 7}. Best is trial 9 with value: 0.4333838736728957.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:29:37,723]\u001b[0m Trial 10 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:32:24,187]\u001b[0m Trial 11 finished with value: 0.4360523626300907 and parameters: {'embedding_dim': 10, 'step_size': 0.04337690983089577, 'batch_size': 12, 'num_knots': 37, 'num_epochs': 9}. Best is trial 9 with value: 0.4333838736728957.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:34:57,499]\u001b[0m Trial 12 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:35:09,175]\u001b[0m Trial 13 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:37:00,063]\u001b[0m Trial 14 finished with value: 0.44198953018367365 and parameters: {'embedding_dim': 5, 'step_size': 0.04195290839392635, 'batch_size': 10, 'num_knots': 15, 'num_epochs': 5}. Best is trial 9 with value: 0.4333838736728957.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:40:59,893]\u001b[0m Trial 15 finished with value: 0.44867919363263115 and parameters: {'embedding_dim': 8, 'step_size': 0.07625186899625158, 'batch_size': 13, 'num_knots': 25, 'num_epochs': 15}. Best is trial 9 with value: 0.4333838736728957.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:44:05,914]\u001b[0m Trial 16 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:44:19,552]\u001b[0m Trial 17 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:44:29,726]\u001b[0m Trial 18 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:46:10,377]\u001b[0m Trial 19 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:50:45,295]\u001b[0m Trial 20 finished with value: 0.4304292518785922 and parameters: {'embedding_dim': 7, 'step_size': 0.04499014800396729, 'batch_size': 7, 'num_knots': 25, 'num_epochs': 9}. Best is trial 20 with value: 0.4304292518785922.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:55:19,993]\u001b[0m Trial 21 finished with value: 0.4317104576548742 and parameters: {'embedding_dim': 7, 'step_size': 0.04748562470151251, 'batch_size': 7, 'num_knots': 25, 'num_epochs': 9}. Best is trial 20 with value: 0.4304292518785922.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:59:56,006]\u001b[0m Trial 22 finished with value: 0.43209864623593314 and parameters: {'embedding_dim': 7, 'step_size': 0.05467652110187646, 'batch_size': 7, 'num_knots': 25, 'num_epochs': 8}. Best is trial 20 with value: 0.4304292518785922.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:05:31,702]\u001b[0m Trial 23 finished with value: 0.43518819708811474 and parameters: {'embedding_dim': 5, 'step_size': 0.061380216632269086, 'batch_size': 7, 'num_knots': 23, 'num_epochs': 10}. Best is trial 20 with value: 0.4304292518785922.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:21:14,196]\u001b[0m Trial 24 finished with value: 0.4282563738885848 and parameters: {'embedding_dim': 8, 'step_size': 0.050729607059104745, 'batch_size': 2, 'num_knots': 18, 'num_epochs': 8}. Best is trial 24 with value: 0.4282563738885848.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:22:23,305]\u001b[0m Trial 25 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:28:34,067]\u001b[0m Trial 26 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:32:23,398]\u001b[0m Trial 27 finished with value: 0.4342443722594684 and parameters: {'embedding_dim': 8, 'step_size': 0.04633322962813203, 'batch_size': 6, 'num_knots': 20, 'num_epochs': 6}. Best is trial 24 with value: 0.4282563738885848.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:36:20,024]\u001b[0m Trial 28 finished with value: 0.4303967258133841 and parameters: {'embedding_dim': 4, 'step_size': 0.10498851807328576, 'batch_size': 9, 'num_knots': 17, 'num_epochs': 8}. Best is trial 24 with value: 0.4282563738885848.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:36:51,318]\u001b[0m Trial 29 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:54:01,502]\u001b[0m Trial 30 finished with value: 0.432384126802336 and parameters: {'embedding_dim': 3, 'step_size': 0.2224714974302765, 'batch_size': 2, 'num_knots': 17, 'num_epochs': 8}. Best is trial 24 with value: 0.4282563738885848.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:57:04,089]\u001b[0m Trial 31 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:57:58,437]\u001b[0m Trial 32 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:03:33,448]\u001b[0m Trial 33 finished with value: 0.4318301617578173 and parameters: {'embedding_dim': 6, 'step_size': 0.05610785831672759, 'batch_size': 6, 'num_knots': 28, 'num_epochs': 8}. Best is trial 24 with value: 0.4282563738885848.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:08:11,614]\u001b[0m Trial 34 finished with value: 0.4385583881926396 and parameters: {'embedding_dim': 7, 'step_size': 0.10397008430678391, 'batch_size': 8, 'num_knots': 18, 'num_epochs': 10}. Best is trial 24 with value: 0.4282563738885848.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:08:32,897]\u001b[0m Trial 35 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:09:17,581]\u001b[0m Trial 36 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:10:13,339]\u001b[0m Trial 37 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:10:28,002]\u001b[0m Trial 38 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:13:17,265]\u001b[0m Trial 39 finished with value: 0.4332768080903596 and parameters: {'embedding_dim': 7, 'step_size': 0.049927006498105705, 'batch_size': 10, 'num_knots': 31, 'num_epochs': 8}. Best is trial 24 with value: 0.4282563738885848.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:13:27,097]\u001b[0m Trial 40 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:14:04,662]\u001b[0m Trial 41 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:14:37,254]\u001b[0m Trial 42 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:22:56,495]\u001b[0m Trial 43 finished with value: 0.4355067418514952 and parameters: {'embedding_dim': 7, 'step_size': 0.06736094208217702, 'batch_size': 4, 'num_knots': 31, 'num_epochs': 9}. Best is trial 24 with value: 0.4282563738885848.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:25:10,548]\u001b[0m Trial 44 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:34:48,334]\u001b[0m Trial 45 finished with value: 0.428379686852566 and parameters: {'embedding_dim': 6, 'step_size': 0.05663408617479354, 'batch_size': 4, 'num_knots': 26, 'num_epochs': 9}. Best is trial 24 with value: 0.4282563738885848.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:35:53,613]\u001b[0m Trial 46 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:37:25,006]\u001b[0m Trial 47 finished with value: 0.4282606362493427 and parameters: {'embedding_dim': 7, 'step_size': 0.07608335371533785, 'batch_size': 28, 'num_knots': 26, 'num_epochs': 9}. Best is trial 24 with value: 0.4282563738885848.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:37:35,335]\u001b[0m Trial 48 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:38:55,705]\u001b[0m Trial 49 finished with value: 0.4345063751848249 and parameters: {'embedding_dim': 8, 'step_size': 0.0681880972757457, 'batch_size': 31, 'num_knots': 23, 'num_epochs': 9}. Best is trial 24 with value: 0.4282563738885848.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:41:42,130]\u001b[0m Trial 50 finished with value: 0.44851676160691284 and parameters: {'embedding_dim': 7, 'step_size': 0.09456589098657536, 'batch_size': 20, 'num_knots': 26, 'num_epochs': 12}. Best is trial 24 with value: 0.4282563738885848.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:41:53,916]\u001b[0m Trial 51 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:45:02,791]\u001b[0m Trial 52 finished with value: 0.42911772664965314 and parameters: {'embedding_dim': 6, 'step_size': 0.0598936840450053, 'batch_size': 13, 'num_knots': 21, 'num_epochs': 10}. Best is trial 24 with value: 0.4282563738885848.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:45:19,028]\u001b[0m Trial 53 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:45:39,786]\u001b[0m Trial 54 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:45:56,382]\u001b[0m Trial 55 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:46:08,651]\u001b[0m Trial 56 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:46:34,431]\u001b[0m Trial 57 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:46:42,515]\u001b[0m Trial 58 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:50:12,016]\u001b[0m Trial 59 finished with value: 0.4367239404141295 and parameters: {'embedding_dim': 8, 'step_size': 0.10688913438019201, 'batch_size': 12, 'num_knots': 15, 'num_epochs': 11}. Best is trial 24 with value: 0.4282563738885848.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:51:25,554]\u001b[0m Trial 60 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:56:17,668]\u001b[0m Trial 61 finished with value: 0.4293793353510158 and parameters: {'embedding_dim': 7, 'step_size': 0.052279480161766326, 'batch_size': 7, 'num_knots': 25, 'num_epochs': 9}. Best is trial 24 with value: 0.4282563738885848.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:00:51,493]\u001b[0m Trial 62 finished with value: 0.42994092923668487 and parameters: {'embedding_dim': 7, 'step_size': 0.05748379333415699, 'batch_size': 8, 'num_knots': 21, 'num_epochs': 10}. Best is trial 24 with value: 0.4282563738885848.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:05:34,761]\u001b[0m Trial 63 finished with value: 0.42706043678253136 and parameters: {'embedding_dim': 7, 'step_size': 0.05868656247354761, 'batch_size': 8, 'num_knots': 22, 'num_epochs': 10}. Best is trial 63 with value: 0.42706043678253136.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:10:31,959]\u001b[0m Trial 64 finished with value: 0.43217209422359965 and parameters: {'embedding_dim': 7, 'step_size': 0.05987477205374452, 'batch_size': 8, 'num_knots': 21, 'num_epochs': 11}. Best is trial 63 with value: 0.42706043678253136.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:17:56,158]\u001b[0m Trial 65 finished with value: 0.4350829454793019 and parameters: {'embedding_dim': 8, 'step_size': 0.05256357897612862, 'batch_size': 5, 'num_knots': 27, 'num_epochs': 10}. Best is trial 63 with value: 0.42706043678253136.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:30:35,599]\u001b[0m Trial 66 finished with value: 0.434589308419973 and parameters: {'embedding_dim': 7, 'step_size': 0.07823068649760494, 'batch_size': 3, 'num_knots': 22, 'num_epochs': 10}. Best is trial 63 with value: 0.42706043678253136.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:31:27,696]\u001b[0m Trial 67 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:31:42,320]\u001b[0m Trial 68 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:36:04,117]\u001b[0m Trial 69 finished with value: 0.42859933221611907 and parameters: {'embedding_dim': 6, 'step_size': 0.055681143765690076, 'batch_size': 10, 'num_knots': 24, 'num_epochs': 11}. Best is trial 63 with value: 0.42706043678253136.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:40:13,326]\u001b[0m Trial 70 finished with value: 0.4299545152028061 and parameters: {'embedding_dim': 6, 'step_size': 0.05054589162131087, 'batch_size': 10, 'num_knots': 24, 'num_epochs': 11}. Best is trial 63 with value: 0.42706043678253136.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:40:31,967]\u001b[0m Trial 71 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:45:04,936]\u001b[0m Trial 72 finished with value: 0.4300412888997274 and parameters: {'embedding_dim': 6, 'step_size': 0.06533761080426782, 'batch_size': 8, 'num_knots': 22, 'num_epochs': 10}. Best is trial 63 with value: 0.42706043678253136.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:45:38,758]\u001b[0m Trial 73 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:48:56,131]\u001b[0m Trial 74 finished with value: 0.44036213426531196 and parameters: {'embedding_dim': 7, 'step_size': 0.08362911227618235, 'batch_size': 10, 'num_knots': 25, 'num_epochs': 9}. Best is trial 63 with value: 0.42706043678253136.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:49:33,812]\u001b[0m Trial 75 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:53:03,111]\u001b[0m Trial 76 finished with value: 0.43593629611269524 and parameters: {'embedding_dim': 9, 'step_size': 0.07164141513191531, 'batch_size': 11, 'num_knots': 18, 'num_epochs': 10}. Best is trial 63 with value: 0.42706043678253136.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:53:31,764]\u001b[0m Trial 77 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:56:48,594]\u001b[0m Trial 78 finished with value: 0.42568818221361704 and parameters: {'embedding_dim': 7, 'step_size': 0.052749809167358816, 'batch_size': 14, 'num_knots': 21, 'num_epochs': 12}. Best is trial 78 with value: 0.42568818221361704.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:57:05,323]\u001b[0m Trial 79 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:57:20,072]\u001b[0m Trial 80 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:57:38,013]\u001b[0m Trial 81 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:58:10,526]\u001b[0m Trial 82 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:59:42,180]\u001b[0m Trial 83 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:01:19,935]\u001b[0m Trial 84 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:01:41,489]\u001b[0m Trial 85 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:01:50,579]\u001b[0m Trial 86 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:05:15,605]\u001b[0m Trial 87 finished with value: 0.4339707428227233 and parameters: {'embedding_dim': 8, 'step_size': 0.06642377759575374, 'batch_size': 12, 'num_knots': 24, 'num_epochs': 11}. Best is trial 78 with value: 0.42568818221361704.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:08:33,489]\u001b[0m Trial 88 finished with value: 0.42828476563061135 and parameters: {'embedding_dim': 6, 'step_size': 0.0782610810110209, 'batch_size': 9, 'num_knots': 26, 'num_epochs': 8}. Best is trial 78 with value: 0.42568818221361704.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:09:04,670]\u001b[0m Trial 89 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:16:16,226]\u001b[0m Trial 90 finished with value: 0.4314792037783839 and parameters: {'embedding_dim': 6, 'step_size': 0.07535244582825283, 'batch_size': 4, 'num_knots': 26, 'num_epochs': 8}. Best is trial 78 with value: 0.42568818221361704.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:16:40,833]\u001b[0m Trial 91 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:17:12,669]\u001b[0m Trial 92 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:20:33,451]\u001b[0m Trial 93 finished with value: 0.4322334637200849 and parameters: {'embedding_dim': 7, 'step_size': 0.06619156418718045, 'batch_size': 11, 'num_knots': 23, 'num_epochs': 10}. Best is trial 78 with value: 0.42568818221361704.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:20:58,149]\u001b[0m Trial 94 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:22:56,512]\u001b[0m Trial 95 finished with value: 0.4317012761514909 and parameters: {'embedding_dim': 6, 'step_size': 0.07919381789878355, 'batch_size': 15, 'num_knots': 21, 'num_epochs': 8}. Best is trial 78 with value: 0.42568818221361704.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:24:09,483]\u001b[0m Trial 96 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:25:31,914]\u001b[0m Trial 97 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:26:07,732]\u001b[0m Trial 98 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:33:25,357]\u001b[0m Trial 99 finished with value: 0.43331236773293147 and parameters: {'embedding_dim': 8, 'step_size': 0.06415431672405765, 'batch_size': 5, 'num_knots': 25, 'num_epochs': 10}. Best is trial 78 with value: 0.42568818221361704.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "study = optuna.create_study(study_name='splines',\n",
    "                            direction='minimize',\n",
    "                            sampler=optuna.samplers.TPESampler(seed=42))\n",
    "study.optimize(train_spline_objective, n_trials=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test loss: 0.42568818221361704\n",
      "Best hyperparameters: {'embedding_dim': 7, 'step_size': 0.052749809167358816, 'batch_size': 14, 'num_knots': 21, 'num_epochs': 12}\n"
     ]
    }
   ],
   "source": [
    "trial = study.best_trial\n",
    "\n",
    "print('Test loss: {}'.format(trial.value))\n",
    "print(\"Best hyperparameters: {}\".format(trial.params))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'embedding_dim': 7,\n",
       " 'step_size': 0.052749809167358816,\n",
       " 'batch_size': 14,\n",
       " 'num_knots': 21,\n",
       " 'num_epochs': 12}"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "study.best_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [1:01:29<00:00, 184.47s/it]\n"
     ]
    }
   ],
   "source": [
    "spline_losses = []\n",
    "for i in trange(20):\n",
    "    mse = train_spline_ffm(**study.best_params)\n",
    "    spline_losses.append(math.sqrt(mse))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.42679040640311267,\n",
       " 0.42634108249853464,\n",
       " 0.4261458174783907,\n",
       " 0.43225683700092105,\n",
       " 0.431104387357407,\n",
       " 0.4288870899249733,\n",
       " 0.4277180881970324,\n",
       " 0.4311970464012093,\n",
       " 0.4338799640302308,\n",
       " 0.4241321194256667,\n",
       " 0.42784070300256505,\n",
       " 0.43046144662743624,\n",
       " 0.4316383465322712,\n",
       " 0.42870344707227903,\n",
       " 0.4286306736221006,\n",
       " 0.4301517752661648,\n",
       " 0.4326343059382102,\n",
       " 0.428075541124546,\n",
       " 0.42944209286761614,\n",
       " 0.4320622969331268]"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "spline_losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.42940467338518973,\n",
       " 0.007393634011334159,\n",
       " 0.43679830739652387,\n",
       " 0.4220110393738556)"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(spline_losses), 3 * np.std(spline_losses), np.mean(spline_losses) + 3 * np.std(spline_losses), np.mean(spline_losses) - 3 * np.std(spline_losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_bin_ffm(embedding_dim: int, step_size: float, batch_size: int,\n",
    "                  num_bins: int, bin_strategy: str, num_epochs: int,\n",
    "                  callback: Callable[[int, float], None]=None):\n",
    "    num_fields = X_train.shape[1]\n",
    "    offsets = np.arange(0, num_fields) * num_bins\n",
    "\n",
    "    discretizer = KBinsDiscretizer(num_bins, encode='ordinal', strategy=bin_strategy, random_state=42)\n",
    "    discretizer.fit(X_train)\n",
    "\n",
    "    indices_train = discretizer.transform(X_train)\n",
    "    indices_train += np.tile(offsets, (X_train.shape[0], 1))\n",
    "    weights_train = np.ones_like(indices_train)\n",
    "    fields_train = np.tile(np.arange(0, num_fields), (X_train.shape[0], 1))\n",
    "\n",
    "    indices_test = discretizer.transform(X_test)\n",
    "    indices_test += np.tile(offsets, (X_test.shape[0], 1))\n",
    "    weights_test = np.ones_like(indices_test)\n",
    "    fields_test = np.tile(np.arange(0, num_fields), (X_test.shape[0], 1))\n",
    "\n",
    "\n",
    "    num_embeddings = num_fields * num_bins\n",
    "    train_ds = TensorDataset(\n",
    "        torch.tensor(indices_train, dtype=torch.int64),\n",
    "        torch.tensor(weights_train, dtype=torch.float32),\n",
    "        torch.tensor(fields_train, dtype=torch.int64),\n",
    "        torch.tensor(fields_train, dtype=torch.int64),\n",
    "        torch.tensor(y_train, dtype=torch.float32))\n",
    "\n",
    "    test_ds = TensorDataset(\n",
    "        torch.tensor(indices_test, dtype=torch.int64),\n",
    "        torch.tensor(weights_test, dtype=torch.float32),\n",
    "        torch.tensor(fields_test, dtype=torch.int64),\n",
    "        torch.tensor(fields_test, dtype=torch.int64),\n",
    "        torch.tensor(y_test, dtype=torch.float32))\n",
    "\n",
    "    trainer = FFMTrainer(embedding_dim, step_size, batch_size, num_epochs, callback)\n",
    "    return trainer.train(num_fields, num_embeddings, train_ds, test_ds, torch.nn.MSELoss(), device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_bins_objective(trial: optuna.Trial):\n",
    "    embedding_dim = trial.suggest_int('embedding_dim', 1, 10)\n",
    "    step_size = trial.suggest_float('step_size', 1e-2, 0.5, log=True)\n",
    "    batch_size = trial.suggest_int('batch_size', 2, 32)\n",
    "    num_bins = trial.suggest_int('num_bins', 2, 100)\n",
    "    bin_strategy = trial.suggest_categorical('bin_strategy', ['uniform', 'quantile'])\n",
    "    num_epochs = trial.suggest_int('num_epochs', 5, 15)\n",
    "\n",
    "    def callback(epoch: int, mse: float):\n",
    "        rmse = math.sqrt(mse)\n",
    "        trial.report(rmse, epoch)\n",
    "        if trial.should_prune():\n",
    "            raise optuna.TrialPruned()\n",
    "\n",
    "    mse = train_bin_ffm(embedding_dim, step_size, batch_size, num_bins, bin_strategy, num_epochs,\n",
    "                        callback=callback)\n",
    "    return math.sqrt(mse)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m[I 2023-05-16 23:36:36,044]\u001b[0m A new study created in memory with name: bins\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:37:23,402]\u001b[0m Trial 0 finished with value: 0.5280771877157359 and parameters: {'embedding_dim': 4, 'step_size': 0.4123206532618726, 'batch_size': 24, 'num_bins': 61, 'bin_strategy': 'uniform', 'num_epochs': 5}. Best is trial 0 with value: 0.5280771877157359.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:38:21,498]\u001b[0m Trial 1 finished with value: 0.7002097411644254 and parameters: {'embedding_dim': 9, 'step_size': 0.10502105436744279, 'batch_size': 23, 'num_bins': 4, 'bin_strategy': 'uniform', 'num_epochs': 7}. Best is trial 0 with value: 0.5280771877157359.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:42:02,363]\u001b[0m Trial 2 finished with value: 0.5551966895575728 and parameters: {'embedding_dim': 2, 'step_size': 0.020492680115417352, 'batch_size': 11, 'num_bins': 53, 'bin_strategy': 'uniform', 'num_epochs': 11}. Best is trial 0 with value: 0.5280771877157359.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:44:53,104]\u001b[0m Trial 3 finished with value: 0.541053027020405 and parameters: {'embedding_dim': 2, 'step_size': 0.03135775732257745, 'batch_size': 13, 'num_bins': 47, 'bin_strategy': 'uniform', 'num_epochs': 10}. Best is trial 0 with value: 0.5280771877157359.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-16 23:47:41,642]\u001b[0m Trial 4 finished with value: 0.5161959059594673 and parameters: {'embedding_dim': 6, 'step_size': 0.011992724522955167, 'batch_size': 20, 'num_bins': 18, 'bin_strategy': 'quantile', 'num_epochs': 15}. Best is trial 4 with value: 0.5161959059594673.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:55:02,608]\u001b[0m Trial 5 finished with value: 0.4970072048574563 and parameters: {'embedding_dim': 9, 'step_size': 0.032925293631105246, 'batch_size': 5, 'num_bins': 69, 'bin_strategy': 'uniform', 'num_epochs': 10}. Best is trial 5 with value: 0.4970072048574563.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-16 23:59:00,824]\u001b[0m Trial 6 finished with value: 0.4909980418745024 and parameters: {'embedding_dim': 1, 'step_size': 0.35067764992972184, 'batch_size': 10, 'num_bins': 67, 'bin_strategy': 'quantile', 'num_epochs': 11}. Best is trial 6 with value: 0.4909980418745024.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:01:14,977]\u001b[0m Trial 7 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:02:29,776]\u001b[0m Trial 8 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:03:20,527]\u001b[0m Trial 9 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:03:28,015]\u001b[0m Trial 10 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:05:20,936]\u001b[0m Trial 11 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:05:44,832]\u001b[0m Trial 12 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:06:08,219]\u001b[0m Trial 13 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:09:05,756]\u001b[0m Trial 14 finished with value: 0.502395642627435 and parameters: {'embedding_dim': 4, 'step_size': 0.05011263930396099, 'batch_size': 15, 'num_bins': 43, 'bin_strategy': 'quantile', 'num_epochs': 12}. Best is trial 6 with value: 0.4909980418745024.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:09:42,896]\u001b[0m Trial 15 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:10:19,842]\u001b[0m Trial 16 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 3 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:10:39,667]\u001b[0m Trial 17 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:10:55,796]\u001b[0m Trial 18 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:13:51,299]\u001b[0m Trial 19 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:14:16,169]\u001b[0m Trial 20 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:14:30,740]\u001b[0m Trial 21 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:14:43,525]\u001b[0m Trial 22 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:15:03,802]\u001b[0m Trial 23 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:15:31,644]\u001b[0m Trial 24 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 3 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:15:44,864]\u001b[0m Trial 25 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:18:28,847]\u001b[0m Trial 26 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:19:02,639]\u001b[0m Trial 27 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:22:31,925]\u001b[0m Trial 28 finished with value: 0.5196768477788142 and parameters: {'embedding_dim': 7, 'step_size': 0.14061965231931225, 'batch_size': 11, 'num_bins': 85, 'bin_strategy': 'uniform', 'num_epochs': 11}. Best is trial 6 with value: 0.4909980418745024.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:23:55,180]\u001b[0m Trial 29 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:24:09,656]\u001b[0m Trial 30 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:24:20,135]\u001b[0m Trial 31 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:25:03,259]\u001b[0m Trial 32 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:26:44,073]\u001b[0m Trial 33 finished with value: 0.4875798948549539 and parameters: {'embedding_dim': 6, 'step_size': 0.020684208264680124, 'batch_size': 24, 'num_bins': 11, 'bin_strategy': 'quantile', 'num_epochs': 11}. Best is trial 33 with value: 0.4875798948549539.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:26:50,369]\u001b[0m Trial 34 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:27:11,599]\u001b[0m Trial 35 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:27:19,713]\u001b[0m Trial 36 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:27:38,434]\u001b[0m Trial 37 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:28:03,184]\u001b[0m Trial 38 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:28:10,949]\u001b[0m Trial 39 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:28:34,929]\u001b[0m Trial 40 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:28:43,881]\u001b[0m Trial 41 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:30:51,364]\u001b[0m Trial 42 finished with value: 0.4912414844638768 and parameters: {'embedding_dim': 6, 'step_size': 0.02366328080027664, 'batch_size': 25, 'num_bins': 12, 'bin_strategy': 'quantile', 'num_epochs': 14}. Best is trial 33 with value: 0.4875798948549539.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:32:38,553]\u001b[0m Trial 43 finished with value: 0.4814953574890661 and parameters: {'embedding_dim': 6, 'step_size': 0.02353530079691863, 'batch_size': 30, 'num_bins': 11, 'bin_strategy': 'quantile', 'num_epochs': 14}. Best is trial 43 with value: 0.4814953574890661.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:34:18,009]\u001b[0m Trial 44 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:34:24,647]\u001b[0m Trial 45 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:36:18,448]\u001b[0m Trial 46 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:36:56,906]\u001b[0m Trial 47 finished with value: 0.49383070805204027 and parameters: {'embedding_dim': 6, 'step_size': 0.025434555654167735, 'batch_size': 30, 'num_bins': 15, 'bin_strategy': 'quantile', 'num_epochs': 5}. Best is trial 43 with value: 0.4814953574890661.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:37:04,635]\u001b[0m Trial 48 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:37:53,269]\u001b[0m Trial 49 finished with value: 0.48899480942348683 and parameters: {'embedding_dim': 6, 'step_size': 0.028209809445833324, 'batch_size': 28, 'num_bins': 15, 'bin_strategy': 'quantile', 'num_epochs': 6}. Best is trial 43 with value: 0.4814953574890661.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:38:01,503]\u001b[0m Trial 50 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:38:48,458]\u001b[0m Trial 51 finished with value: 0.48648655467159996 and parameters: {'embedding_dim': 6, 'step_size': 0.027729848277686826, 'batch_size': 29, 'num_bins': 15, 'bin_strategy': 'quantile', 'num_epochs': 6}. Best is trial 43 with value: 0.4814953574890661.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:39:43,237]\u001b[0m Trial 52 finished with value: 0.48784862426097797 and parameters: {'embedding_dim': 6, 'step_size': 0.028651818424917735, 'batch_size': 25, 'num_bins': 17, 'bin_strategy': 'quantile', 'num_epochs': 6}. Best is trial 43 with value: 0.4814953574890661.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:40:31,815]\u001b[0m Trial 53 finished with value: 0.4868366998634225 and parameters: {'embedding_dim': 7, 'step_size': 0.030842284746007004, 'batch_size': 28, 'num_bins': 18, 'bin_strategy': 'quantile', 'num_epochs': 6}. Best is trial 43 with value: 0.4814953574890661.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:40:48,258]\u001b[0m Trial 54 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:41:08,129]\u001b[0m Trial 55 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:41:16,092]\u001b[0m Trial 56 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:41:51,479]\u001b[0m Trial 57 finished with value: 0.47951060024732495 and parameters: {'embedding_dim': 8, 'step_size': 0.03929383281600887, 'batch_size': 32, 'num_bins': 17, 'bin_strategy': 'quantile', 'num_epochs': 6}. Best is trial 57 with value: 0.47951060024732495.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:42:43,233]\u001b[0m Trial 58 finished with value: 0.4845081577007928 and parameters: {'embedding_dim': 8, 'step_size': 0.04246678658964534, 'batch_size': 32, 'num_bins': 25, 'bin_strategy': 'quantile', 'num_epochs': 7}. Best is trial 57 with value: 0.47951060024732495.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:43:34,879]\u001b[0m Trial 59 finished with value: 0.49187994233475485 and parameters: {'embedding_dim': 8, 'step_size': 0.04312872060983268, 'batch_size': 32, 'num_bins': 25, 'bin_strategy': 'quantile', 'num_epochs': 7}. Best is trial 57 with value: 0.47951060024732495.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:44:35,589]\u001b[0m Trial 60 finished with value: 0.4815114345371291 and parameters: {'embedding_dim': 8, 'step_size': 0.04837698369243848, 'batch_size': 31, 'num_bins': 11, 'bin_strategy': 'quantile', 'num_epochs': 8}. Best is trial 57 with value: 0.47951060024732495.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:45:34,500]\u001b[0m Trial 61 finished with value: 0.49070868780647986 and parameters: {'embedding_dim': 8, 'step_size': 0.049335977465450875, 'batch_size': 32, 'num_bins': 12, 'bin_strategy': 'quantile', 'num_epochs': 8}. Best is trial 57 with value: 0.47951060024732495.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:45:42,276]\u001b[0m Trial 62 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:46:20,071]\u001b[0m Trial 63 finished with value: 0.491216049016063 and parameters: {'embedding_dim': 7, 'step_size': 0.03515178497648949, 'batch_size': 31, 'num_bins': 19, 'bin_strategy': 'quantile', 'num_epochs': 5}. Best is trial 57 with value: 0.47951060024732495.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:47:08,678]\u001b[0m Trial 64 finished with value: 0.4849947072261395 and parameters: {'embedding_dim': 8, 'step_size': 0.04805048391692695, 'batch_size': 29, 'num_bins': 11, 'bin_strategy': 'quantile', 'num_epochs': 6}. Best is trial 57 with value: 0.47951060024732495.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:47:57,305]\u001b[0m Trial 65 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:48:50,313]\u001b[0m Trial 66 finished with value: 0.476914166633194 and parameters: {'embedding_dim': 8, 'step_size': 0.03971654057761182, 'batch_size': 31, 'num_bins': 23, 'bin_strategy': 'quantile', 'num_epochs': 7}. Best is trial 66 with value: 0.476914166633194.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:49:13,452]\u001b[0m Trial 67 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:50:01,305]\u001b[0m Trial 68 finished with value: 0.482084411088513 and parameters: {'embedding_dim': 9, 'step_size': 0.06121422232448351, 'batch_size': 29, 'num_bins': 23, 'bin_strategy': 'quantile', 'num_epochs': 7}. Best is trial 66 with value: 0.476914166633194.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:50:54,972]\u001b[0m Trial 69 finished with value: 0.4851239628268377 and parameters: {'embedding_dim': 10, 'step_size': 0.06349724505888375, 'batch_size': 31, 'num_bins': 22, 'bin_strategy': 'quantile', 'num_epochs': 7}. Best is trial 66 with value: 0.476914166633194.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:51:46,909]\u001b[0m Trial 70 finished with value: 0.48660264444334267 and parameters: {'embedding_dim': 9, 'step_size': 0.056829217265857324, 'batch_size': 32, 'num_bins': 26, 'bin_strategy': 'quantile', 'num_epochs': 7}. Best is trial 66 with value: 0.476914166633194.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:52:40,173]\u001b[0m Trial 71 finished with value: 0.48210371390904005 and parameters: {'embedding_dim': 10, 'step_size': 0.06862730361201547, 'batch_size': 31, 'num_bins': 23, 'bin_strategy': 'quantile', 'num_epochs': 7}. Best is trial 66 with value: 0.476914166633194.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:53:13,224]\u001b[0m Trial 72 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:54:06,179]\u001b[0m Trial 73 finished with value: 0.47648125878431646 and parameters: {'embedding_dim': 8, 'step_size': 0.0527994326893901, 'batch_size': 31, 'num_bins': 23, 'bin_strategy': 'quantile', 'num_epochs': 7}. Best is trial 73 with value: 0.47648125878431646.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:54:13,921]\u001b[0m Trial 74 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:55:14,666]\u001b[0m Trial 75 finished with value: 0.4832143413749891 and parameters: {'embedding_dim': 8, 'step_size': 0.05312163758261429, 'batch_size': 27, 'num_bins': 24, 'bin_strategy': 'quantile', 'num_epochs': 7}. Best is trial 73 with value: 0.47648125878431646.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:55:40,298]\u001b[0m Trial 76 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:56:43,399]\u001b[0m Trial 77 finished with value: 0.48646381124588367 and parameters: {'embedding_dim': 10, 'step_size': 0.06784387379069022, 'batch_size': 30, 'num_bins': 23, 'bin_strategy': 'quantile', 'num_epochs': 8}. Best is trial 73 with value: 0.47648125878431646.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:56:52,544]\u001b[0m Trial 78 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:58:01,068]\u001b[0m Trial 79 finished with value: 0.49141271856808333 and parameters: {'embedding_dim': 9, 'step_size': 0.0831574285124721, 'batch_size': 31, 'num_bins': 19, 'bin_strategy': 'quantile', 'num_epochs': 9}. Best is trial 73 with value: 0.47648125878431646.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:58:35,449]\u001b[0m Trial 80 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 00:59:18,893]\u001b[0m Trial 81 finished with value: 0.48632571974316574 and parameters: {'embedding_dim': 8, 'step_size': 0.043784056828969076, 'batch_size': 32, 'num_bins': 26, 'bin_strategy': 'quantile', 'num_epochs': 7}. Best is trial 73 with value: 0.47648125878431646.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 01:00:13,887]\u001b[0m Trial 82 finished with value: 0.4808140576078414 and parameters: {'embedding_dim': 9, 'step_size': 0.0396166901001848, 'batch_size': 30, 'num_bins': 17, 'bin_strategy': 'quantile', 'num_epochs': 7}. Best is trial 73 with value: 0.47648125878431646.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 01:01:16,749]\u001b[0m Trial 83 finished with value: 0.484505758786694 and parameters: {'embedding_dim': 10, 'step_size': 0.035889673059122974, 'batch_size': 30, 'num_bins': 17, 'bin_strategy': 'quantile', 'num_epochs': 8}. Best is trial 73 with value: 0.47648125878431646.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:01:24,973]\u001b[0m Trial 84 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:02:03,746]\u001b[0m Trial 85 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:02:26,934]\u001b[0m Trial 86 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 01:03:17,388]\u001b[0m Trial 87 finished with value: 0.4778754175051123 and parameters: {'embedding_dim': 8, 'step_size': 0.05100157729747126, 'batch_size': 28, 'num_bins': 22, 'bin_strategy': 'quantile', 'num_epochs': 6}. Best is trial 73 with value: 0.47648125878431646.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 01:03:59,348]\u001b[0m Trial 88 finished with value: 0.47865184479910533 and parameters: {'embedding_dim': 10, 'step_size': 0.046509370024280926, 'batch_size': 28, 'num_bins': 21, 'bin_strategy': 'quantile', 'num_epochs': 5}. Best is trial 73 with value: 0.47648125878431646.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 01:04:42,129]\u001b[0m Trial 89 finished with value: 0.48706426403950454 and parameters: {'embedding_dim': 8, 'step_size': 0.044832751812685284, 'batch_size': 28, 'num_bins': 18, 'bin_strategy': 'quantile', 'num_epochs': 5}. Best is trial 73 with value: 0.47648125878431646.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:04:50,840]\u001b[0m Trial 90 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 01:05:38,872]\u001b[0m Trial 91 finished with value: 0.4764637610283817 and parameters: {'embedding_dim': 10, 'step_size': 0.04717310364019017, 'batch_size': 30, 'num_bins': 22, 'bin_strategy': 'quantile', 'num_epochs': 6}. Best is trial 91 with value: 0.4764637610283817.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:06:26,018]\u001b[0m Trial 92 finished with value: 0.4725330759763331 and parameters: {'embedding_dim': 10, 'step_size': 0.04746441356701626, 'batch_size': 30, 'num_bins': 13, 'bin_strategy': 'quantile', 'num_epochs': 6}. Best is trial 92 with value: 0.4725330759763331.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:06:41,549]\u001b[0m Trial 93 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 01:07:12,239]\u001b[0m Trial 94 finished with value: 0.4796635929325749 and parameters: {'embedding_dim': 10, 'step_size': 0.0477274890995249, 'batch_size': 30, 'num_bins': 17, 'bin_strategy': 'quantile', 'num_epochs': 5}. Best is trial 92 with value: 0.4725330759763331.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:07:53,602]\u001b[0m Trial 95 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 01:08:37,636]\u001b[0m Trial 96 finished with value: 0.4850260604414627 and parameters: {'embedding_dim': 10, 'step_size': 0.046947828609809845, 'batch_size': 26, 'num_bins': 19, 'bin_strategy': 'quantile', 'num_epochs': 5}. Best is trial 92 with value: 0.4725330759763331.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 01:08:51,371]\u001b[0m Trial 97 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 01:09:30,112]\u001b[0m Trial 98 finished with value: 0.4824779956622831 and parameters: {'embedding_dim': 10, 'step_size': 0.04156484777805755, 'batch_size': 30, 'num_bins': 17, 'bin_strategy': 'quantile', 'num_epochs': 5}. Best is trial 92 with value: 0.4725330759763331.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:09:54,902]\u001b[0m Trial 99 pruned. \u001b[0m\n"
     ]
    }
   ],
   "source": [
    "study_bins = optuna.create_study(study_name='bins',\n",
    "                            direction='minimize',\n",
    "                            sampler=optuna.samplers.TPESampler(seed=42))\n",
    "study_bins.optimize(test_bins_objective, n_trials=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'embedding_dim': 10,\n",
       " 'step_size': 0.04746441356701626,\n",
       " 'batch_size': 30,\n",
       " 'num_bins': 13,\n",
       " 'bin_strategy': 'quantile',\n",
       " 'num_epochs': 6}"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "study_bins.best_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test loss: 0.4725330759763331\n",
      "Best hyperparameters: {'embedding_dim': 10, 'step_size': 0.04746441356701626, 'batch_size': 30, 'num_bins': 13, 'bin_strategy': 'quantile', 'num_epochs': 6}\n"
     ]
    }
   ],
   "source": [
    "trial = study_bins.best_trial\n",
    "\n",
    "print('Test loss: {}'.format(trial.value))\n",
    "print(\"Best hyperparameters: {}\".format(trial.params))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.22700029611587524"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_bin_ffm(**study_bins.best_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [15:42<00:00, 47.11s/it]\n"
     ]
    }
   ],
   "source": [
    "bin_losses = []\n",
    "for i in trange(20):\n",
    "    mse = train_bin_ffm(**study_bins.best_params)\n",
    "    bin_losses.append(math.sqrt(mse))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.47253831069784313,\n",
       " 0.47436226585193625,\n",
       " 0.4746327178614912,\n",
       " 0.47469055988463177,\n",
       " 0.4737954243998599,\n",
       " 0.4738253487224963,\n",
       " 0.47137653094042825,\n",
       " 0.47257028538839857,\n",
       " 0.47029610602108984,\n",
       " 0.47444127877551856,\n",
       " 0.4718063374009628,\n",
       " 0.4723477105300527,\n",
       " 0.47160880620469847,\n",
       " 0.477920785313591,\n",
       " 0.4717265988354307,\n",
       " 0.47360833488214976,\n",
       " 0.4717850814181646,\n",
       " 0.4717952515621083,\n",
       " 0.4710633573328178,\n",
       " 0.47449408806774845]"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bin_losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.47303425900457097,\n",
       " 0.0051958734734685146,\n",
       " 0.47823013247803947,\n",
       " 0.46783838553110246)"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(bin_losses), 3 * np.std(bin_losses), np.mean(bin_losses) + 3 * np.std(bin_losses), np.mean(bin_losses) - 3 * np.std(bin_losses)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
