import numpy as np
import GPy
import safeopt
import gurobipy as grb
from util import sample_unit_vec, inner_product, try_sample_gp_func

BIG_M = 1e10


def get_random_agent_config():
    config = dict()
    config['problem_name'] = 'random_agent'
    print('Start constructing random_agent config!')

    # configurations on input x
    x_dim = 1
    x_range = (-1, 1)
    x_init = np.array([0.0])

    discrete_num_per_dim = 50
    x_grid = safeopt.linearly_spaced_combinations(
        [x_range] * x_dim,
        [discrete_num_per_dim] * x_dim
    )
    print('x_grid constructed!')

    # configurations on black-box functions
    black_box_funcs_dim = 3
    num_constraints = black_box_funcs_dim - 1
    GP_kernel_list = []
    for _ in range(black_box_funcs_dim):
        GP_kernel_list.append(
            GPy.kern.RBF(
                input_dim=x_dim, variance=1.0, lengthscale=0.3, ARD=True
            )
        )

    black_box_funcs_list = []
    black_box_funcs_norm_list = []
    y_init_list = []
    for k in range(black_box_funcs_dim):
        func_norm, func = try_sample_gp_func(
            True, 5, x_dim, [(-2, 2)] * x_dim, GP_kernel_list[k],
            1e-8, 0.0, 3)
        black_box_funcs_list.append(func)
        black_box_funcs_norm_list.append(func_norm)
        y_init_list.append(
            black_box_funcs_list[k](x_init)
        )
        print(f'Black box function {k} sampled!')

    def get_obj(x):
        z = black_box_funcs_list[0](x)
        return z

    obj_init = get_obj(x_init[0])

    feasibility_eval = [
        np.all(
            [black_box_funcs_list[k+1](x) <= 0 for k in range(num_constraints)]
        )
        for x in x_grid
    ]

    constr_func_val = [get_obj(x) +
                       BIG_M * (1 - np.all(
            [black_box_funcs_list[k+1](x) <= 0 for k in range(num_constraints)]
                       )
                       )
                       for x in x_grid
                       ]

    constr_func_val_min = np.min(constr_func_val)
    constr_func_val_argmin = np.argmin(constr_func_val)
    opt_sol = x_grid[constr_func_val_argmin]

    config['var_dim'] = x_dim
    config['x_grid'] = x_grid

    config['black_box_funcs_list'] = black_box_funcs_list
    config['kernel_list'] = GP_kernel_list

    config['x_init'] = np.expand_dims(x_init, axis=0)
    config['y_init_list'] = y_init_list
    config['obj_init'] = obj_init
    config['ground_truth_opt_val'] = constr_func_val_min
    config['ground_truth_opt_sol'] = opt_sol
    config['beta_func'] = lambda t: max(black_box_funcs_norm_list)

    # specify the keys for CONFIG purpose
    config['parameter_set'] = x_grid
    config['eval_simu'] = False

    config['obj'] = get_obj
    config['bounds'] = [x_range] * x_dim
    config['discretize_num_list'] = [discrete_num_per_dim] * x_dim
    config['init_points'] = np.expand_dims(x_init, axis=0)
    config['init_safe_points'] = x_init
    config['train_X'] = x_init
    config['noise_level'] = 1e-4
    config['total_eval_num'] = 30
    config['eta'] = 0.2
    return config

if __name__ == '__main__':
    pass
