import optuna
from utils import get_device
from wv_wrapper import AGNWord2VecWrapper, MLPWord2VecWrapper, MLPConcatWord2VecWrapper
from word_analogy.train import load_bats_for_train, train


def tune_agn(trial):
    hidden_dim = trial.suggest_int("hidden_dim", 8, 256)
    layer_num = trial.suggest_int("layer_num", 2, 6)
    weight_decay = trial.suggest_float("weight_decay", 0.0, 1e-3)
    model = AGNWord2VecWrapper(hidden_dim=hidden_dim, layer_num=layer_num).to(get_device())
    train_abcd, train_ds, valid_ds, test_ds = load_bats_for_train(model)
    return -train(model, train_abcd, train_ds, valid_ds, weight_decay=weight_decay)


def tune_mlp(trial):
    hidden_dim = trial.suggest_int("hidden_dim", 8, 256)
    layer_num = trial.suggest_int("layer_num", 2, 6)
    weight_decay = trial.suggest_float("weight_decay", 0.0, 1e-3)
    model = MLPWord2VecWrapper(hidden_dim=hidden_dim, layer_num=layer_num).to(get_device())
    train_abcd, train_ds, valid_ds, test_ds = load_bats_for_train(model)
    return -train(model, train_abcd, train_ds, valid_ds, weight_decay=weight_decay)


def tune_mlp_c(trial):
    hidden_dim = trial.suggest_int("hidden_dim", 8, 768)
    layer_num = trial.suggest_int("layer_num", 2, 6)
    weight_decay = trial.suggest_float("weight_decay", 0.0, 1e-3)
    model = MLPConcatWord2VecWrapper(hidden_dim=hidden_dim, layer_num=layer_num).to(get_device())
    train_abcd, train_ds, valid_ds, test_ds = load_bats_for_train(model)
    return -train(model, train_abcd, train_ds, valid_ds, weight_decay=weight_decay)


if __name__ == "__main__":
    study = optuna.create_study()
    study.optimize(tune_agn, n_trials=10)
