import sys
import numpy
import torch
import optuna
from nn.abelian_group_net import (AbelianGroupNetwork, AbelianSemigroupNetwork1,
                                  AbelianSemigroupNetwork2, AbelianSemigroupNetwork3)
from nn.monotonic_net import MonotonicNetwork, MultiMonotonicNetwork
from nn.deep_sets import DeepSet, MLP
from train import train
from setting import Xs, xs, func1, func2, func3, func4, func5, func6, func7, func8, func9, func10, func11


func = func11


def tune_deepset(trial):
    print(func, DeepSet, file=sys.stderr)
    mid_dim = trial.suggest_int('mid_dim', 2, 32)
    hidden_dim = trial.suggest_int('hidden_dim', 2, 32)
    n_layer_phi = trial.suggest_int('n_layer_phi', 2, 8)
    n_layer_rho = trial.suggest_int('n_layer_rho', 2, 8)
    ds = DeepSet(MLP(1, mid_dim, hidden_dim=hidden_dim, layer_num=n_layer_phi),
                 MLP(mid_dim, 1, hidden_dim=hidden_dim, layer_num=n_layer_rho))
    return train(ds, func, print_error_test=True)


def tune_agn(trial):
    print(func, file=sys.stderr)
    # n_layer = trial.suggest_int("n_layer", 1, 8)
    n_group = trial.suggest_int('n_group', 2, 32)
    n_each = trial.suggest_int('n_each', 2, 32)
    agn = AbelianGroupNetwork(MonotonicNetwork(n_group=n_group, n_each=n_each))
    # agn = AbelianGroupNetwork(MultiMonotonicNetwork(
    #     n_layer=n_layer, n_group=n_group, n_each=n_each, const_sign=None))
    return train(agn, func, print_error_test=True)


def tune_asn(trial):
    print(func, file=sys.stderr)
    # n_layer = trial.suggest_int("n_layer", 1, 8)
    n_group = trial.suggest_int('n_group', 2, 32)
    n_each = trial.suggest_int('n_each', 2, 32)
    asn = AbelianSemigroupNetwork2(MonotonicNetwork(n_group=n_group, n_each=n_each))
    # asn = AbelianSemigroupNetwork2(MultiMonotonicNetwork(
    #     n_layer=n_layer, n_group=n_group, n_each=n_each, const_sign=None))
    return train(asn, func, print_error_test=True)


if __name__ == "__main__":
    numpy.random.seed(0)
    torch.manual_seed(0)
    study = optuna.create_study(sampler=optuna.samplers.RandomSampler(seed=0))
    study.optimize(tune_agn, n_trials=10)
