import numpy as np
import GPy
import safeopt
import gurobipy as grb
from util import sample_unit_vec, inner_product, try_sample_gp_func

BIG_M = 1e10


def get_constrained_opt_val_sol(x_grid, sample_h, get_obj_constr):
    opt_val = BIG_M
    opt_sol = None
    for x in x_grid:
        h = sample_h(x)
        obj, constr_vals = get_obj_constr(x, h)
        if np.all(np.array(constr_vals) <= 0) and obj < opt_val:
            opt_val = obj
            opt_sol = x
    return opt_val, opt_sol


def get_GPLP_config():
    config = dict()
    config['problem_name'] = 'GP_LP'
    print('Start constructing GPLP config!')
    # configurations on input x
    x_dim = 2
    x_range = (-2, 2)
    x_init = np.array([0.0, 0.0])

    discrete_num_per_dim = 100
    x_grid = safeopt.linearly_spaced_combinations(
        [x_range] * x_dim,
        [discrete_num_per_dim] * x_dim
    )
    print('x_grid constructed!')
    # configurations on black-box functions
    black_box_funcs_dim = 2
    GP_kernel_list = []
    for _ in range(black_box_funcs_dim):
        GP_kernel_list.append(
            GPy.kern.RBF(
                input_dim=x_dim, variance=0.5, lengthscale=1.0, ARD=True
            )
        )
    black_box_funcs_list = []
    black_box_funcs_norm_list = []
    y_init_list = []
    for k in range(black_box_funcs_dim):
        func_norm, func = try_sample_gp_func(
            False, 3, x_dim, [x_range] * x_dim, GP_kernel_list[k],
            1e-8, 0.0, 3)
        black_box_funcs_list.append(func)
        black_box_funcs_norm_list.append(func_norm)
        y_init_list.append(
            black_box_funcs_list[k](x_init)
        )
        print(f'Black box function {k} sampled!')

    def sample_h(x):
        x = np.squeeze(x)
        h_list = []
        for k in range(black_box_funcs_dim):
            h_list.append(black_box_funcs_list[k](x))
        return h_list

    # configurations of the outer explicit functions
    c_x = sample_unit_vec(x_dim)
    c_h = sample_unit_vec(black_box_funcs_dim)
    constraints_num = 2
    constraints_vec_list = []
    for _ in range(constraints_num):
        constraints_vec_list.append(
            [sample_unit_vec(x_dim),
             sample_unit_vec(black_box_funcs_dim),
             0.5 * sample_unit_vec(1)
             ]
        )

    def get_obj_constr(x, h_x):
        x = np.squeeze(x)
        obj_val = inner_product(c_x, x) + inner_product(c_h, h_x)
        constraint_val_list = []
        for k in range(constraints_num):
            a_x = constraints_vec_list[k][0]
            a_z = constraints_vec_list[k][1]
            b = constraints_vec_list[k][2]
            constraint_val_list.append(
                inner_product(a_x, x)+inner_product(a_z, h_x)+b
            )
        return obj_val, constraint_val_list

    obj_kernel = GPy.kern.RBF(
                input_dim=x_dim, variance=0.5, lengthscale=1.0, ARD=True
            )
    constr_kernel_list = []
    for k in range(constraints_num):
        constr_kernel_list.append(GPy.kern.RBF(
                input_dim=x_dim, variance=0.5, lengthscale=1.0, ARD=True
            )
        )

    def solve_aux_prob(x, z_lb, z_ub):
        m = grb.Model('aux_prob')
        m.Params.LogToConsole = 0
        z = m.addVars(black_box_funcs_dim, lb=z_lb, ub=z_ub)

        obj = inner_product(c_x, x) + inner_product(c_h, z)
        m.setObjective(obj, grb.GRB.MINIMIZE)

        for k in range(constraints_num):
            a_x = constraints_vec_list[k][0]
            a_z = constraints_vec_list[k][1]
            b = constraints_vec_list[k][2]
            m.addConstr(
                inner_product(a_x, x)+inner_product(a_z, z)+b <= 0
            )

        m.optimize()

        try:
            obj_val = m.ObjVal
        except Exception as e:
            obj_val = BIG_M
        return obj_val

    ground_truth_opt_val, ground_truth_opt_sol = get_constrained_opt_val_sol(
        x_grid, sample_h, get_obj_constr
    )
    config['var_dim'] = x_dim
    config['num_constrs'] = constraints_num
    config['x_grid'] = x_grid

    config['black_box_funcs_list'] = black_box_funcs_list
    config['kernel_list'] = GP_kernel_list
    config['get_obj_constr'] = get_obj_constr
    config['aux_prob_solver'] = solve_aux_prob

    config['x_init'] = np.expand_dims(x_init, axis=0)
    config['y_init_list'] = y_init_list
    config['sample_h'] = sample_h

    config['ground_truth_opt_val'] = ground_truth_opt_val
    config['ground_truth_opt_sol'] = ground_truth_opt_sol
    config['beta_func'] = lambda t: max(black_box_funcs_norm_list)

    # specify the keys for CONFIG purpose
    config['parameter_set'] = x_grid
    config['eval_simu'] = False

    def map_x_to_obj_constr(x):
        h = sample_h(x)
        obj, constr_list = get_obj_constr(x, h)
        return obj, constr_list

    def get_obj(x):
        obj, constr_list = map_x_to_obj_constr(x)
        return np.atleast_1d(obj)

    def get_constr_1(x):
        obj, constr_list = map_x_to_obj_constr(x)
        return constr_list[0]

    def get_constr_2(x):
        obj, constr_list = map_x_to_obj_constr(x)
        return constr_list[1]

    config['obj'] = get_obj
    config['constrs_list'] = [get_constr_1, get_constr_2]
    config['bounds'] = [x_range] * x_dim
    config['discretize_num_list'] = [discrete_num_per_dim] * x_dim
    config['init_points'] = np.expand_dims(x_init, axis=0)
    config['init_safe_points'] = x_init
    config['train_X'] = x_init
    config['noise_level'] = 1e-4
    config['kernel'] = [obj_kernel] + constr_kernel_list
    config['total_eval_num'] = 30
    return config


def get_config(problem_name):
    """
    Input: problem_name
    Output: configuration of the constrained grey-box problem, including
    variable dimension, number of constraints, objective function and
    constraint function.
    """

    if problem_name == 'GP_LP':
        return get_GPLP_config()


if __name__ == '__main__':
    pass
