import numpy as np
import GPy
import safeopt
import gurobipy as grb
from util import sample_unit_vec, inner_product, try_sample_gp_func

BIG_M = 1e10


def get_random_agent_config():
    config = dict()
    config['problem_name'] = 'random_agent'
    print('Start constructing random_agent config!')
    # configurations on input x
    x_dim = 1
    x_range = (-1, 1)
    x_init = np.array([0.0])

    discrete_num_per_dim = 50
    x_grid = safeopt.linearly_spaced_combinations(
        [x_range] * x_dim,
        [discrete_num_per_dim] * x_dim
    )
    print('x_grid constructed!')

    # configurations on black-box functions
    black_box_funcs_dim = 3
    num_constraints = black_box_funcs_dim - 1
    GP_kernel_list = []
    for _ in range(black_box_funcs_dim):
        GP_kernel_list.append(
            GPy.kern.RBF(
                input_dim=x_dim, variance=1.0, lengthscale=0.3, ARD=True
            )
        )

    obj_kernel = GPy.kern.RBF(
        input_dim=x_dim, variance=1.0, lengthscale=0.3, ARD=True
    )

    black_box_funcs_list = []
    black_box_funcs_norm_list = []
    y_init_list = []
    for k in range(black_box_funcs_dim):
        func_norm, func = try_sample_gp_func(
            True, 5, x_dim, [(-2, 2)] * x_dim, GP_kernel_list[k],
            1e-8, 0.0, 3)
        black_box_funcs_list.append(func)
        black_box_funcs_norm_list.append(func_norm)
        y_init_list.append(
            black_box_funcs_list[k](x_init)
        )
        print(f'Black box function {k} sampled!')


    def get_obj(x):
        z = black_box_funcs_list[0](x)
        return z

    obj_init = get_obj(x_init[0])


    M = 1e20

    feasibility_eval = [
        np.all(
            [black_box_funcs_list[k+1](x) <= 0 for k in range(num_constraints)]
        )
        for x in x_grid
    ]

    constr_func_val = [get_obj(x) +
                       M * (1 - np.all(
            [black_box_funcs_list[k+1](x) <= 0 for k in range(num_constraints)]
                       )
                       )
                       for x in x_grid
                       ]

    constr_func_val_min = np.min(constr_func_val)
    constr_func_val_argmin = np.argmin(constr_func_val)
    opt_sol = x_grid[constr_func_val_argmin]

    config['var_dim'] = x_dim
    config['x_grid'] = x_grid

    config['black_box_funcs_list'] = black_box_funcs_list
    config['kernel_list'] = GP_kernel_list

    config['x_init'] = np.expand_dims(x_init, axis=0)
    config['y_init_list'] = y_init_list
    config['obj_init'] = obj_init
    config['ground_truth_opt_val'] = constr_func_val_min
    config['ground_truth_opt_sol'] = opt_sol
    config['beta_func'] = lambda t: max(black_box_funcs_norm_list)

    # specify the keys for CONFIG purpose
    config['parameter_set'] = x_grid
    config['eval_simu'] = False

    config['obj'] = get_obj
    config['bounds'] = [x_range] * x_dim
    config['discretize_num_list'] = [discrete_num_per_dim] * x_dim
    config['init_points'] = np.expand_dims(x_init, axis=0)
    config['init_safe_points'] = x_init
    config['train_X'] = x_init
    config['noise_level'] = 1e-4
    config['kernel'] = [obj_kernel]
    config['obj_kernel'] = obj_kernel
    config['total_eval_num'] = 30
    return config

def get_config(problem_name):
    """
    Input: problem_name
    Output: configuration of the constrained grey-box problem, including
    variable dimension, number of constraints, objective function and
    constraint function.
    """

    if problem_name == '1d3l':
        return get_one_d_three_layer_config()


if __name__ == '__main__':
    pass
