import optuna
import os
import joblib
import sys
from nn.abelian_group_net import (AbelianGroupNetwork, AbelianSemigroupNetwork1,
                                  AbelianSemigroupNetwork2, AbelianSemigroupNetwork3)
from nn.monotonic_net import MonotonicNetwork, MultiMonotonicNetwork
from nn.deep_sets import DeepSet, MLP
from train import train
from setting import Xs, xs, func1, func2, func3, func4, func5, func6, func7, func8, func9, func10, func11

func = func2
study_name_base = "toy_f2"
storage_path = "sqlite:///./optuna_toy.db"


def tune_deepset(trial):
    print(func, DeepSet, file=sys.stderr)
    mid_dim = trial.suggest_int('mid_dim', 2, 32)
    hidden_dim = trial.suggest_int('hidden_dim', 2, 32)
    n_layer_phi = trial.suggest_int('n_layer_phi', 2, 8)
    n_layer_rho = trial.suggest_int('n_layer_rho', 2, 8)
    ds = DeepSet(MLP(1, mid_dim, hidden_dim=hidden_dim, layer_num=n_layer_phi),
                 MLP(mid_dim, 1, hidden_dim=hidden_dim, layer_num=n_layer_rho))
    return train(ds, func, print_error_test=True)


def tune_agn(trial):
    print(func, file=sys.stderr)
    n_layer = trial.suggest_int("n_layer", 1, 8)
    n_group = trial.suggest_int('n_group', 2, 32)
    n_each = trial.suggest_int('n_each', 2, 32)
    agn = AbelianGroupNetwork(MultiMonotonicNetwork(
        n_layer=n_layer, n_group=n_group, n_each=n_each, const_sign=None))
    return train(agn, func, print_error_test=True)


def tune_asn(trial):
    print(func, file=sys.stderr)
    n_layer = trial.suggest_int("n_layer", 1, 8)
    n_group = trial.suggest_int('n_group', 2, 32)
    n_each = trial.suggest_int('n_each', 2, 32)
    asn = AbelianSemigroupNetwork2(MultiMonotonicNetwork(
        n_layer=n_layer, n_group=n_group, n_each=n_each, const_sign=None))
    return train(asn, func, print_error_test=True)


objective = -1
study_name = "dummy"


def run():
    study = optuna.load_study(
        study_name=study_name,
        storage=storage_path,
    )
    study.optimize(objective, n_trials=5)
    return os.getpid()


if __name__ == "__main__":
    for model_name in ["asn", "deepset", "agn", "asn"]:
        objective = eval(f"tune_{model_name}")
        study_name = f"{study_name_base}_{model_name}"
        study = optuna.create_study(
            study_name=study_name,
            storage=storage_path,
            load_if_exists=True,
        )
        process_ids = joblib.Parallel(n_jobs=16)([joblib.delayed(run)() for _ in range(16)])

        print(study.best_value)
        print(study.best_params)
