In [1]:
import sys

sys.path.insert(0, "../utils")
In [24]:
import sklearn.datasets as skds
from sklearn.preprocessing import QuantileTransformer, KBinsDiscretizer, StandardScaler
from sklearn.model_selection import train_test_split
from weighted_fm import WeightedFFM, WeightedFM
from trainers import FFMTrainer
from transformation import BSplineTransformer, spline_transform_dataset
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
import math
import optuna
import optuna.samplers
from typing import Callable
from tqdm import tqdm, trange
In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
print(device)
cuda:0
In [4]:
torch.manual_seed(42)
np.random.seed(42)
In [5]:
import tarfile
import joblib
In [6]:
with tarfile.open(mode="r:gz", name="../data/cal_housing.tgz") as f:
    cal_housing = np.loadtxt(
        f.extractfile("CaliforniaHousing/cal_housing.data"), delimiter=","
    )
    # Columns are not in the same order compared to the previous
    # URL resource on lib.stat.cmu.edu
    columns_index = [8, 7, 2, 3, 4, 5, 6, 1, 0]
    cal_housing = cal_housing[:, columns_index]

    joblib.dump(cal_housing, "../data/cal_housing_py3.pkz", compress=6)
In [7]:
ds = skds.fetch_california_housing(data_home="../data")
In [8]:
X_train, X_test, y_train, y_test = train_test_split(ds['data'], ds['target'], test_size=0.2, random_state=42)
In [9]:
target_scaler = StandardScaler()
y_train = target_scaler.fit_transform(y_train.reshape(-1, 1)).reshape(-1)
y_test = target_scaler.transform(y_test.reshape(-1, 1)).reshape(-1)
In [10]:
quant_transform = QuantileTransformer(output_distribution='uniform',
                                      n_quantiles=10000,
                                      subsample=len(X_train),
                                      random_state=42)
X_train_qs = quant_transform.fit_transform(X_train)
X_test_qs = quant_transform.transform(X_test)
In [28]:
def train_spline_ffm(embedding_dim: int, step_size: float, batch_size: int, num_knots: int, num_epochs: int,
                     callback: Callable[[int, float], None]=None):
    bs = BSplineTransformer(num_knots, 3)
    train_indices, train_weights, train_offsets, train_fields = spline_transform_dataset(X_train_qs, bs)
    test_indices, test_weights, test_offsets, test_fields = spline_transform_dataset(X_test_qs, bs)
    num_fields = ds['data'].shape[1]
    num_embeddings = bs.basis_size() * num_fields

    train_ds = TensorDataset(
        torch.tensor(train_indices, dtype=torch.int64),
        torch.tensor(train_weights, dtype=torch.float32),
        torch.tensor(train_offsets, dtype=torch.int64),
        torch.tensor(train_fields, dtype=torch.int64),
        torch.tensor(y_train, dtype=torch.float32))
    test_ds = TensorDataset(
        torch.tensor(test_indices, dtype=torch.int64),
        torch.tensor(test_weights, dtype=torch.float32),
        torch.tensor(test_offsets, dtype=torch.int64),
        torch.tensor(test_fields, dtype=torch.int64),
        torch.tensor(y_test, dtype=torch.float32))

    criterion = torch.nn.MSELoss()
    trainer = FFMTrainer(embedding_dim, step_size, batch_size, num_epochs, callback)
    return trainer.train(num_fields, num_embeddings, train_ds, test_ds, criterion, device)
In [12]:
def train_spline_objective(trial: optuna.Trial):
    embedding_dim = trial.suggest_int('embedding_dim', 1, 10)
    step_size = trial.suggest_float('step_size', 1e-2, 0.5, log=True)
    batch_size = trial.suggest_int('batch_size', 2, 32)
    num_knots = trial.suggest_int('num_knots', 3, 48)
    num_epochs = trial.suggest_int('num_epochs', 5, 15)

    def callback(epoch: int, loss: float):
        rmse = math.sqrt(loss)
        trial.report(rmse, epoch)
        if trial.should_prune():
            raise optuna.TrialPruned()

    mse = train_spline_ffm(embedding_dim, step_size, batch_size, num_knots, num_epochs,
                           callback=callback)
    return math.sqrt(mse)
In [13]:
study = optuna.create_study(study_name='splines',
                            direction='minimize',
                            sampler=optuna.samplers.TPESampler(seed=42))
