import numpy as np
import GPy
import safeopt
import gurobipy as grb
from util import sample_unit_vec, inner_product, try_sample_gp_func, \
    try_sample_utility_func

BIG_M = 1e10


def get_power_allocation_agent_config():
    config = dict()
    config['problem_name'] = 'power_allocation_agent'
    print('Start constructing power allocation config!')

    # configurations on input x
    x_dim = 1
    x_range = (0, 1)
    x_init = np.array([0.1])

    discrete_num_per_dim = 50
    x_grid = safeopt.linearly_spaced_combinations(
        [x_range] * x_dim,
        [discrete_num_per_dim] * x_dim
    )
    print('x_grid constructed!')

    # configurations on black-box functions
    black_box_funcs_dim = 1
    num_constraints = black_box_funcs_dim - 1
    GP_kernel_list = []
    for _ in range(black_box_funcs_dim):
        GP_kernel_list.append(
            GPy.kern.RBF(
                input_dim=x_dim, variance=1.0, lengthscale=0.3, ARD=True
            )
        )

    black_box_funcs_list = []
    black_box_funcs_norm_list = []
    y_init_list = []
    for k in range(black_box_funcs_dim):
        func_norm, func = try_sample_utility_func(
            True, 5, x_dim, [(-2, 2)] * x_dim, GP_kernel_list[k],
            1e-8, 0.0, 3)
        black_box_funcs_list.append(func)
        black_box_funcs_norm_list.append(func_norm)
        y_init_list.append(
            black_box_funcs_list[k](x_init)
        )
        print(f'Black box function {k} sampled!')

    def get_obj(x):
        z = black_box_funcs_list[0](x)
        return z



    obj_init = get_obj(x_init[0])
    constr_func_val = [get_obj(x) for x in x_grid]

    mean_val = np.mean(constr_func_val)
    mean_func = GPy.mappings.Constant(
        input_dim=1, output_dim=1, value=mean_val)
    config['prior_means'] = [mean_func]

    constr_func_val_min = np.min(constr_func_val)
    constr_func_val_argmin = np.argmin(constr_func_val)
    opt_sol = x_grid[constr_func_val_argmin]

    config['var_dim'] = x_dim
    config['x_grid'] = x_grid

    config['black_box_funcs_list'] = black_box_funcs_list
    config['kernel_list'] = GP_kernel_list

    config['x_init'] = np.expand_dims(x_init, axis=0)
    config['y_init_list'] = y_init_list
    config['obj_init'] = obj_init
    config['ground_truth_opt_val'] = constr_func_val_min
    config['ground_truth_opt_sol'] = opt_sol
    config['beta_func'] = lambda t: max(black_box_funcs_norm_list)

    # specify the keys for CONFIG purpose
    config['parameter_set'] = x_grid
    config['eval_simu'] = False

    config['obj'] = get_obj
    config['bounds'] = [x_range] * x_dim
    config['discretize_num_list'] = [discrete_num_per_dim] * x_dim
    config['init_points'] = np.expand_dims(x_init, axis=0)
    config['init_safe_points'] = x_init
    config['train_X'] = x_init
    config['noise_level'] = 1e-4
    config['total_eval_num'] = 30
    config['eta'] = 0.05
    return config

if __name__ == '__main__':
    pass
