{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:31.944221Z",
     "start_time": "2023-05-14T19:59:31.869652Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.insert(0, \"../utils\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.291538Z",
     "start_time": "2023-05-14T19:59:31.876778Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import sklearn.datasets as skds\n",
    "from sklearn.preprocessing import QuantileTransformer, KBinsDiscretizer, OrdinalEncoder, LabelEncoder, StandardScaler\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from transformation import BSplineTransformer, spline_transform_dataset\n",
    "from trainers import FFMTrainer, FMTrainer\n",
    "import math\n",
    "import optuna\n",
    "import optuna.samplers\n",
    "from typing import Callable\n",
    "from sklearn.model_selection import train_test_split\n",
    "import torch\n",
    "from torch.utils.data import TensorDataset\n",
    "from tqdm import trange"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.291828Z",
     "start_time": "2023-05-14T19:59:35.238271Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda:0\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.291910Z",
     "start_time": "2023-05-14T19:59:35.244997Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(42)\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "raw_df = pd.read_csv(\"../data/Ye_millionsongdataset/Training_set_songs.csv\",\n",
    "                     names=['Year', 'TA01', 'TA02', 'TA03', 'TA04', 'TA05', 'TA06', 'TA07', 'TA08', 'TA09', 'TA10', 'TA11', 'TA12', 'TC01', 'TC02', 'TC03', 'TC04', 'TC05', 'TC06', 'TC07', 'TC08', 'TC09', 'TC10', 'TC11', 'TC12', 'TC13', 'TC14', 'TC15', 'TC16', 'TC17', 'TC18', 'TC19', 'TC20', 'TC21', 'TC22', 'TC23', 'TC24', 'TC25', 'TC26', 'TC27', 'TC28', 'TC29', 'TC30', 'TC31', 'TC32', 'TC33', 'TC34', 'TC35', 'TC36', 'TC37', 'TC38', 'TC39', 'TC40', 'TC41', 'TC42', 'TC43', 'TC44', 'TC45', 'TC46', 'TC47', 'TC48', 'TC49', 'TC50', 'TC51', 'TC52', 'TC53', 'TC54', 'TC55', 'TC56', 'TC57', 'TC58', 'TC59', 'TC60', 'TC61', 'TC62', 'TC63', 'TC64', 'TC65', 'TC66', 'TC67', 'TC68', 'TC69', 'TC70', 'TC71', 'TC72', 'TC73', 'TC74', 'TC75', 'TC76', 'TC77', 'TC78'],\n",
    "                     dtype={0:int, 1:float, 2:float, 3:float, 4:float, 5:float, 6:float, 7:float, 8:float, 9:float, 10:float, 11:float, 12:float, 13:float, 14:float, 15:float, 16:float, 17:float, 18:float, 19:float, 20:float, 21:float, 22:float, 23:float, 24:float, 25:float, 26:float, 27:float, 28:float, 29:float, 30:float, 31:float, 32:float, 33:float, 34:float, 35:float, 36:float, 37:float, 38:float, 39:float, 40:float, 41:float, 42:float, 43:float, 44:float, 45:float, 46:float, 47:float, 48:float, 49:float, 50:float, 51:float, 52:float, 53:float, 54:float, 55:float, 56:float, 57:float, 58:float, 59:float, 60:float, 61:float, 62:float, 63:float, 64:float, 65:float, 66:float, 67:float, 68:float, 69:float, 70:float, 71:float, 72:float, 73:float, 74:float, 75:float, 76:float, 77:float, 78:float, 79:float, 80:float, 81:float, 82:float, 83:float, 84:float, 85:float, 86:float, 87:float, 88:float, 89:float, 90:float},\n",
    "                     na_values=\"?\", skiprows=1)  # TODO: only 3000 lines are loaded in the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.465430Z",
     "start_time": "2023-05-14T19:59:35.384297Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Year</th>\n",
       "      <th>TA01</th>\n",
       "      <th>TA02</th>\n",
       "      <th>TA03</th>\n",
       "      <th>TA04</th>\n",
       "      <th>TA05</th>\n",
       "      <th>TA06</th>\n",
       "      <th>TA07</th>\n",
       "      <th>TA08</th>\n",
       "      <th>TA09</th>\n",
       "      <th>...</th>\n",
       "      <th>TC69</th>\n",
       "      <th>TC70</th>\n",
       "      <th>TC71</th>\n",
       "      <th>TC72</th>\n",
       "      <th>TC73</th>\n",
       "      <th>TC74</th>\n",
       "      <th>TC75</th>\n",
       "      <th>TC76</th>\n",
       "      <th>TC77</th>\n",
       "      <th>TC78</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>332595</th>\n",
       "      <td>2004</td>\n",
       "      <td>51.19589</td>\n",
       "      <td>18.54077</td>\n",
       "      <td>12.55200</td>\n",
       "      <td>3.25000</td>\n",
       "      <td>-7.79542</td>\n",
       "      <td>-19.63461</td>\n",
       "      <td>0.27662</td>\n",
       "      <td>-1.03994</td>\n",
       "      <td>-1.57231</td>\n",
       "      <td>...</td>\n",
       "      <td>-16.43458</td>\n",
       "      <td>-37.05054</td>\n",
       "      <td>-100.93186</td>\n",
       "      <td>49.26529</td>\n",
       "      <td>-1.47467</td>\n",
       "      <td>-39.25996</td>\n",
       "      <td>-22.44389</td>\n",
       "      <td>-16.37939</td>\n",
       "      <td>-49.67664</td>\n",
       "      <td>3.87491</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>230573</th>\n",
       "      <td>1989</td>\n",
       "      <td>51.44573</td>\n",
       "      <td>53.42621</td>\n",
       "      <td>35.83483</td>\n",
       "      <td>-16.72867</td>\n",
       "      <td>-48.13185</td>\n",
       "      <td>-13.04248</td>\n",
       "      <td>-45.74081</td>\n",
       "      <td>-6.18791</td>\n",
       "      <td>16.60100</td>\n",
       "      <td>...</td>\n",
       "      <td>9.94220</td>\n",
       "      <td>-71.82280</td>\n",
       "      <td>59.40250</td>\n",
       "      <td>45.14201</td>\n",
       "      <td>-2.72343</td>\n",
       "      <td>35.62292</td>\n",
       "      <td>37.99939</td>\n",
       "      <td>7.04862</td>\n",
       "      <td>21.08678</td>\n",
       "      <td>2.07779</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>364530</th>\n",
       "      <td>1987</td>\n",
       "      <td>44.85215</td>\n",
       "      <td>33.51052</td>\n",
       "      <td>27.44631</td>\n",
       "      <td>24.99783</td>\n",
       "      <td>-15.02508</td>\n",
       "      <td>12.29223</td>\n",
       "      <td>10.57365</td>\n",
       "      <td>7.06412</td>\n",
       "      <td>-1.29649</td>\n",
       "      <td>...</td>\n",
       "      <td>5.05131</td>\n",
       "      <td>56.78609</td>\n",
       "      <td>136.98499</td>\n",
       "      <td>-30.38374</td>\n",
       "      <td>24.51034</td>\n",
       "      <td>62.77834</td>\n",
       "      <td>44.80304</td>\n",
       "      <td>30.34866</td>\n",
       "      <td>105.17991</td>\n",
       "      <td>2.58183</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>82857</th>\n",
       "      <td>2002</td>\n",
       "      <td>50.25653</td>\n",
       "      <td>59.83236</td>\n",
       "      <td>37.80210</td>\n",
       "      <td>-0.57762</td>\n",
       "      <td>-1.64064</td>\n",
       "      <td>-11.77884</td>\n",
       "      <td>-2.25574</td>\n",
       "      <td>-4.17007</td>\n",
       "      <td>4.05922</td>\n",
       "      <td>...</td>\n",
       "      <td>-4.12555</td>\n",
       "      <td>21.34620</td>\n",
       "      <td>29.72172</td>\n",
       "      <td>56.71419</td>\n",
       "      <td>2.61590</td>\n",
       "      <td>56.47152</td>\n",
       "      <td>-26.05716</td>\n",
       "      <td>-0.77059</td>\n",
       "      <td>29.40943</td>\n",
       "      <td>-0.02311</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>108108</th>\n",
       "      <td>1971</td>\n",
       "      <td>45.48775</td>\n",
       "      <td>-5.49790</td>\n",
       "      <td>9.56187</td>\n",
       "      <td>9.36977</td>\n",
       "      <td>-11.15726</td>\n",
       "      <td>-3.63341</td>\n",
       "      <td>11.00297</td>\n",
       "      <td>-6.36722</td>\n",
       "      <td>5.37455</td>\n",
       "      <td>...</td>\n",
       "      <td>-29.83405</td>\n",
       "      <td>228.02525</td>\n",
       "      <td>53.68190</td>\n",
       "      <td>-47.38609</td>\n",
       "      <td>14.62809</td>\n",
       "      <td>124.90797</td>\n",
       "      <td>-26.61476</td>\n",
       "      <td>5.08838</td>\n",
       "      <td>295.42035</td>\n",
       "      <td>19.74883</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>446568</th>\n",
       "      <td>2005</td>\n",
       "      <td>49.59492</td>\n",
       "      <td>35.40110</td>\n",
       "      <td>-8.11273</td>\n",
       "      <td>-13.40502</td>\n",
       "      <td>3.18931</td>\n",
       "      <td>-14.05923</td>\n",
       "      <td>3.04436</td>\n",
       "      <td>-1.43375</td>\n",
       "      <td>3.60354</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.40689</td>\n",
       "      <td>-91.80271</td>\n",
       "      <td>-29.51019</td>\n",
       "      <td>-6.40710</td>\n",
       "      <td>6.99983</td>\n",
       "      <td>81.51023</td>\n",
       "      <td>-83.44656</td>\n",
       "      <td>2.81049</td>\n",
       "      <td>24.80947</td>\n",
       "      <td>-10.32877</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>6 rows × 91 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        Year      TA01      TA02      TA03      TA04      TA05      TA06  \\\n",
       "332595  2004  51.19589  18.54077  12.55200   3.25000  -7.79542 -19.63461   \n",
       "230573  1989  51.44573  53.42621  35.83483 -16.72867 -48.13185 -13.04248   \n",
       "364530  1987  44.85215  33.51052  27.44631  24.99783 -15.02508  12.29223   \n",
       "82857   2002  50.25653  59.83236  37.80210  -0.57762  -1.64064 -11.77884   \n",
       "108108  1971  45.48775  -5.49790   9.56187   9.36977 -11.15726  -3.63341   \n",
       "446568  2005  49.59492  35.40110  -8.11273 -13.40502   3.18931 -14.05923   \n",
       "\n",
       "            TA07     TA08      TA09  ...      TC69       TC70       TC71  \\\n",
       "332595   0.27662 -1.03994  -1.57231  ... -16.43458  -37.05054 -100.93186   \n",
       "230573 -45.74081 -6.18791  16.60100  ...   9.94220  -71.82280   59.40250   \n",
       "364530  10.57365  7.06412  -1.29649  ...   5.05131   56.78609  136.98499   \n",
       "82857   -2.25574 -4.17007   4.05922  ...  -4.12555   21.34620   29.72172   \n",
       "108108  11.00297 -6.36722   5.37455  ... -29.83405  228.02525   53.68190   \n",
       "446568   3.04436 -1.43375   3.60354  ...  -0.40689  -91.80271  -29.51019   \n",
       "\n",
       "            TC72      TC73       TC74      TC75      TC76       TC77      TC78  \n",
       "332595  49.26529  -1.47467  -39.25996 -22.44389 -16.37939  -49.67664   3.87491  \n",
       "230573  45.14201  -2.72343   35.62292  37.99939   7.04862   21.08678   2.07779  \n",
       "364530 -30.38374  24.51034   62.77834  44.80304  30.34866  105.17991   2.58183  \n",
       "82857   56.71419   2.61590   56.47152 -26.05716  -0.77059   29.40943  -0.02311  \n",
       "108108 -47.38609  14.62809  124.90797 -26.61476   5.08838  295.42035  19.74883  \n",
       "446568  -6.40710   6.99983   81.51023 -83.44656   2.81049   24.80947 -10.32877  \n",
       "\n",
       "[6 rows x 91 columns]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_df.sample(6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(463715, 91)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.465923Z",
     "start_time": "2023-05-14T19:59:35.408593Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['Year', 'TA01', 'TA02', 'TA03', 'TA04', 'TA05', 'TA06', 'TA07', 'TA08',\n",
       "       'TA09', 'TA10', 'TA11', 'TA12', 'TC01', 'TC02', 'TC03', 'TC04', 'TC05',\n",
       "       'TC06', 'TC07', 'TC08', 'TC09', 'TC10', 'TC11', 'TC12', 'TC13', 'TC14',\n",
       "       'TC15', 'TC16', 'TC17', 'TC18', 'TC19', 'TC20', 'TC21', 'TC22', 'TC23',\n",
       "       'TC24', 'TC25', 'TC26', 'TC27', 'TC28', 'TC29', 'TC30', 'TC31', 'TC32',\n",
       "       'TC33', 'TC34', 'TC35', 'TC36', 'TC37', 'TC38', 'TC39', 'TC40', 'TC41',\n",
       "       'TC42', 'TC43', 'TC44', 'TC45', 'TC46', 'TC47', 'TC48', 'TC49', 'TC50',\n",
       "       'TC51', 'TC52', 'TC53', 'TC54', 'TC55', 'TC56', 'TC57', 'TC58', 'TC59',\n",
       "       'TC60', 'TC61', 'TC62', 'TC63', 'TC64', 'TC65', 'TC66', 'TC67', 'TC68',\n",
       "       'TC69', 'TC70', 'TC71', 'TC72', 'TC73', 'TC74', 'TC75', 'TC76', 'TC77',\n",
       "       'TC78'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.772301Z",
     "start_time": "2023-05-14T19:59:35.617641Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "train, test = train_test_split(raw_df, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.772529Z",
     "start_time": "2023-05-14T19:59:35.642783Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "tr_feats = train.drop(\"Year\", axis=1)\n",
    "tr_target = train[\"Year\"].values\n",
    "te_feats = test.drop(\"Year\", axis=1)\n",
    "te_target = test[\"Year\"].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_scaler = StandardScaler()\n",
    "tr_target = target_scaler.fit_transform(tr_target.reshape(-1, 1)).reshape(-1)\n",
    "te_target = target_scaler.transform(te_target.reshape(-1, 1)).reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.773052Z",
     "start_time": "2023-05-14T19:59:35.647130Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "quant_transform = QuantileTransformer(output_distribution='uniform',\n",
    "                                      n_quantiles=10000,\n",
    "                                      subsample=len(tr_feats),\n",
    "                                      random_state=42)\n",
    "X_train_qs = quant_transform.fit_transform(tr_feats)\n",
    "X_test_qs = quant_transform.transform(te_feats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.773208Z",
     "start_time": "2023-05-14T19:59:35.735974Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def train_spline_fm(embedding_dim: int, step_size: float, batch_size: int, num_knots: int, num_epochs: int,\n",
    "                     callback: Callable[[int, float], None]=None):\n",
    "    bs = BSplineTransformer(num_knots, 3)\n",
    "    tr_indices, tr_weights, tr_offsets, tr_fields = spline_transform_dataset(X_train_qs, bs)\n",
    "    te_indices, te_weights, te_offsets, te_fields = spline_transform_dataset(X_test_qs, bs)\n",
    "\n",
    "    num_fields = X_train_qs.shape[1]\n",
    "    num_embeddings = int(max(np.max(tr_indices), np.max(te_indices)) + 1)\n",
    "\n",
    "    train_ds = TensorDataset(\n",
    "        torch.tensor(tr_indices, dtype=torch.int64),\n",
    "        torch.tensor(tr_weights, dtype=torch.float32),\n",
    "        torch.tensor(tr_offsets, dtype=torch.int64),\n",
    "        torch.tensor(tr_fields, dtype=torch.int64),\n",
    "        torch.tensor(tr_target, dtype=torch.float32))\n",
    "\n",
    "    test_ds = TensorDataset(\n",
    "        torch.tensor(te_indices, dtype=torch.int64),\n",
    "        torch.tensor(te_weights, dtype=torch.float32),\n",
    "        torch.tensor(te_offsets, dtype=torch.int64),\n",
    "        torch.tensor(te_fields, dtype=torch.int64),\n",
    "        torch.tensor(te_target, dtype=torch.float32))\n",
    "\n",
    "\n",
    "    trainer = FMTrainer(embedding_dim, step_size, batch_size, num_epochs, callback)\n",
    "    return trainer.train(num_fields, num_embeddings, train_ds, test_ds, torch.nn.MSELoss(), device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T19:59:35.773274Z",
     "start_time": "2023-05-14T19:59:35.744532Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def train_spline_objective(trial: optuna.Trial):\n",
    "    embedding_dim = trial.suggest_int('embedding_dim', 1, 10)\n",
    "    step_size = trial.suggest_float('step_size', 1e-2, 0.5, log=True)\n",
    "    batch_size = trial.suggest_int('batch_size', 32, 256)\n",
    "    num_knots = trial.suggest_int('num_knots', 3, 48)\n",
    "    num_epochs = trial.suggest_int('num_epochs', 5, 15)\n",
    "\n",
    "    def callback(epoch: int, loss: float):\n",
    "        trial.report(math.sqrt(loss), epoch)\n",
    "        if trial.should_prune():\n",
    "            raise optuna.TrialPruned()\n",
    "\n",
    "    return math.sqrt(train_spline_fm(embedding_dim, step_size, batch_size, num_knots, num_epochs,\n",
    "                                     callback=callback))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-05-14T19:59:35.750585Z"
    },
    "collapsed": false,
    "is_executing": true,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m[I 2023-05-16 18:53:23,460]\u001b[0m A new study created in memory with name: splines\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 18:56:43,426]\u001b[0m Trial 0 finished with value: 1.2386136255730438 and parameters: {'embedding_dim': 4, 'step_size': 0.4123206532618726, 'batch_size': 196, 'num_knots': 30, 'num_epochs': 6}. Best is trial 0 with value: 1.2386136255730438.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:03:07,293]\u001b[0m Trial 1 finished with value: 9.578139736045964 and parameters: {'embedding_dim': 2, 'step_size': 0.012551115172973842, 'batch_size': 226, 'num_knots': 30, 'num_epochs': 12}. Best is trial 0 with value: 1.2386136255730438.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:06:49,945]\u001b[0m Trial 2 finished with value: 1.4629122193885566 and parameters: {'embedding_dim': 1, 'step_size': 0.44447541666908114, 'batch_size': 219, 'num_knots': 12, 'num_epochs': 7}. Best is trial 0 with value: 1.2386136255730438.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:13:15,304]\u001b[0m Trial 3 finished with value: 5.51720045029323 and parameters: {'embedding_dim': 2, 'step_size': 0.0328774741399112, 'batch_size': 150, 'num_knots': 22, 'num_epochs': 8}. Best is trial 0 with value: 1.2386136255730438.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:23:13,052]\u001b[0m Trial 4 finished with value: 2.7195761291430305 and parameters: {'embedding_dim': 7, 'step_size': 0.017258215396625, 'batch_size': 97, 'num_knots': 19, 'num_epochs': 10}. Best is trial 0 with value: 1.2386136255730438.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:24:34,830]\u001b[0m Trial 5 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:26:59,399]\u001b[0m Trial 6 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:39:46,136]\u001b[0m Trial 7 finished with value: 1.3228858834036774 and parameters: {'embedding_dim': 9, 'step_size': 0.032925293631105246, 'batch_size': 53, 'num_knots': 34, 'num_epochs': 9}. Best is trial 0 with value: 1.2386136255730438.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:52:35,118]\u001b[0m Trial 8 finished with value: 1.550295967639816 and parameters: {'embedding_dim': 2, 'step_size': 0.06938901412739397, 'batch_size': 39, 'num_knots': 44, 'num_epochs': 7}. Best is trial 0 with value: 1.2386136255730438.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:53:51,842]\u001b[0m Trial 9 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 19:57:07,263]\u001b[0m Trial 10 finished with value: 1.1966288679541184 and parameters: {'embedding_dim': 4, 'step_size': 0.38450727047302585, 'batch_size': 188, 'num_knots': 3, 'num_epochs': 5}. Best is trial 10 with value: 1.1966288679541184.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:00:16,413]\u001b[0m Trial 11 finished with value: 1.166368747320352 and parameters: {'embedding_dim': 4, 'step_size': 0.45819653342095856, 'batch_size': 192, 'num_knots': 6, 'num_epochs': 5}. Best is trial 11 with value: 1.166368747320352.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:03:34,162]\u001b[0m Trial 12 finished with value: 1.2247194669600276 and parameters: {'embedding_dim': 4, 'step_size': 0.20992321291982594, 'batch_size': 184, 'num_knots': 4, 'num_epochs': 5}. Best is trial 11 with value: 1.166368747320352.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:09:21,405]\u001b[0m Trial 13 finished with value: 1.0230244392652823 and parameters: {'embedding_dim': 5, 'step_size': 0.25231843076021476, 'batch_size': 256, 'num_knots': 3, 'num_epochs': 12}. Best is trial 13 with value: 1.0230244392652823.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:15:25,054]\u001b[0m Trial 14 finished with value: 1.1263353581741014 and parameters: {'embedding_dim': 5, 'step_size': 0.22060040200689754, 'batch_size': 256, 'num_knots': 12, 'num_epochs': 12}. Best is trial 13 with value: 1.0230244392652823.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:22:02,080]\u001b[0m Trial 15 finished with value: 1.0999773131545263 and parameters: {'embedding_dim': 6, 'step_size': 0.18312582558712856, 'batch_size': 256, 'num_knots': 13, 'num_epochs': 13}. Best is trial 13 with value: 1.0230244392652823.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:28:42,275]\u001b[0m Trial 16 finished with value: 1.0410950936336398 and parameters: {'embedding_dim': 10, 'step_size': 0.1475792084364477, 'batch_size': 247, 'num_knots': 12, 'num_epochs': 13}. Best is trial 13 with value: 1.0230244392652823.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:30:11,072]\u001b[0m Trial 17 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:42:32,391]\u001b[0m Trial 18 finished with value: 0.9837241670350367 and parameters: {'embedding_dim': 10, 'step_size': 0.11027358472172152, 'batch_size': 101, 'num_knots': 8, 'num_epochs': 13}. Best is trial 18 with value: 0.9837241670350367.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 20:52:31,862]\u001b[0m Trial 19 finished with value: 1.0135361538325767 and parameters: {'embedding_dim': 9, 'step_size': 0.09653067289462239, 'batch_size': 106, 'num_knots': 8, 'num_epochs': 11}. Best is trial 18 with value: 0.9837241670350367.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:02:01,178]\u001b[0m Trial 20 finished with value: 1.0474427804363677 and parameters: {'embedding_dim': 9, 'step_size': 0.08886733479770562, 'batch_size': 100, 'num_knots': 9, 'num_epochs': 10}. Best is trial 18 with value: 0.9837241670350367.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:03:17,069]\u001b[0m Trial 21 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:12:46,324]\u001b[0m Trial 22 finished with value: 0.9591751996469187 and parameters: {'embedding_dim': 8, 'step_size': 0.2688719862751585, 'batch_size': 77, 'num_knots': 16, 'num_epochs': 11}. Best is trial 22 with value: 0.9591751996469187.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:21:46,922]\u001b[0m Trial 23 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:32:40,408]\u001b[0m Trial 24 finished with value: 1.013309950705173 and parameters: {'embedding_dim': 10, 'step_size': 0.1381692732684226, 'batch_size': 122, 'num_knots': 20, 'num_epochs': 14}. Best is trial 22 with value: 0.9591751996469187.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:44:32,521]\u001b[0m Trial 25 finished with value: 0.9766103503901942 and parameters: {'embedding_dim': 10, 'step_size': 0.30666487353085764, 'batch_size': 120, 'num_knots': 22, 'num_epochs': 14}. Best is trial 22 with value: 0.9591751996469187.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 21:57:07,523]\u001b[0m Trial 26 finished with value: 0.9631696687473242 and parameters: {'embedding_dim': 8, 'step_size': 0.30221758573737095, 'batch_size': 74, 'num_knots': 26, 'num_epochs': 14}. Best is trial 22 with value: 0.9591751996469187.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:08:48,665]\u001b[0m Trial 27 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:22:53,824]\u001b[0m Trial 28 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:27:42,916]\u001b[0m Trial 29 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:33:49,484]\u001b[0m Trial 30 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:44:19,069]\u001b[0m Trial 31 finished with value: 0.9559326717458044 and parameters: {'embedding_dim': 10, 'step_size': 0.1845435818497936, 'batch_size': 83, 'num_knots': 15, 'num_epochs': 13}. Best is trial 31 with value: 0.9559326717458044.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:45:30,091]\u001b[0m Trial 32 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-16 22:59:50,871]\u001b[0m Trial 33 finished with value: 0.9604147753141788 and parameters: {'embedding_dim': 10, 'step_size': 0.288261536780601, 'batch_size': 57, 'num_knots': 22, 'num_epochs': 13}. Best is trial 31 with value: 0.9559326717458044.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:15:48,223]\u001b[0m Trial 34 finished with value: 0.9742608985868826 and parameters: {'embedding_dim': 8, 'step_size': 0.17867376960796208, 'batch_size': 59, 'num_knots': 26, 'num_epochs': 13}. Best is trial 31 with value: 0.9559326717458044.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:37:28,738]\u001b[0m Trial 35 finished with value: 0.9171861274877934 and parameters: {'embedding_dim': 9, 'step_size': 0.2645542020430195, 'batch_size': 32, 'num_knots': 19, 'num_epochs': 10}. Best is trial 35 with value: 0.9171861274877934.\u001b[0m\n",
      "\u001b[32m[I 2023-05-16 23:55:05,064]\u001b[0m Trial 36 finished with value: 1.0132628334036484 and parameters: {'embedding_dim': 9, 'step_size': 0.40545523268469424, 'batch_size': 35, 'num_knots': 20, 'num_epochs': 9}. Best is trial 35 with value: 0.9171861274877934.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:05:42,393]\u001b[0m Trial 37 finished with value: 0.9345916139732143 and parameters: {'embedding_dim': 10, 'step_size': 0.25690624184860333, 'batch_size': 59, 'num_knots': 14, 'num_epochs': 10}. Best is trial 35 with value: 0.9171861274877934.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:20:05,720]\u001b[0m Trial 38 finished with value: 0.9352145948828553 and parameters: {'embedding_dim': 9, 'step_size': 0.22319832065362558, 'batch_size': 49, 'num_knots': 14, 'num_epochs': 10}. Best is trial 35 with value: 0.9171861274877934.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:26:44,288]\u001b[0m Trial 39 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:41:30,166]\u001b[0m Trial 40 finished with value: 0.9663647234913916 and parameters: {'embedding_dim': 9, 'step_size': 0.15557203352959853, 'batch_size': 48, 'num_knots': 18, 'num_epochs': 10}. Best is trial 35 with value: 0.9171861274877934.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:49:39,157]\u001b[0m Trial 41 finished with value: 0.9422524159401244 and parameters: {'embedding_dim': 10, 'step_size': 0.24528852921151278, 'batch_size': 63, 'num_knots': 10, 'num_epochs': 8}. Best is trial 35 with value: 0.9171861274877934.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 00:58:35,285]\u001b[0m Trial 42 finished with value: 0.9372499450353899 and parameters: {'embedding_dim': 10, 'step_size': 0.2395050154883088, 'batch_size': 57, 'num_knots': 10, 'num_epochs': 8}. Best is trial 35 with value: 0.9171861274877934.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:06:42,836]\u001b[0m Trial 43 finished with value: 0.9408245901462334 and parameters: {'embedding_dim': 10, 'step_size': 0.24135658451497335, 'batch_size': 64, 'num_knots': 10, 'num_epochs': 8}. Best is trial 35 with value: 0.9171861274877934.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:18:11,857]\u001b[0m Trial 44 finished with value: 0.9583412080247737 and parameters: {'embedding_dim': 9, 'step_size': 0.39856936883703, 'batch_size': 42, 'num_knots': 10, 'num_epochs': 8}. Best is trial 35 with value: 0.9171861274877934.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:25:08,180]\u001b[0m Trial 45 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:30:13,286]\u001b[0m Trial 46 finished with value: 0.9433470225660775 and parameters: {'embedding_dim': 9, 'step_size': 0.3592001787790354, 'batch_size': 63, 'num_knots': 6, 'num_epochs': 6}. Best is trial 35 with value: 0.9171861274877934.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:42:03,625]\u001b[0m Trial 47 finished with value: 0.931486584866254 and parameters: {'embedding_dim': 10, 'step_size': 0.225118064987353, 'batch_size': 43, 'num_knots': 14, 'num_epochs': 9}. Best is trial 35 with value: 0.9171861274877934.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:54:11,688]\u001b[0m Trial 48 finished with value: 0.9828667060204228 and parameters: {'embedding_dim': 10, 'step_size': 0.45874685485020067, 'batch_size': 42, 'num_knots': 14, 'num_epochs': 9}. Best is trial 35 with value: 0.9171861274877934.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 01:58:13,715]\u001b[0m Trial 49 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 02:09:46,741]\u001b[0m Trial 50 finished with value: 0.9406328625524666 and parameters: {'embedding_dim': 9, 'step_size': 0.20765904624110154, 'batch_size': 50, 'num_knots': 13, 'num_epochs': 10}. Best is trial 35 with value: 0.9171861274877934.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 02:21:07,688]\u001b[0m Trial 51 finished with value: 0.9405742149964395 and parameters: {'embedding_dim': 9, 'step_size': 0.1969866703171254, 'batch_size': 51, 'num_knots': 13, 'num_epochs': 10}. Best is trial 35 with value: 0.9171861274877934.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 02:36:19,135]\u001b[0m Trial 52 finished with value: 0.8842285075359627 and parameters: {'embedding_dim': 10, 'step_size': 0.2645915336156327, 'batch_size': 33, 'num_knots': 5, 'num_epochs': 9}. Best is trial 52 with value: 0.8842285075359627.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 02:50:02,113]\u001b[0m Trial 53 finished with value: 0.8876547893343609 and parameters: {'embedding_dim': 10, 'step_size': 0.2727717956425, 'batch_size': 36, 'num_knots': 5, 'num_epochs': 9}. Best is trial 52 with value: 0.8842285075359627.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 03:04:02,753]\u001b[0m Trial 54 finished with value: 0.8829253470615361 and parameters: {'embedding_dim': 10, 'step_size': 0.2762817857390965, 'batch_size': 35, 'num_knots': 4, 'num_epochs': 9}. Best is trial 54 with value: 0.8829253470615361.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 03:19:30,440]\u001b[0m Trial 55 finished with value: 0.8949441081981341 and parameters: {'embedding_dim': 10, 'step_size': 0.2824114142172617, 'batch_size': 32, 'num_knots': 5, 'num_epochs': 9}. Best is trial 54 with value: 0.8829253470615361.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 03:34:35,882]\u001b[0m Trial 56 finished with value: 0.8915999914951418 and parameters: {'embedding_dim': 10, 'step_size': 0.40545429992131593, 'batch_size': 33, 'num_knots': 3, 'num_epochs': 9}. Best is trial 54 with value: 0.8829253470615361.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 03:45:22,948]\u001b[0m Trial 57 finished with value: 0.9174837819882632 and parameters: {'embedding_dim': 10, 'step_size': 0.42181265669918006, 'batch_size': 36, 'num_knots': 5, 'num_epochs': 7}. Best is trial 54 with value: 0.8829253470615361.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:01:08,351]\u001b[0m Trial 58 finished with value: 0.8812943291843862 and parameters: {'embedding_dim': 10, 'step_size': 0.36073913832745863, 'batch_size': 32, 'num_knots': 3, 'num_epochs': 9}. Best is trial 58 with value: 0.8812943291843862.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:02:48,388]\u001b[0m Trial 59 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:03:31,415]\u001b[0m Trial 60 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:18:57,676]\u001b[0m Trial 61 finished with value: 0.8945438434097501 and parameters: {'embedding_dim': 9, 'step_size': 0.2875851804909289, 'batch_size': 33, 'num_knots': 6, 'num_epochs': 9}. Best is trial 58 with value: 0.8812943291843862.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:34:24,387]\u001b[0m Trial 62 finished with value: 0.9080907608582166 and parameters: {'embedding_dim': 10, 'step_size': 0.2859064435511358, 'batch_size': 33, 'num_knots': 6, 'num_epochs': 9}. Best is trial 58 with value: 0.8812943291843862.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 04:46:47,936]\u001b[0m Trial 63 finished with value: 0.923213213745785 and parameters: {'embedding_dim': 10, 'step_size': 0.3574927238399868, 'batch_size': 41, 'num_knots': 8, 'num_epochs': 9}. Best is trial 58 with value: 0.8812943291843862.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:02:15,375]\u001b[0m Trial 64 finished with value: 0.9303926950883815 and parameters: {'embedding_dim': 9, 'step_size': 0.4984824913019317, 'batch_size': 32, 'num_knots': 5, 'num_epochs': 9}. Best is trial 58 with value: 0.8812943291843862.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:12:20,767]\u001b[0m Trial 65 finished with value: 0.8836246232120878 and parameters: {'embedding_dim': 10, 'step_size': 0.2918598742790804, 'batch_size': 45, 'num_knots': 3, 'num_epochs': 8}. Best is trial 58 with value: 0.8812943291843862.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:13:23,409]\u001b[0m Trial 66 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:23:17,634]\u001b[0m Trial 67 finished with value: 0.9203880084785563 and parameters: {'embedding_dim': 10, 'step_size': 0.2978806762773378, 'batch_size': 46, 'num_knots': 7, 'num_epochs': 8}. Best is trial 58 with value: 0.8812943291843862.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:24:07,220]\u001b[0m Trial 68 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:25:26,541]\u001b[0m Trial 69 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:38:54,831]\u001b[0m Trial 70 finished with value: 0.8885545925677477 and parameters: {'embedding_dim': 6, 'step_size': 0.3377290795161348, 'batch_size': 40, 'num_knots': 4, 'num_epochs': 8}. Best is trial 58 with value: 0.8812943291843862.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:53:02,508]\u001b[0m Trial 71 finished with value: 0.898234601537262 and parameters: {'embedding_dim': 6, 'step_size': 0.32131412729525716, 'batch_size': 39, 'num_knots': 4, 'num_epochs': 8}. Best is trial 58 with value: 0.8812943291843862.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:55:20,996]\u001b[0m Trial 72 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 05:57:42,907]\u001b[0m Trial 73 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:09:46,746]\u001b[0m Trial 74 finished with value: 0.9336780206801858 and parameters: {'embedding_dim': 10, 'step_size': 0.3363325348006269, 'batch_size': 38, 'num_knots': 8, 'num_epochs': 7}. Best is trial 58 with value: 0.8812943291843862.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:10:46,772]\u001b[0m Trial 75 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:23:23,832]\u001b[0m Trial 76 finished with value: 0.9035483242637248 and parameters: {'embedding_dim': 10, 'step_size': 0.3023068642066013, 'batch_size': 46, 'num_knots': 6, 'num_epochs': 8}. Best is trial 58 with value: 0.8812943291843862.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:34:00,400]\u001b[0m Trial 77 finished with value: 0.9356443477813984 and parameters: {'embedding_dim': 8, 'step_size': 0.3707428489721444, 'batch_size': 53, 'num_knots': 8, 'num_epochs': 9}. Best is trial 58 with value: 0.8812943291843862.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:40:20,237]\u001b[0m Trial 78 finished with value: 0.8893062800372801 and parameters: {'embedding_dim': 9, 'step_size': 0.27560158762743114, 'batch_size': 66, 'num_knots': 3, 'num_epochs': 10}. Best is trial 58 with value: 0.8812943291843862.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:48:23,467]\u001b[0m Trial 79 finished with value: 0.8840000301227305 and parameters: {'embedding_dim': 10, 'step_size': 0.2621309491504174, 'batch_size': 68, 'num_knots': 3, 'num_epochs': 11}. Best is trial 58 with value: 0.8812943291843862.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:49:38,654]\u001b[0m Trial 80 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 06:50:44,419]\u001b[0m Trial 81 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 07:01:59,621]\u001b[0m Trial 82 finished with value: 0.8794329950964015 and parameters: {'embedding_dim': 10, 'step_size': 0.24654623010907453, 'batch_size': 68, 'num_knots': 3, 'num_epochs': 12}. Best is trial 82 with value: 0.8794329950964015.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 07:03:01,386]\u001b[0m Trial 83 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 07:04:27,287]\u001b[0m Trial 84 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 07:05:37,113]\u001b[0m Trial 85 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 07:06:37,147]\u001b[0m Trial 86 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 07:07:52,641]\u001b[0m Trial 87 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 07:10:41,882]\u001b[0m Trial 88 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 07:12:02,122]\u001b[0m Trial 89 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 07:13:00,456]\u001b[0m Trial 90 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 07:25:04,550]\u001b[0m Trial 91 finished with value: 0.8899198718915379 and parameters: {'embedding_dim': 10, 'step_size': 0.38574848982787274, 'batch_size': 37, 'num_knots': 3, 'num_epochs': 8}. Best is trial 82 with value: 0.8794329950964015.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 07:36:52,813]\u001b[0m Trial 92 finished with value: 0.8794658998359767 and parameters: {'embedding_dim': 10, 'step_size': 0.24069912203863186, 'batch_size': 38, 'num_knots': 3, 'num_epochs': 8}. Best is trial 82 with value: 0.8794329950964015.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 07:44:42,769]\u001b[0m Trial 93 finished with value: 0.9155397785613877 and parameters: {'embedding_dim': 10, 'step_size': 0.2388690985342639, 'batch_size': 53, 'num_knots': 6, 'num_epochs': 7}. Best is trial 82 with value: 0.8794329950964015.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 07:45:35,447]\u001b[0m Trial 94 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 07:58:02,324]\u001b[0m Trial 95 finished with value: 0.9002698705631274 and parameters: {'embedding_dim': 9, 'step_size': 0.22749984739763673, 'batch_size': 46, 'num_knots': 5, 'num_epochs': 10}. Best is trial 82 with value: 0.8794329950964015.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 08:00:11,700]\u001b[0m Trial 96 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 08:08:32,642]\u001b[0m Trial 97 finished with value: 0.8877704114380222 and parameters: {'embedding_dim': 9, 'step_size': 0.2560666131932691, 'batch_size': 71, 'num_knots': 3, 'num_epochs': 11}. Best is trial 82 with value: 0.8794329950964015.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 08:09:38,192]\u001b[0m Trial 98 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 08:14:32,944]\u001b[0m Trial 99 pruned. \u001b[0m\n"
     ]
    }
   ],
   "source": [
    "study = optuna.create_study(study_name='splines',\n",
    "                            direction='minimize',\n",
    "                            sampler=optuna.samplers.TPESampler(seed=42))\n",
    "study.optimize(train_spline_objective, n_trials=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false,
    "is_executing": true,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test loss: 0.8794329950964015\n",
      "Best hyperparameters: {'embedding_dim': 10, 'step_size': 0.24654623010907453, 'batch_size': 68, 'num_knots': 3, 'num_epochs': 12}\n"
     ]
    }
   ],
   "source": [
    "trial = study.best_trial\n",
    "\n",
    "print('Test loss: {}'.format(trial.value))\n",
    "print(\"Best hyperparameters: {}\".format(trial.params))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": false,
    "is_executing": true,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'embedding_dim': 10,\n",
       " 'step_size': 0.24654623010907453,\n",
       " 'batch_size': 68,\n",
       " 'num_knots': 3,\n",
       " 'num_epochs': 12}"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "study.best_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": false,
    "is_executing": true,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7787345051765442"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_spline_fm(**study.best_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [2:39:59<00:00, 479.95s/it]  \n"
     ]
    }
   ],
   "source": [
    "spline_losses = []\n",
    "for i in trange(20):\n",
    "    loss = train_spline_fm(**study.best_params)\n",
    "    spline_losses.append(math.sqrt(loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.881695640108332,\n",
       " 0.881986518583829,\n",
       " 0.8812350129821337,\n",
       " 0.8835068731661306,\n",
       " 0.8756028550892196,\n",
       " 0.8795998443241684,\n",
       " 0.8812374141146075,\n",
       " 0.8793456949910461,\n",
       " 0.879686949816137,\n",
       " 0.8792196435620432,\n",
       " 0.8800452079323017,\n",
       " 0.8800887905106133,\n",
       " 0.8787611778058435,\n",
       " 0.8798512104892138,\n",
       " 0.8808272328401603,\n",
       " 0.8801772358534317,\n",
       " 0.8815416286395354,\n",
       " 0.8799189857558083,\n",
       " 0.8787488669310817,\n",
       " 0.8843083494702224]"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "spline_losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.880369256648293, 0.0018009996668723587, 0.2045732121234053)"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(spline_losses), np.std(spline_losses), 100 * np.std(spline_losses) / np.mean(spline_losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": false,
    "is_executing": true,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def train_bin_fm(embedding_dim: int, step_size: float, batch_size: int,\n",
    "                  num_bins: int, bin_strategy: str, num_epochs: int,\n",
    "                  callback: Callable[[int, float], None]=None):\n",
    "    num_fields = tr_feats.shape[1]\n",
    "    num_embeddings = num_fields * num_bins\n",
    "    index_offsets = np.arange(0, num_fields) * num_bins\n",
    "\n",
    "    discretizer = KBinsDiscretizer(num_bins, encode='ordinal', strategy=bin_strategy, random_state=42)\n",
    "    discretizer.fit(tr_feats)\n",
    "\n",
    "    tr_indices = discretizer.transform(tr_feats)\n",
    "    tr_indices += np.tile(index_offsets, (tr_indices.shape[0], 1))\n",
    "    tr_weights = np.ones_like(tr_indices)\n",
    "    tr_fields = np.tile(np.arange(0, num_fields), (tr_indices.shape[0], 1))\n",
    "    tr_offsets = tr_fields.copy()\n",
    "\n",
    "    te_indices = discretizer.transform(te_feats)\n",
    "    te_indices += np.tile(index_offsets, (te_indices.shape[0], 1))\n",
    "    te_weights = np.ones_like(te_indices)\n",
    "    te_fields = np.tile(np.arange(0, num_fields), (te_indices.shape[0], 1))\n",
    "    te_offsets = te_fields.copy()\n",
    "\n",
    "    train_ds = TensorDataset(\n",
    "        torch.tensor(tr_indices, dtype=torch.int64),\n",
    "        torch.tensor(tr_weights, dtype=torch.float32),\n",
    "        torch.tensor(tr_offsets, dtype=torch.int64),\n",
    "        torch.tensor(tr_fields, dtype=torch.int64),\n",
    "        torch.tensor(tr_target, dtype=torch.float32))\n",
    "\n",
    "    test_ds = TensorDataset(\n",
    "        torch.tensor(te_indices, dtype=torch.int64),\n",
    "        torch.tensor(te_weights, dtype=torch.float32),\n",
    "        torch.tensor(te_offsets, dtype=torch.int64),\n",
    "        torch.tensor(te_fields, dtype=torch.int64),\n",
    "        torch.tensor(te_target, dtype=torch.float32))\n",
    "\n",
    "    trainer = FMTrainer(embedding_dim, step_size, batch_size, num_epochs, callback)\n",
    "    return trainer.train(num_fields, num_embeddings, train_ds, test_ds, torch.nn.MSELoss(), device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": false,
    "is_executing": true,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def test_bins_objective(trial: optuna.Trial):\n",
    "    embedding_dim = trial.suggest_int('embedding_dim', 1, 10)\n",
    "    step_size = trial.suggest_float('step_size', 1e-2, 0.5, log=True)\n",
    "    batch_size = trial.suggest_int('batch_size', 32, 256)\n",
    "    num_bins = trial.suggest_int('num_bins', 2, 100)\n",
    "    bin_strategy = trial.suggest_categorical('bin_strategy', ['uniform', 'quantile'])\n",
    "    num_epochs = trial.suggest_int('num_epochs', 5, 15)\n",
    "\n",
    "    def callback(epoch: int, loss: float):\n",
    "        trial.report(math.sqrt(loss), epoch)\n",
    "        if trial.should_prune():\n",
    "            raise optuna.TrialPruned()\n",
    "\n",
    "    return math.sqrt(train_bin_fm(embedding_dim, step_size, batch_size, num_bins, bin_strategy, num_epochs,\n",
    "                                  callback=callback))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": false,
    "is_executing": true,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m[I 2023-05-17 08:23:55,739]\u001b[0m A new study created in memory with name: bins\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 08:25:38,113]\u001b[0m Trial 0 finished with value: 1.5271531398304155 and parameters: {'embedding_dim': 4, 'step_size': 0.4123206532618726, 'batch_size': 196, 'num_bins': 61, 'bin_strategy': 'uniform', 'num_epochs': 5}. Best is trial 0 with value: 1.5271531398304155.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 08:28:04,884]\u001b[0m Trial 1 finished with value: 0.9448509350530763 and parameters: {'embedding_dim': 9, 'step_size': 0.10502105436744279, 'batch_size': 191, 'num_bins': 4, 'bin_strategy': 'uniform', 'num_epochs': 7}. Best is trial 1 with value: 0.9448509350530763.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 08:34:45,404]\u001b[0m Trial 2 finished with value: 5.603986016331611 and parameters: {'embedding_dim': 2, 'step_size': 0.020492680115417352, 'batch_size': 100, 'num_bins': 53, 'bin_strategy': 'uniform', 'num_epochs': 11}. Best is trial 1 with value: 0.9448509350530763.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 08:40:16,860]\u001b[0m Trial 3 finished with value: 3.487825336116525 and parameters: {'embedding_dim': 2, 'step_size': 0.03135775732257745, 'batch_size': 114, 'num_bins': 47, 'bin_strategy': 'uniform', 'num_epochs': 10}. Best is trial 1 with value: 0.9448509350530763.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:209: FutureWarning: In version 1.3 onwards, subsample=2e5 will be used by default. Set subsample explicitly to silence this warning in the mean time. Set subsample=None to disable subsampling explicitly.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 08:45:51,720]\u001b[0m Trial 4 finished with value: 5.078675976720081 and parameters: {'embedding_dim': 6, 'step_size': 0.011992724522955167, 'batch_size': 168, 'num_bins': 18, 'bin_strategy': 'quantile', 'num_epochs': 15}. Best is trial 1 with value: 0.9448509350530763.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 08:56:25,390]\u001b[0m Trial 5 finished with value: 1.5473596603529345 and parameters: {'embedding_dim': 9, 'step_size': 0.032925293631105246, 'batch_size': 53, 'num_bins': 69, 'bin_strategy': 'uniform', 'num_epochs': 10}. Best is trial 1 with value: 0.9448509350530763.\u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:209: FutureWarning: In version 1.3 onwards, subsample=2e5 will be used by default. Set subsample explicitly to silence this warning in the mean time. Set subsample=None to disable subsampling explicitly.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 09:03:33,049]\u001b[0m Trial 6 finished with value: 1.2903412458488763 and parameters: {'embedding_dim': 1, 'step_size': 0.35067764992972184, 'batch_size': 90, 'num_bins': 67, 'bin_strategy': 'quantile', 'num_epochs': 11}. Best is trial 1 with value: 0.9448509350530763.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 09:08:53,444]\u001b[0m Trial 7 finished with value: 1.204044238579979 and parameters: {'embedding_dim': 2, 'step_size': 0.4439102767051397, 'batch_size': 206, 'num_bins': 95, 'bin_strategy': 'uniform', 'num_epochs': 15}. Best is trial 1 with value: 0.9448509350530763.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 09:10:11,410]\u001b[0m Trial 8 pruned. \u001b[0m\n",
      "\u001b[32m[I 2023-05-17 09:10:36,489]\u001b[0m Trial 9 pruned. \u001b[0m\n",
      "/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:209: FutureWarning: In version 1.3 onwards, subsample=2e5 will be used by default. Set subsample explicitly to silence this warning in the mean time. Set subsample=None to disable subsampling explicitly.\n",
      "  warnings.warn(\n",
      "\u001b[32m[I 2023-05-17 09:12:05,260]\u001b[0m Trial 10 finished with value: 1.3519945281047807 and parameters: {'embedding_dim': 10, 'step_size': 0.12983523609376665, 'batch_size': 250, 'num_bins': 3, 'bin_strategy': 'quantile', 'num_epochs': 5}. Best is trial 1 with value: 0.9448509350530763.\u001b[0m\n",
      "\u001b[32m[I 2023-05-17 09:14:21,947]\u001b[0m Trial 11 finished with value: 1.5510122285726062 and parameters: {'embedding_dim': 8, 'step_size': 0.1376872218807011, 'batch_size': 216, 'num_bins': 100, 'bin_strategy': 'uniform', 'num_epochs': 7}. Best is trial 1 with value: 0.9448509350530763.\u001b[0m\n",
      "\u001b[33m[W 2023-05-17 09:14:22,281]\u001b[0m Trial 12 failed with parameters: {'embedding_dim': 7, 'step_size': 0.21367114396906817, 'batch_size': 202, 'num_bins': 95, 'bin_strategy': 'uniform', 'num_epochs': 8} because of the following error: KeyboardInterrupt().\u001b[0m\n",
      "Traceback (most recent call last):\n",
      "  File \"/usr/local/lib/python3.9/site-packages/optuna/study/_optimize.py\", line 200, in _run_trial\n",
      "    value_or_values = func(trial)\n",
      "  File \"/tmp/ipykernel_258955/2787357413.py\", line 14, in test_bins_objective\n",
      "    return math.sqrt(train_bin_fm(embedding_dim, step_size, batch_size, num_bins, bin_strategy, num_epochs,\n",
      "  File \"/tmp/ipykernel_258955/1864156579.py\", line 11, in train_bin_fm\n",
      "    tr_indices = discretizer.transform(tr_feats)\n",
      "  File \"/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py\", line 375, in transform\n",
      "    Xt[:, jj] = np.searchsorted(bin_edges[jj][1:-1], Xt[:, jj], side=\"right\")\n",
      "  File \"<__array_function__ internals>\", line 180, in searchsorted\n",
      "  File \"/usr/local/lib64/python3.9/site-packages/numpy/core/fromnumeric.py\", line 1387, in searchsorted\n",
      "    return _wrapfunc(a, 'searchsorted', v, side=side, sorter=sorter)\n",
      "  File \"/usr/local/lib64/python3.9/site-packages/numpy/core/fromnumeric.py\", line 57, in _wrapfunc\n",
      "    return bound(*args, **kwds)\n",
      "KeyboardInterrupt\n",
      "\u001b[33m[W 2023-05-17 09:14:22,283]\u001b[0m Trial 12 failed with value None.\u001b[0m\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Input \u001b[0;32mIn [21]\u001b[0m, in \u001b[0;36m<cell line: 4>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m study_bins \u001b[38;5;241m=\u001b[39m optuna\u001b[38;5;241m.\u001b[39mcreate_study(study_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbins\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m      2\u001b[0m                                  direction\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mminimize\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m      3\u001b[0m                                  sampler\u001b[38;5;241m=\u001b[39moptuna\u001b[38;5;241m.\u001b[39msamplers\u001b[38;5;241m.\u001b[39mTPESampler(seed\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m42\u001b[39m))\n\u001b[0;32m----> 4\u001b[0m \u001b[43mstudy_bins\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimize\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_bins_objective\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_trials\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m100\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.9/site-packages/optuna/study/study.py:425\u001b[0m, in \u001b[0;36mStudy.optimize\u001b[0;34m(self, func, n_trials, timeout, n_jobs, catch, callbacks, gc_after_trial, show_progress_bar)\u001b[0m\n\u001b[1;32m    321\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21moptimize\u001b[39m(\n\u001b[1;32m    322\u001b[0m     \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m    323\u001b[0m     func: ObjectiveFuncType,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    330\u001b[0m     show_progress_bar: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m    331\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    332\u001b[0m     \u001b[38;5;124;03m\"\"\"Optimize an objective function.\u001b[39;00m\n\u001b[1;32m    333\u001b[0m \n\u001b[1;32m    334\u001b[0m \u001b[38;5;124;03m    Optimization is done by choosing a suitable set of hyperparameter values from a given\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    422\u001b[0m \u001b[38;5;124;03m            If nested invocation of this method occurs.\u001b[39;00m\n\u001b[1;32m    423\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[0;32m--> 425\u001b[0m     \u001b[43m_optimize\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    426\u001b[0m \u001b[43m        \u001b[49m\u001b[43mstudy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m    427\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfunc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    428\u001b[0m \u001b[43m        \u001b[49m\u001b[43mn_trials\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_trials\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    429\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    430\u001b[0m \u001b[43m        \u001b[49m\u001b[43mn_jobs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_jobs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    431\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mtuple\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mcatch\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43misinstance\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mcatch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mIterable\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mcatch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    432\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcallbacks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    433\u001b[0m \u001b[43m        \u001b[49m\u001b[43mgc_after_trial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgc_after_trial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    434\u001b[0m \u001b[43m        \u001b[49m\u001b[43mshow_progress_bar\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mshow_progress_bar\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    435\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.9/site-packages/optuna/study/_optimize.py:66\u001b[0m, in \u001b[0;36m_optimize\u001b[0;34m(study, func, n_trials, timeout, n_jobs, catch, callbacks, gc_after_trial, show_progress_bar)\u001b[0m\n\u001b[1;32m     64\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m     65\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m n_jobs \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m---> 66\u001b[0m         \u001b[43m_optimize_sequential\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     67\u001b[0m \u001b[43m            \u001b[49m\u001b[43mstudy\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     68\u001b[0m \u001b[43m            \u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     69\u001b[0m \u001b[43m            \u001b[49m\u001b[43mn_trials\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     70\u001b[0m \u001b[43m            \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     71\u001b[0m \u001b[43m            \u001b[49m\u001b[43mcatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     72\u001b[0m \u001b[43m            \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     73\u001b[0m \u001b[43m            \u001b[49m\u001b[43mgc_after_trial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     74\u001b[0m \u001b[43m            \u001b[49m\u001b[43mreseed_sampler_rng\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m     75\u001b[0m \u001b[43m            \u001b[49m\u001b[43mtime_start\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m     76\u001b[0m \u001b[43m            \u001b[49m\u001b[43mprogress_bar\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprogress_bar\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     77\u001b[0m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     78\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m     79\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m n_jobs \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m:\n",
      "File \u001b[0;32m/usr/local/lib/python3.9/site-packages/optuna/study/_optimize.py:163\u001b[0m, in \u001b[0;36m_optimize_sequential\u001b[0;34m(study, func, n_trials, timeout, catch, callbacks, gc_after_trial, reseed_sampler_rng, time_start, progress_bar)\u001b[0m\n\u001b[1;32m    160\u001b[0m         \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[1;32m    162\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 163\u001b[0m     frozen_trial \u001b[38;5;241m=\u001b[39m \u001b[43m_run_trial\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstudy\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    164\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m    165\u001b[0m     \u001b[38;5;66;03m# The following line mitigates memory problems that can be occurred in some\u001b[39;00m\n\u001b[1;32m    166\u001b[0m     \u001b[38;5;66;03m# environments (e.g., services that use computing containers such as GitHub Actions).\u001b[39;00m\n\u001b[1;32m    167\u001b[0m     \u001b[38;5;66;03m# Please refer to the following PR for further details:\u001b[39;00m\n\u001b[1;32m    168\u001b[0m     \u001b[38;5;66;03m# https://github.com/optuna/optuna/pull/325.\u001b[39;00m\n\u001b[1;32m    169\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m gc_after_trial:\n",
      "File \u001b[0;32m/usr/local/lib/python3.9/site-packages/optuna/study/_optimize.py:251\u001b[0m, in \u001b[0;36m_run_trial\u001b[0;34m(study, func, catch)\u001b[0m\n\u001b[1;32m    244\u001b[0m         \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mShould not reach.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    246\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m    247\u001b[0m     frozen_trial\u001b[38;5;241m.\u001b[39mstate \u001b[38;5;241m==\u001b[39m TrialState\u001b[38;5;241m.\u001b[39mFAIL\n\u001b[1;32m    248\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m func_err \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m    249\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(func_err, catch)\n\u001b[1;32m    250\u001b[0m ):\n\u001b[0;32m--> 251\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m func_err\n\u001b[1;32m    252\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m frozen_trial\n",
      "File \u001b[0;32m/usr/local/lib/python3.9/site-packages/optuna/study/_optimize.py:200\u001b[0m, in \u001b[0;36m_run_trial\u001b[0;34m(study, func, catch)\u001b[0m\n\u001b[1;32m    198\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m get_heartbeat_thread(trial\u001b[38;5;241m.\u001b[39m_trial_id, study\u001b[38;5;241m.\u001b[39m_storage):\n\u001b[1;32m    199\u001b[0m     \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 200\u001b[0m         value_or_values \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    201\u001b[0m     \u001b[38;5;28;01mexcept\u001b[39;00m exceptions\u001b[38;5;241m.\u001b[39mTrialPruned \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m    202\u001b[0m         \u001b[38;5;66;03m# TODO(mamu): Handle multi-objective cases.\u001b[39;00m\n\u001b[1;32m    203\u001b[0m         state \u001b[38;5;241m=\u001b[39m TrialState\u001b[38;5;241m.\u001b[39mPRUNED\n",
      "Input \u001b[0;32mIn [20]\u001b[0m, in \u001b[0;36mtest_bins_objective\u001b[0;34m(trial)\u001b[0m\n\u001b[1;32m     11\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m trial\u001b[38;5;241m.\u001b[39mshould_prune():\n\u001b[1;32m     12\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m optuna\u001b[38;5;241m.\u001b[39mTrialPruned()\n\u001b[0;32m---> 14\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m math\u001b[38;5;241m.\u001b[39msqrt(\u001b[43mtrain_bin_fm\u001b[49m\u001b[43m(\u001b[49m\u001b[43membedding_dim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstep_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_bins\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbin_strategy\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     15\u001b[0m \u001b[43m                              \u001b[49m\u001b[43mcallback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcallback\u001b[49m\u001b[43m)\u001b[49m)\n",
      "Input \u001b[0;32mIn [19]\u001b[0m, in \u001b[0;36mtrain_bin_fm\u001b[0;34m(embedding_dim, step_size, batch_size, num_bins, bin_strategy, num_epochs, callback)\u001b[0m\n\u001b[1;32m      8\u001b[0m discretizer \u001b[38;5;241m=\u001b[39m KBinsDiscretizer(num_bins, encode\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mordinal\u001b[39m\u001b[38;5;124m'\u001b[39m, strategy\u001b[38;5;241m=\u001b[39mbin_strategy, random_state\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m42\u001b[39m)\n\u001b[1;32m      9\u001b[0m discretizer\u001b[38;5;241m.\u001b[39mfit(tr_feats)\n\u001b[0;32m---> 11\u001b[0m tr_indices \u001b[38;5;241m=\u001b[39m \u001b[43mdiscretizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransform\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtr_feats\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     12\u001b[0m tr_indices \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mtile(index_offsets, (tr_indices\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;241m1\u001b[39m))\n\u001b[1;32m     13\u001b[0m tr_weights \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mones_like(tr_indices)\n",
      "File \u001b[0;32m/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:375\u001b[0m, in \u001b[0;36mKBinsDiscretizer.transform\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m    373\u001b[0m bin_edges \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbin_edges_\n\u001b[1;32m    374\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m jj \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(Xt\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m1\u001b[39m]):\n\u001b[0;32m--> 375\u001b[0m     Xt[:, jj] \u001b[38;5;241m=\u001b[39m \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msearchsorted\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbin_edges\u001b[49m\u001b[43m[\u001b[49m\u001b[43mjj\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m:\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mXt\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjj\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mside\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mright\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m    377\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mencode \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mordinal\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m    378\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m Xt\n",
      "File \u001b[0;32m<__array_function__ internals>:180\u001b[0m, in \u001b[0;36msearchsorted\u001b[0;34m(*args, **kwargs)\u001b[0m\n",
      "File \u001b[0;32m/usr/local/lib64/python3.9/site-packages/numpy/core/fromnumeric.py:1387\u001b[0m, in \u001b[0;36msearchsorted\u001b[0;34m(a, v, side, sorter)\u001b[0m\n\u001b[1;32m   1319\u001b[0m \u001b[38;5;129m@array_function_dispatch\u001b[39m(_searchsorted_dispatcher)\n\u001b[1;32m   1320\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msearchsorted\u001b[39m(a, v, side\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mleft\u001b[39m\u001b[38;5;124m'\u001b[39m, sorter\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m   1321\u001b[0m     \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m   1322\u001b[0m \u001b[38;5;124;03m    Find indices where elements should be inserted to maintain order.\u001b[39;00m\n\u001b[1;32m   1323\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1385\u001b[0m \n\u001b[1;32m   1386\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1387\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_wrapfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43msearchsorted\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mside\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mside\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msorter\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msorter\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib64/python3.9/site-packages/numpy/core/fromnumeric.py:57\u001b[0m, in \u001b[0;36m_wrapfunc\u001b[0;34m(obj, method, *args, **kwds)\u001b[0m\n\u001b[1;32m     54\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m _wrapit(obj, method, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds)\n\u001b[1;32m     56\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 57\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mbound\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     58\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m     59\u001b[0m     \u001b[38;5;66;03m# A TypeError occurs if the object does have such a method in its\u001b[39;00m\n\u001b[1;32m     60\u001b[0m     \u001b[38;5;66;03m# class, but its signature is not identical to that of NumPy's. This\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     64\u001b[0m     \u001b[38;5;66;03m# Call _wrapit from within the except clause to ensure a potential\u001b[39;00m\n\u001b[1;32m     65\u001b[0m     \u001b[38;5;66;03m# exception has a traceback chain.\u001b[39;00m\n\u001b[1;32m     66\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m _wrapit(obj, method, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "study_bins = optuna.create_study(study_name='bins',\n",
    "                                 direction='minimize',\n",
    "                                 sampler=optuna.samplers.TPESampler(seed=42))\n",
    "study_bins.optimize(test_bins_objective, n_trials=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "is_executing": true,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "study_bins.best_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "is_executing": true,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "trial = study_bins.best_trial\n",
    "\n",
    "print('Test loss: {}'.format(trial.value))\n",
    "print(\"Best hyperparameters: {}\".format(trial.params))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "is_executing": true,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "train_bin_fm(**study_bins.best_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bin_losses = []\n",
    "for i in trange(20):\n",
    "    loss = train_bin_fm(**study_bins.best_params)\n",
    "    bin_losses.append(math.sqrt(loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bin_losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(bin_losses), np.std(bin_losses)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