study.optimize(train_spline_objective, n_trials=100)
[I 2023-05-16 19:06:04,290] A new study created in memory with name: splines
[I 2023-05-16 19:07:06,616] Trial 0 finished with value: 0.4704895326349986 and parameters: {'embedding_dim': 4, 'step_size': 0.4123206532618726, 'batch_size': 24, 'num_knots': 30, 'num_epochs': 6}. Best is trial 0 with value: 0.4704895326349986.
[I 2023-05-16 19:08:50,430] Trial 1 finished with value: 0.5324508054601986 and parameters: {'embedding_dim': 2, 'step_size': 0.012551115172973842, 'batch_size': 28, 'num_knots': 30, 'num_epochs': 12}. Best is trial 0 with value: 0.4704895326349986.
[I 2023-05-16 19:09:57,194] Trial 2 finished with value: 0.45800390831558263 and parameters: {'embedding_dim': 1, 'step_size': 0.44447541666908114, 'batch_size': 27, 'num_knots': 12, 'num_epochs': 7}. Best is trial 2 with value: 0.45800390831558263.
[I 2023-05-16 19:11:51,841] Trial 3 finished with value: 0.46901124985568104 and parameters: {'embedding_dim': 2, 'step_size': 0.0328774741399112, 'batch_size': 18, 'num_knots': 22, 'num_epochs': 8}. Best is trial 2 with value: 0.45800390831558263.
[I 2023-05-16 19:15:35,305] Trial 4 finished with value: 0.4475541496014591 and parameters: {'embedding_dim': 7, 'step_size': 0.017258215396625, 'batch_size': 11, 'num_knots': 19, 'num_epochs': 10}. Best is trial 4 with value: 0.4475541496014591.
[I 2023-05-16 19:16:53,720] Trial 5 finished with value: 0.45202152127010237 and parameters: {'embedding_dim': 8, 'step_size': 0.021839352923182977, 'batch_size': 17, 'num_knots': 30, 'num_epochs': 5}. Best is trial 4 with value: 0.4475541496014591.
[I 2023-05-16 19:20:13,198] Trial 6 pruned. 
[I 2023-05-16 19:26:56,554] Trial 7 finished with value: 0.4374097662924263 and parameters: {'embedding_dim': 9, 'step_size': 0.032925293631105246, 'batch_size': 5, 'num_knots': 34, 'num_epochs': 9}. Best is trial 7 with value: 0.4374097662924263.
[I 2023-05-16 19:27:54,756] Trial 8 pruned. 
[I 2023-05-16 19:29:23,192] Trial 9 finished with value: 0.4333838736728957 and parameters: {'embedding_dim': 7, 'step_size': 0.033852267834519785, 'batch_size': 18, 'num_knots': 28, 'num_epochs': 7}. Best is trial 9 with value: 0.4333838736728957.
[I 2023-05-16 19:29:37,723] Trial 10 pruned. 
[I 2023-05-16 19:32:24,187] Trial 11 finished with value: 0.4360523626300907 and parameters: {'embedding_dim': 10, 'step_size': 0.04337690983089577, 'batch_size': 12, 'num_knots': 37, 'num_epochs': 9}. Best is trial 9 with value: 0.4333838736728957.
[I 2023-05-16 19:34:57,499] Trial 12 pruned. 
[I 2023-05-16 19:35:09,175] Trial 13 pruned. 
[I 2023-05-16 19:37:00,063] Trial 14 finished with value: 0.44198953018367365 and parameters: {'embedding_dim': 5, 'step_size': 0.04195290839392635, 'batch_size': 10, 'num_knots': 15, 'num_epochs': 5}. Best is trial 9 with value: 0.4333838736728957.
[I 2023-05-16 19:40:59,893] Trial 15 finished with value: 0.44867919363263115 and parameters: {'embedding_dim': 8, 'step_size': 0.07625186899625158, 'batch_size': 13, 'num_knots': 25, 'num_epochs': 15}. Best is trial 9 with value: 0.4333838736728957.
[I 2023-05-16 19:44:05,914] Trial 16 pruned. 
[I 2023-05-16 19:44:19,552] Trial 17 pruned. 
[I 2023-05-16 19:44:29,726] Trial 18 pruned. 
[I 2023-05-16 19:46:10,377] Trial 19 pruned. 
[I 2023-05-16 19:50:45,295] Trial 20 finished with value: 0.4304292518785922 and parameters: {'embedding_dim': 7, 'step_size': 0.04499014800396729, 'batch_size': 7, 'num_knots': 25, 'num_epochs': 9}. Best is trial 20 with value: 0.4304292518785922.
[I 2023-05-16 19:55:19,993] Trial 21 finished with value: 0.4317104576548742 and parameters: {'embedding_dim': 7, 'step_size': 0.04748562470151251, 'batch_size': 7, 'num_knots': 25, 'num_epochs': 9}. Best is trial 20 with value: 0.4304292518785922.
[I 2023-05-16 19:59:56,006] Trial 22 finished with value: 0.43209864623593314 and parameters: {'embedding_dim': 7, 'step_size': 0.05467652110187646, 'batch_size': 7, 'num_knots': 25, 'num_epochs': 8}. Best is trial 20 with value: 0.4304292518785922.
[I 2023-05-16 20:05:31,702] Trial 23 finished with value: 0.43518819708811474 and parameters: {'embedding_dim': 5, 'step_size': 0.061380216632269086, 'batch_size': 7, 'num_knots': 23, 'num_epochs': 10}. Best is trial 20 with value: 0.4304292518785922.
[I 2023-05-16 20:21:14,196] Trial 24 finished with value: 0.4282563738885848 and parameters: {'embedding_dim': 8, 'step_size': 0.050729607059104745, 'batch_size': 2, 'num_knots': 18, 'num_epochs': 8}. Best is trial 24 with value: 0.4282563738885848.
[I 2023-05-16 20:22:23,305] Trial 25 pruned. 
[I 2023-05-16 20:28:34,067] Trial 26 pruned. 
[I 2023-05-16 20:32:23,398] Trial 27 finished with value: 0.4342443722594684 and parameters: {'embedding_dim': 8, 'step_size': 0.04633322962813203, 'batch_size': 6, 'num_knots': 20, 'num_epochs': 6}. Best is trial 24 with value: 0.4282563738885848.
[I 2023-05-16 20:36:20,024] Trial 28 finished with value: 0.4303967258133841 and parameters: {'embedding_dim': 4, 'step_size': 0.10498851807328576, 'batch_size': 9, 'num_knots': 17, 'num_epochs': 8}. Best is trial 24 with value: 0.4282563738885848.
[I 2023-05-16 20:36:51,318] Trial 29 pruned. 
[I 2023-05-16 20:54:01,502] Trial 30 finished with value: 0.432384126802336 and parameters: {'embedding_dim': 3, 'step_size': 0.2224714974302765, 'batch_size': 2, 'num_knots': 17, 'num_epochs': 8}. Best is trial 24 with value: 0.4282563738885848.
[I 2023-05-16 20:57:04,089] Trial 31 pruned. 
[I 2023-05-16 20:57:58,437] Trial 32 pruned. 
[I 2023-05-16 21:03:33,448] Trial 33 finished with value: 0.4318301617578173 and parameters: {'embedding_dim': 6, 'step_size': 0.05610785831672759, 'batch_size': 6, 'num_knots': 28, 'num_epochs': 8}. Best is trial 24 with value: 0.4282563738885848.
[I 2023-05-16 21:08:11,614] Trial 34 finished with value: 0.4385583881926396 and parameters: {'embedding_dim': 7, 'step_size': 0.10397008430678391, 'batch_size': 8, 'num_knots': 18, 'num_epochs': 10}. Best is trial 24 with value: 0.4282563738885848.
[I 2023-05-16 21:08:32,897] Trial 35 pruned. 
[I 2023-05-16 21:09:17,581] Trial 36 pruned. 
[I 2023-05-16 21:10:13,339] Trial 37 pruned. 
[I 2023-05-16 21:10:28,002] Trial 38 pruned. 
[I 2023-05-16 21:13:17,265] Trial 39 finished with value: 0.4332768080903596 and parameters: {'embedding_dim': 7, 'step_size': 0.049927006498105705, 'batch_size': 10, 'num_knots': 31, 'num_epochs': 8}. Best is trial 24 with value: 0.4282563738885848.
[I 2023-05-16 21:13:27,097] Trial 40 pruned. 
[I 2023-05-16 21:14:04,662] Trial 41 pruned. 
[I 2023-05-16 21:14:37,254] Trial 42 pruned. 
[I 2023-05-16 21:22:56,495] Trial 43 finished with value: 0.4355067418514952 and parameters: {'embedding_dim': 7, 'step_size': 0.06736094208217702, 'batch_size': 4, 'num_knots': 31, 'num_epochs': 9}. Best is trial 24 with value: 0.4282563738885848.
[I 2023-05-16 21:25:10,548] Trial 44 pruned. 
[I 2023-05-16 21:34:48,334] Trial 45 finished with value: 0.428379686852566 and parameters: {'embedding_dim': 6, 'step_size': 0.05663408617479354, 'batch_size': 4, 'num_knots': 26, 'num_epochs': 9}. Best is trial 24 with value: 0.4282563738885848.
[I 2023-05-16 21:35:53,613] Trial 46 pruned. 
[I 2023-05-16 21:37:25,006] Trial 47 finished with value: 0.4282606362493427 and parameters: {'embedding_dim': 7, 'step_size': 0.07608335371533785, 'batch_size': 28, 'num_knots': 26, 'num_epochs': 9}. Best is trial 24 with value: 0.4282563738885848.
[I 2023-05-16 21:37:35,335] Trial 48 pruned. 
[I 2023-05-16 21:38:55,705] Trial 49 finished with value: 0.4345063751848249 and parameters: {'embedding_dim': 8, 'step_size': 0.0681880972757457, 'batch_size': 31, 'num_knots': 23, 'num_epochs': 9}. Best is trial 24 with value: 0.4282563738885848.
[I 2023-05-16 21:41:42,130] Trial 50 finished with value: 0.44851676160691284 and parameters: {'embedding_dim': 7, 'step_size': 0.09456589098657536, 'batch_size': 20, 'num_knots': 26, 'num_epochs': 12}. Best is trial 24 with value: 0.4282563738885848.
[I 2023-05-16 21:41:53,916] Trial 51 pruned. 
[I 2023-05-16 21:45:02,791] Trial 52 finished with value: 0.42911772664965314 and parameters: {'embedding_dim': 6, 'step_size': 0.0598936840450053, 'batch_size': 13, 'num_knots': 21, 'num_epochs': 10}. Best is trial 24 with value: 0.4282563738885848.
[I 2023-05-16 21:45:19,028] Trial 53 pruned. 
[I 2023-05-16 21:45:39,786] Trial 54 pruned. 
[I 2023-05-16 21:45:56,382] Trial 55 pruned. 
[I 2023-05-16 21:46:08,651] Trial 56 pruned. 
[I 2023-05-16 21:46:34,431] Trial 57 pruned. 
[I 2023-05-16 21:46:42,515] Trial 58 pruned. 
[I 2023-05-16 21:50:12,016] Trial 59 finished with value: 0.4367239404141295 and parameters: {'embedding_dim': 8, 'step_size': 0.10688913438019201, 'batch_size': 12, 'num_knots': 15, 'num_epochs': 11}. Best is trial 24 with value: 0.4282563738885848.
[I 2023-05-16 21:51:25,554] Trial 60 pruned. 
[I 2023-05-16 21:56:17,668] Trial 61 finished with value: 0.4293793353510158 and parameters: {'embedding_dim': 7, 'step_size': 0.052279480161766326, 'batch_size': 7, 'num_knots': 25, 'num_epochs': 9}. Best is trial 24 with value: 0.4282563738885848.
[I 2023-05-16 22:00:51,493] Trial 62 finished with value: 0.42994092923668487 and parameters: {'embedding_dim': 7, 'step_size': 0.05748379333415699, 'batch_size': 8, 'num_knots': 21, 'num_epochs': 10}. Best is trial 24 with value: 0.4282563738885848.
[I 2023-05-16 22:05:34,761] Trial 63 finished with value: 0.42706043678253136 and parameters: {'embedding_dim': 7, 'step_size': 0.05868656247354761, 'batch_size': 8, 'num_knots': 22, 'num_epochs': 10}. Best is trial 63 with value: 0.42706043678253136.
[I 2023-05-16 22:10:31,959] Trial 64 finished with value: 0.43217209422359965 and parameters: {'embedding_dim': 7, 'step_size': 0.05987477205374452, 'batch_size': 8, 'num_knots': 21, 'num_epochs': 11}. Best is trial 63 with value: 0.42706043678253136.
[I 2023-05-16 22:17:56,158] Trial 65 finished with value: 0.4350829454793019 and parameters: {'embedding_dim': 8, 'step_size': 0.05256357897612862, 'batch_size': 5, 'num_knots': 27, 'num_epochs': 10}. Best is trial 63 with value: 0.42706043678253136.
[I 2023-05-16 22:30:35,599] Trial 66 finished with value: 0.434589308419973 and parameters: {'embedding_dim': 7, 'step_size': 0.07823068649760494, 'batch_size': 3, 'num_knots': 22, 'num_epochs': 10}. Best is trial 63 with value: 0.42706043678253136.
[I 2023-05-16 22:31:27,696] Trial 67 pruned. 
[I 2023-05-16 22:31:42,320] Trial 68 pruned. 
[I 2023-05-16 22:36:04,117] Trial 69 finished with value: 0.42859933221611907 and parameters: {'embedding_dim': 6, 'step_size': 0.055681143765690076, 'batch_size': 10, 'num_knots': 24, 'num_epochs': 11}. Best is trial 63 with value: 0.42706043678253136.
[I 2023-05-16 22:40:13,326] Trial 70 finished with value: 0.4299545152028061 and parameters: {'embedding_dim': 6, 'step_size': 0.05054589162131087, 'batch_size': 10, 'num_knots': 24, 'num_epochs': 11}. Best is trial 63 with value: 0.42706043678253136.
[I 2023-05-16 22:40:31,967] Trial 71 pruned. 
[I 2023-05-16 22:45:04,936] Trial 72 finished with value: 0.4300412888997274 and parameters: {'embedding_dim': 6, 'step_size': 0.06533761080426782, 'batch_size': 8, 'num_knots': 22, 'num_epochs': 10}. Best is trial 63 with value: 0.42706043678253136.
[I 2023-05-16 22:45:38,758] Trial 73 pruned. 
[I 2023-05-16 22:48:56,131] Trial 74 finished with value: 0.44036213426531196 and parameters: {'embedding_dim': 7, 'step_size': 0.08362911227618235, 'batch_size': 10, 'num_knots': 25, 'num_epochs': 9}. Best is trial 63 with value: 0.42706043678253136.
[I 2023-05-16 22:49:33,812] Trial 75 pruned. 
[I 2023-05-16 22:53:03,111] Trial 76 finished with value: 0.43593629611269524 and parameters: {'embedding_dim': 9, 'step_size': 0.07164141513191531, 'batch_size': 11, 'num_knots': 18, 'num_epochs': 10}. Best is trial 63 with value: 0.42706043678253136.
[I 2023-05-16 22:53:31,764] Trial 77 pruned. 
[I 2023-05-16 22:56:48,594] Trial 78 finished with value: 0.42568818221361704 and parameters: {'embedding_dim': 7, 'step_size': 0.052749809167358816, 'batch_size': 14, 'num_knots': 21, 'num_epochs': 12}. Best is trial 78 with value: 0.42568818221361704.
[I 2023-05-16 22:57:05,323] Trial 79 pruned. 
[I 2023-05-16 22:57:20,072] Trial 80 pruned. 
[I 2023-05-16 22:57:38,013] Trial 81 pruned. 
[I 2023-05-16 22:58:10,526] Trial 82 pruned. 
[I 2023-05-16 22:59:42,180] Trial 83 pruned. 
[I 2023-05-16 23:01:19,935] Trial 84 pruned. 
[I 2023-05-16 23:01:41,489] Trial 85 pruned. 
[I 2023-05-16 23:01:50,579] Trial 86 pruned. 
[I 2023-05-16 23:05:15,605] Trial 87 finished with value: 0.4339707428227233 and parameters: {'embedding_dim': 8, 'step_size': 0.06642377759575374, 'batch_size': 12, 'num_knots': 24, 'num_epochs': 11}. Best is trial 78 with value: 0.42568818221361704.
[I 2023-05-16 23:08:33,489] Trial 88 finished with value: 0.42828476563061135 and parameters: {'embedding_dim': 6, 'step_size': 0.0782610810110209, 'batch_size': 9, 'num_knots': 26, 'num_epochs': 8}. Best is trial 78 with value: 0.42568818221361704.
[I 2023-05-16 23:09:04,670] Trial 89 pruned. 
[I 2023-05-16 23:16:16,226] Trial 90 finished with value: 0.4314792037783839 and parameters: {'embedding_dim': 6, 'step_size': 0.07535244582825283, 'batch_size': 4, 'num_knots': 26, 'num_epochs': 8}. Best is trial 78 with value: 0.42568818221361704.
[I 2023-05-16 23:16:40,833] Trial 91 pruned. 
[I 2023-05-16 23:17:12,669] Trial 92 pruned. 
[I 2023-05-16 23:20:33,451] Trial 93 finished with value: 0.4322334637200849 and parameters: {'embedding_dim': 7, 'step_size': 0.06619156418718045, 'batch_size': 11, 'num_knots': 23, 'num_epochs': 10}. Best is trial 78 with value: 0.42568818221361704.
[I 2023-05-16 23:20:58,149] Trial 94 pruned. 
[I 2023-05-16 23:22:56,512] Trial 95 finished with value: 0.4317012761514909 and parameters: {'embedding_dim': 6, 'step_size': 0.07919381789878355, 'batch_size': 15, 'num_knots': 21, 'num_epochs': 8}. Best is trial 78 with value: 0.42568818221361704.
[I 2023-05-16 23:24:09,483] Trial 96 pruned. 
[I 2023-05-16 23:25:31,914] Trial 97 pruned. 
[I 2023-05-16 23:26:07,732] Trial 98 pruned. 
[I 2023-05-16 23:33:25,357] Trial 99 finished with value: 0.43331236773293147 and parameters: {'embedding_dim': 8, 'step_size': 0.06415431672405765, 'batch_size': 5, 'num_knots': 25, 'num_epochs': 10}. Best is trial 78 with value: 0.42568818221361704.
In [14]:
trial = study.best_trial

