import logging
import sys
import warnings

import botorch

from _test_functions.get_function import get_args
from mobo_osd.mobo_osd import MOBO_OSD
from mobo_osd.helper import set_seed

botorch.settings.suppress_botorch_warnings(True)
warnings.filterwarnings("ignore", category=UserWarning)


if __name__ == "__main__":

    input_dicts = get_args()
    objective_function = input_dicts['objective_function']
    dim = objective_function.input_dims
    num_objectives = objective_function.num_objectives
    num_constraints = objective_function.num_constraints 
    batch = input_dicts['batch']
    max_evals = input_dicts['max_evals']
    seed = input_dicts['seed']
    n_beta = input_dicts['n_beta']
    N_INIT = input_dicts['n_init']

    set_seed(seed)
    logging.root.setLevel(logging.INFO)
    logging.getLogger('').addHandler(logging.StreamHandler(sys.stdout))
    logging.info(f'Running MOBO-OSD with {objective_function.name} D={objective_function.input_dims}, M={objective_function.num_objectives}, n_init={N_INIT}, n_beta={n_beta}, batch_size={batch}, n_maxeval={max_evals}')
   
    mobo_osd = MOBO_OSD(
        objective_function=objective_function,
        n_maxeval=max_evals,
        batch_size=batch,
        n_init=N_INIT,
        n_beta=n_beta,
    )
    mobo_osd.run_optimization()