print('Test loss: {}'.format(trial.value))
print("Best hyperparameters: {}".format(trial.params))
Test loss: 0.42568818221361704
Best hyperparameters: {'embedding_dim': 7, 'step_size': 0.052749809167358816, 'batch_size': 14, 'num_knots': 21, 'num_epochs': 12}
In [15]:
study.best_params
Out[15]:
{'embedding_dim': 7,
 'step_size': 0.052749809167358816,
 'batch_size': 14,
 'num_knots': 21,
 'num_epochs': 12}
In [29]:
spline_losses = []
for i in trange(20):
    mse = train_spline_ffm(**study.best_params)
    spline_losses.append(math.sqrt(mse))
100%|██████████| 20/20 [1:01:29<00:00, 184.47s/it]
In [34]:
spline_losses
Out[34]:
[0.42679040640311267,
 0.42634108249853464,
 0.4261458174783907,
 0.43225683700092105,
 0.431104387357407,
 0.4288870899249733,
 0.4277180881970324,
 0.4311970464012093,
 0.4338799640302308,
 0.4241321194256667,
 0.42784070300256505,
 0.43046144662743624,
 0.4316383465322712,
 0.42870344707227903,
 0.4286306736221006,
 0.4301517752661648,
 0.4326343059382102,
 0.428075541124546,
 0.42944209286761614,
 0.4320622969331268]
In [39]:
np.mean(spline_losses), 3 * np.std(spline_losses), np.mean(spline_losses) + 3 * np.std(spline_losses), np.mean(spline_losses) - 3 * np.std(spline_losses)
Out[39]:
(0.42940467338518973,
 0.007393634011334159,
 0.43679830739652387,
 0.4220110393738556)
In [30]:
def train_bin_ffm(embedding_dim: int, step_size: float, batch_size: int,
                  num_bins: int, bin_strategy: str, num_epochs: int,
                  callback: Callable[[int, float], None]=None):
    num_fields = X_train.shape[1]
    offsets = np.arange(0, num_fields) * num_bins

    discretizer = KBinsDiscretizer(num_bins, encode='ordinal', strategy=bin_strategy, random_state=42)
    discretizer.fit(X_train)

    indices_train = discretizer.transform(X_train)
    indices_train += np.tile(offsets, (X_train.shape[0], 1))
    weights_train = np.ones_like(indices_train)
    fields_train = np.tile(np.arange(0, num_fields), (X_train.shape[0], 1))

    indices_test = discretizer.transform(X_test)
    indices_test += np.tile(offsets, (X_test.shape[0], 1))
    weights_test = np.ones_like(indices_test)
    fields_test = np.tile(np.arange(0, num_fields), (X_test.shape[0], 1))


    num_embeddings = num_fields * num_bins
    train_ds = TensorDataset(
        torch.tensor(indices_train, dtype=torch.int64),
        torch.tensor(weights_train, dtype=torch.float32),
        torch.tensor(fields_train, dtype=torch.int64),
        torch.tensor(fields_train, dtype=torch.int64),
        torch.tensor(y_train, dtype=torch.float32))

    test_ds = TensorDataset(
        torch.tensor(indices_test, dtype=torch.int64),
        torch.tensor(weights_test, dtype=torch.float32),
        torch.tensor(fields_test, dtype=torch.int64),
        torch.tensor(fields_test, dtype=torch.int64),
        torch.tensor(y_test, dtype=torch.float32))

    trainer = FFMTrainer(embedding_dim, step_size, batch_size, num_epochs, callback)
    return trainer.train(num_fields, num_embeddings, train_ds, test_ds, torch.nn.MSELoss(), device)
In [18]:
def test_bins_objective(trial: optuna.Trial):
    embedding_dim = trial.suggest_int('embedding_dim', 1, 10)
    step_size = trial.suggest_float('step_size', 1e-2, 0.5, log=True)
    batch_size = trial.suggest_int('batch_size', 2, 32)
    num_bins = trial.suggest_int('num_bins', 2, 100)
    bin_strategy = trial.suggest_categorical('bin_strategy', ['uniform', 'quantile'])
    num_epochs = trial.suggest_int('num_epochs', 5, 15)

    def callback(epoch: int, mse: float):
        rmse = math.sqrt(mse)
        trial.report(rmse, epoch)
        if trial.should_prune():
            raise optuna.TrialPruned()

    mse = train_bin_ffm(embedding_dim, step_size, batch_size, num_bins, bin_strategy, num_epochs,
                        callback=callback)
    return math.sqrt(mse)
In [19]:
study_bins = optuna.create_study(study_name='bins',
                            direction='minimize',
                            sampler=optuna.samplers.TPESampler(seed=42))
study_bins.optimize(test_bins_objective, n_trials=100)
[I 2023-05-16 23:36:36,044] A new study created in memory with name: bins
[I 2023-05-16 23:37:23,402] Trial 0 finished with value: 0.5280771877157359 and parameters: {'embedding_dim': 4, 'step_size': 0.4123206532618726, 'batch_size': 24, 'num_bins': 61, 'bin_strategy': 'uniform', 'num_epochs': 5}. Best is trial 0 with value: 0.5280771877157359.
[I 2023-05-16 23:38:21,498] Trial 1 finished with value: 0.7002097411644254 and parameters: {'embedding_dim': 9, 'step_size': 0.10502105436744279, 'batch_size': 23, 'num_bins': 4, 'bin_strategy': 'uniform', 'num_epochs': 7}. Best is trial 0 with value: 0.5280771877157359.
[I 2023-05-16 23:42:02,363] Trial 2 finished with value: 0.5551966895575728 and parameters: {'embedding_dim': 2, 'step_size': 0.020492680115417352, 'batch_size': 11, 'num_bins': 53, 'bin_strategy': 'uniform', 'num_epochs': 11}. Best is trial 0 with value: 0.5280771877157359.
[I 2023-05-16 23:44:53,104] Trial 3 finished with value: 0.541053027020405 and parameters: {'embedding_dim': 2, 'step_size': 0.03135775732257745, 'batch_size': 13, 'num_bins': 47, 'bin_strategy': 'uniform', 'num_epochs': 10}. Best is trial 0 with value: 0.5280771877157359.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-16 23:47:41,642] Trial 4 finished with value: 0.5161959059594673 and parameters: {'embedding_dim': 6, 'step_size': 0.011992724522955167, 'batch_size': 20, 'num_bins': 18, 'bin_strategy': 'quantile', 'num_epochs': 15}. Best is trial 4 with value: 0.5161959059594673.
[I 2023-05-16 23:55:02,608] Trial 5 finished with value: 0.4970072048574563 and parameters: {'embedding_dim': 9, 'step_size': 0.032925293631105246, 'batch_size': 5, 'num_bins': 69, 'bin_strategy': 'uniform', 'num_epochs': 10}. Best is trial 5 with value: 0.4970072048574563.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-16 23:59:00,824] Trial 6 finished with value: 0.4909980418745024 and parameters: {'embedding_dim': 1, 'step_size': 0.35067764992972184, 'batch_size': 10, 'num_bins': 67, 'bin_strategy': 'quantile', 'num_epochs': 11}. Best is trial 6 with value: 0.4909980418745024.
[I 2023-05-17 00:01:14,977] Trial 7 pruned. 
[I 2023-05-17 00:02:29,776] Trial 8 pruned. 
[I 2023-05-17 00:03:20,527] Trial 9 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:03:28,015] Trial 10 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:05:20,936] Trial 11 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:05:44,832] Trial 12 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:06:08,219] Trial 13 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:09:05,756] Trial 14 finished with value: 0.502395642627435 and parameters: {'embedding_dim': 4, 'step_size': 0.05011263930396099, 'batch_size': 15, 'num_bins': 43, 'bin_strategy': 'quantile', 'num_epochs': 12}. Best is trial 6 with value: 0.4909980418745024.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:09:42,896] Trial 15 pruned. 
[I 2023-05-17 00:10:19,842] Trial 16 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 3 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:10:39,667] Trial 17 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:10:55,796] Trial 18 pruned. 
[I 2023-05-17 00:13:51,299] Trial 19 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:14:16,169] Trial 20 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:14:30,740] Trial 21 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:14:43,525] Trial 22 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:15:03,802] Trial 23 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:15:31,644] Trial 24 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 3 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:15:44,864] Trial 25 pruned. 
[I 2023-05-17 00:18:28,847] Trial 26 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:19:02,639] Trial 27 pruned. 
[I 2023-05-17 00:22:31,925] Trial 28 finished with value: 0.5196768477788142 and parameters: {'embedding_dim': 7, 'step_size': 0.14061965231931225, 'batch_size': 11, 'num_bins': 85, 'bin_strategy': 'uniform', 'num_epochs': 11}. Best is trial 6 with value: 0.4909980418745024.
[I 2023-05-17 00:23:55,180] Trial 29 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:24:09,656] Trial 30 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:24:20,135] Trial 31 pruned. 
[I 2023-05-17 00:25:03,259] Trial 32 pruned. 
[I 2023-05-17 00:26:44,073] Trial 33 finished with value: 0.4875798948549539 and parameters: {'embedding_dim': 6, 'step_size': 0.020684208264680124, 'batch_size': 24, 'num_bins': 11, 'bin_strategy': 'quantile', 'num_epochs': 11}. Best is trial 33 with value: 0.4875798948549539.
[I 2023-05-17 00:26:50,369] Trial 34 pruned. 
[I 2023-05-17 00:27:11,599] Trial 35 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:27:19,713] Trial 36 pruned. 
[I 2023-05-17 00:27:38,434] Trial 37 pruned. 
[I 2023-05-17 00:28:03,184] Trial 38 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:28:10,949] Trial 39 pruned. 
[I 2023-05-17 00:28:34,929] Trial 40 pruned. 
[I 2023-05-17 00:28:43,881] Trial 41 pruned. 
[I 2023-05-17 00:30:51,364] Trial 42 finished with value: 0.4912414844638768 and parameters: {'embedding_dim': 6, 'step_size': 0.02366328080027664, 'batch_size': 25, 'num_bins': 12, 'bin_strategy': 'quantile', 'num_epochs': 14}. Best is trial 33 with value: 0.4875798948549539.
[I 2023-05-17 00:32:38,553] Trial 43 finished with value: 0.4814953574890661 and parameters: {'embedding_dim': 6, 'step_size': 0.02353530079691863, 'batch_size': 30, 'num_bins': 11, 'bin_strategy': 'quantile', 'num_epochs': 14}. Best is trial 43 with value: 0.4814953574890661.
[I 2023-05-17 00:34:18,009] Trial 44 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:34:24,647] Trial 45 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:36:18,448] Trial 46 pruned. 
[I 2023-05-17 00:36:56,906] Trial 47 finished with value: 0.49383070805204027 and parameters: {'embedding_dim': 6, 'step_size': 0.025434555654167735, 'batch_size': 30, 'num_bins': 15, 'bin_strategy': 'quantile', 'num_epochs': 5}. Best is trial 43 with value: 0.4814953574890661.
[I 2023-05-17 00:37:04,635] Trial 48 pruned. 
[I 2023-05-17 00:37:53,269] Trial 49 finished with value: 0.48899480942348683 and parameters: {'embedding_dim': 6, 'step_size': 0.028209809445833324, 'batch_size': 28, 'num_bins': 15, 'bin_strategy': 'quantile', 'num_epochs': 6}. Best is trial 43 with value: 0.4814953574890661.
[I 2023-05-17 00:38:01,503] Trial 50 pruned. 
[I 2023-05-17 00:38:48,458] Trial 51 finished with value: 0.48648655467159996 and parameters: {'embedding_dim': 6, 'step_size': 0.027729848277686826, 'batch_size': 29, 'num_bins': 15, 'bin_strategy': 'quantile', 'num_epochs': 6}. Best is trial 43 with value: 0.4814953574890661.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:39:43,237] Trial 52 finished with value: 0.48784862426097797 and parameters: {'embedding_dim': 6, 'step_size': 0.028651818424917735, 'batch_size': 25, 'num_bins': 17, 'bin_strategy': 'quantile', 'num_epochs': 6}. Best is trial 43 with value: 0.4814953574890661.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:40:31,815] Trial 53 finished with value: 0.4868366998634225 and parameters: {'embedding_dim': 7, 'step_size': 0.030842284746007004, 'batch_size': 28, 'num_bins': 18, 'bin_strategy': 'quantile', 'num_epochs': 6}. Best is trial 43 with value: 0.4814953574890661.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:40:48,258] Trial 54 pruned. 
[I 2023-05-17 00:41:08,129] Trial 55 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:41:16,092] Trial 56 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:41:51,479] Trial 57 finished with value: 0.47951060024732495 and parameters: {'embedding_dim': 8, 'step_size': 0.03929383281600887, 'batch_size': 32, 'num_bins': 17, 'bin_strategy': 'quantile', 'num_epochs': 6}. Best is trial 57 with value: 0.47951060024732495.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:42:43,233] Trial 58 finished with value: 0.4845081577007928 and parameters: {'embedding_dim': 8, 'step_size': 0.04246678658964534, 'batch_size': 32, 'num_bins': 25, 'bin_strategy': 'quantile', 'num_epochs': 7}. Best is trial 57 with value: 0.47951060024732495.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:43:34,879] Trial 59 finished with value: 0.49187994233475485 and parameters: {'embedding_dim': 8, 'step_size': 0.04312872060983268, 'batch_size': 32, 'num_bins': 25, 'bin_strategy': 'quantile', 'num_epochs': 7}. Best is trial 57 with value: 0.47951060024732495.
[I 2023-05-17 00:44:35,589] Trial 60 finished with value: 0.4815114345371291 and parameters: {'embedding_dim': 8, 'step_size': 0.04837698369243848, 'batch_size': 31, 'num_bins': 11, 'bin_strategy': 'quantile', 'num_epochs': 8}. Best is trial 57 with value: 0.47951060024732495.
[I 2023-05-17 00:45:34,500] Trial 61 finished with value: 0.49070868780647986 and parameters: {'embedding_dim': 8, 'step_size': 0.049335977465450875, 'batch_size': 32, 'num_bins': 12, 'bin_strategy': 'quantile', 'num_epochs': 8}. Best is trial 57 with value: 0.47951060024732495.
[I 2023-05-17 00:45:42,276] Trial 62 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:46:20,071] Trial 63 finished with value: 0.491216049016063 and parameters: {'embedding_dim': 7, 'step_size': 0.03515178497648949, 'batch_size': 31, 'num_bins': 19, 'bin_strategy': 'quantile', 'num_epochs': 5}. Best is trial 57 with value: 0.47951060024732495.
[I 2023-05-17 00:47:08,678] Trial 64 finished with value: 0.4849947072261395 and parameters: {'embedding_dim': 8, 'step_size': 0.04805048391692695, 'batch_size': 29, 'num_bins': 11, 'bin_strategy': 'quantile', 'num_epochs': 6}. Best is trial 57 with value: 0.47951060024732495.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:47:57,305] Trial 65 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:48:50,313] Trial 66 finished with value: 0.476914166633194 and parameters: {'embedding_dim': 8, 'step_size': 0.03971654057761182, 'batch_size': 31, 'num_bins': 23, 'bin_strategy': 'quantile', 'num_epochs': 7}. Best is trial 66 with value: 0.476914166633194.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:49:13,452] Trial 67 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:50:01,305] Trial 68 finished with value: 0.482084411088513 and parameters: {'embedding_dim': 9, 'step_size': 0.06121422232448351, 'batch_size': 29, 'num_bins': 23, 'bin_strategy': 'quantile', 'num_epochs': 7}. Best is trial 66 with value: 0.476914166633194.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:50:54,972] Trial 69 finished with value: 0.4851239628268377 and parameters: {'embedding_dim': 10, 'step_size': 0.06349724505888375, 'batch_size': 31, 'num_bins': 22, 'bin_strategy': 'quantile', 'num_epochs': 7}. Best is trial 66 with value: 0.476914166633194.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:51:46,909] Trial 70 finished with value: 0.48660264444334267 and parameters: {'embedding_dim': 9, 'step_size': 0.056829217265857324, 'batch_size': 32, 'num_bins': 26, 'bin_strategy': 'quantile', 'num_epochs': 7}. Best is trial 66 with value: 0.476914166633194.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:52:40,173] Trial 71 finished with value: 0.48210371390904005 and parameters: {'embedding_dim': 10, 'step_size': 0.06862730361201547, 'batch_size': 31, 'num_bins': 23, 'bin_strategy': 'quantile', 'num_epochs': 7}. Best is trial 66 with value: 0.476914166633194.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:53:13,224] Trial 72 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:54:06,179] Trial 73 finished with value: 0.47648125878431646 and parameters: {'embedding_dim': 8, 'step_size': 0.0527994326893901, 'batch_size': 31, 'num_bins': 23, 'bin_strategy': 'quantile', 'num_epochs': 7}. Best is trial 73 with value: 0.47648125878431646.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:54:13,921] Trial 74 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:55:14,666] Trial 75 finished with value: 0.4832143413749891 and parameters: {'embedding_dim': 8, 'step_size': 0.05312163758261429, 'batch_size': 27, 'num_bins': 24, 'bin_strategy': 'quantile', 'num_epochs': 7}. Best is trial 73 with value: 0.47648125878431646.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:55:40,298] Trial 76 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:56:43,399] Trial 77 finished with value: 0.48646381124588367 and parameters: {'embedding_dim': 10, 'step_size': 0.06784387379069022, 'batch_size': 30, 'num_bins': 23, 'bin_strategy': 'quantile', 'num_epochs': 8}. Best is trial 73 with value: 0.47648125878431646.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:56:52,544] Trial 78 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:58:01,068] Trial 79 finished with value: 0.49141271856808333 and parameters: {'embedding_dim': 9, 'step_size': 0.0831574285124721, 'batch_size': 31, 'num_bins': 19, 'bin_strategy': 'quantile', 'num_epochs': 9}. Best is trial 73 with value: 0.47648125878431646.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:58:35,449] Trial 80 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 00:59:18,893] Trial 81 finished with value: 0.48632571974316574 and parameters: {'embedding_dim': 8, 'step_size': 0.043784056828969076, 'batch_size': 32, 'num_bins': 26, 'bin_strategy': 'quantile', 'num_epochs': 7}. Best is trial 73 with value: 0.47648125878431646.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 01:00:13,887] Trial 82 finished with value: 0.4808140576078414 and parameters: {'embedding_dim': 9, 'step_size': 0.0396166901001848, 'batch_size': 30, 'num_bins': 17, 'bin_strategy': 'quantile', 'num_epochs': 7}. Best is trial 73 with value: 0.47648125878431646.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 01:01:16,749] Trial 83 finished with value: 0.484505758786694 and parameters: {'embedding_dim': 10, 'step_size': 0.035889673059122974, 'batch_size': 30, 'num_bins': 17, 'bin_strategy': 'quantile', 'num_epochs': 8}. Best is trial 73 with value: 0.47648125878431646.
[I 2023-05-17 01:01:24,973] Trial 84 pruned. 
[I 2023-05-17 01:02:03,746] Trial 85 pruned. 
[I 2023-05-17 01:02:26,934] Trial 86 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 01:03:17,388] Trial 87 finished with value: 0.4778754175051123 and parameters: {'embedding_dim': 8, 'step_size': 0.05100157729747126, 'batch_size': 28, 'num_bins': 22, 'bin_strategy': 'quantile', 'num_epochs': 6}. Best is trial 73 with value: 0.47648125878431646.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 01:03:59,348] Trial 88 finished with value: 0.47865184479910533 and parameters: {'embedding_dim': 10, 'step_size': 0.046509370024280926, 'batch_size': 28, 'num_bins': 21, 'bin_strategy': 'quantile', 'num_epochs': 5}. Best is trial 73 with value: 0.47648125878431646.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 01:04:42,129] Trial 89 finished with value: 0.48706426403950454 and parameters: {'embedding_dim': 8, 'step_size': 0.044832751812685284, 'batch_size': 28, 'num_bins': 18, 'bin_strategy': 'quantile', 'num_epochs': 5}. Best is trial 73 with value: 0.47648125878431646.
[I 2023-05-17 01:04:50,840] Trial 90 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 01:05:38,872] Trial 91 finished with value: 0.4764637610283817 and parameters: {'embedding_dim': 10, 'step_size': 0.04717310364019017, 'batch_size': 30, 'num_bins': 22, 'bin_strategy': 'quantile', 'num_epochs': 6}. Best is trial 91 with value: 0.4764637610283817.
[I 2023-05-17 01:06:26,018] Trial 92 finished with value: 0.4725330759763331 and parameters: {'embedding_dim': 10, 'step_size': 0.04746441356701626, 'batch_size': 30, 'num_bins': 13, 'bin_strategy': 'quantile', 'num_epochs': 6}. Best is trial 92 with value: 0.4725330759763331.
[I 2023-05-17 01:06:41,549] Trial 93 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 01:07:12,239] Trial 94 finished with value: 0.4796635929325749 and parameters: {'embedding_dim': 10, 'step_size': 0.0477274890995249, 'batch_size': 30, 'num_bins': 17, 'bin_strategy': 'quantile', 'num_epochs': 5}. Best is trial 92 with value: 0.4725330759763331.
[I 2023-05-17 01:07:53,602] Trial 95 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 01:08:37,636] Trial 96 finished with value: 0.4850260604414627 and parameters: {'embedding_dim': 10, 'step_size': 0.046947828609809845, 'batch_size': 26, 'num_bins': 19, 'bin_strategy': 'quantile', 'num_epochs': 5}. Best is trial 92 with value: 0.4725330759763331.
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 01:08:51,371] Trial 97 pruned. 
/usr/local/lib64/python3.9/site-packages/sklearn/preprocessing/_discretization.py:291: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in feature 1 are removed. Consider decreasing the number of bins.
  warnings.warn(
[I 2023-05-17 01:09:30,112] Trial 98 finished with value: 0.4824779956622831 and parameters: {'embedding_dim': 10, 'step_size': 0.04156484777805755, 'batch_size': 30, 'num_bins': 17, 'bin_strategy': 'quantile', 'num_epochs': 5}. Best is trial 92 with value: 0.4725330759763331.
[I 2023-05-17 01:09:54,902] Trial 99 pruned. 
In [20]:
study_bins.best_params
Out[20]:
{'embedding_dim': 10,
 'step_size': 0.04746441356701626,
 'batch_size': 30,
 'num_bins': 13,
 'bin_strategy': 'quantile',
 'num_epochs': 6}
In [21]:
trial = study_bins.best_trial

print('Test loss: {}'.format(trial.value))
print("Best hyperparameters: {}".format(trial.params))
Test loss: 0.4725330759763331
Best hyperparameters: {'embedding_dim': 10, 'step_size': 0.04746441356701626, 'batch_size': 30, 'num_bins': 13, 'bin_strategy': 'quantile', 'num_epochs': 6}
In [22]:
train_bin_ffm(**study_bins.best_params)
Out[22]:
0.22700029611587524
In [32]:
bin_losses = []
for i in trange(20):
    mse = train_bin_ffm(**study_bins.best_params)
    bin_losses.append(math.sqrt(mse))
100%|██████████| 20/20 [15:42<00:00, 47.11s/it]
In [33]:
bin_losses
Out[33]:
[0.47253831069784313,
 0.47436226585193625,
 0.4746327178614912,
 0.47469055988463177,
 0.4737954243998599,
 0.4738253487224963,
 0.47137653094042825,
 0.47257028538839857,
 0.47029610602108984,
 0.47444127877551856,
 0.4718063374009628,
 0.4723477105300527,
 0.47160880620469847,
 0.477920785313591,
 0.4717265988354307,
 0.47360833488214976,
 0.4717850814181646,
 0.4717952515621083,
 0.4710633573328178,
 0.47449408806774845]
In [40]:
np.mean(bin_losses), 3 * np.std(bin_losses), np.mean(bin_losses) + 3 * np.std(bin_losses), np.mean(bin_losses) - 3 * np.std(bin_losses)
Out[40]:
(0.47303425900457097,
 0.0051958734734685146,
 0.47823013247803947,
 0.46783838553110246)