import numpy as np
import GPy
import safeopt
import gurobipy as grb
from util import sample_unit_vec, inner_product, try_sample_gp_func

BIG_M = 1e10


def get_one_d_three_layer_config():
    config = dict()
    config['problem_name'] = '1d3l'
    alpha = 0.2
    print('Start constructing 1d3l config!')
    # configurations on input x
    x_dim = 1
    x_range = (-1, 1)
    x_init = np.array([0.0])

    discrete_num_per_dim = 50
    x_grid = safeopt.linearly_spaced_combinations(
        [x_range] * x_dim,
        [50, 200, 200] #[discrete_num_per_dim] * x_dim
    )
    print('x_grid constructed!')

    # configurations on black-box functions
    black_box_funcs_dim = 3
    GP_kernel_list = []
    for _ in range(black_box_funcs_dim):
        GP_kernel_list.append(
            GPy.kern.RBF(
                input_dim=x_dim, variance=1.0, lengthscale=0.3, ARD=True
            )
        )

    obj_kernel = GPy.kern.RBF(
        input_dim=x_dim, variance=3.0, lengthscale=0.3, ARD=True
    )

    black_box_funcs_list = []
    black_box_funcs_norm_list = []
    y_init_list = []
    for k in range(black_box_funcs_dim):
        func_norm, func = try_sample_gp_func(
            True, 5, x_dim, [(-2, 2)] * x_dim, GP_kernel_list[k],
            1e-8, 0.0, 3)
        black_box_funcs_list.append(func)
        black_box_funcs_norm_list.append(func_norm)
        y_init_list.append(
            black_box_funcs_list[k](x_init)
        )
        print(f'Black box function {k} sampled!')

    xz_grid = safeopt.linearly_spaced_combinations(
        [(-1, 1), (-2, 2), (-2, 2), (-2, 2)],
        [discrete_num_per_dim] * 4
    )

    def sample_phi_one_pass_and_update_gp(x, gp_list):
        phi_list = []
        for k in range(black_box_funcs_dim):
            x = np.squeeze(x)
            phi_list.append(black_box_funcs_list[k](x))

            prev_X = gp_list[k].X
            prev_Y = gp_list[k].Y
            new_X = np.vstack([prev_X, [[x]]])
            new_Y = np.vstack([prev_Y, [phi_list[-1]]])
            gp_list[k].set_XY(new_X, new_Y)

            x = black_box_funcs_list[k](x)
        return phi_list, gp_list

    def sample_phi_and_update_gp(s_list, gp_list):
        phi_list = []
        for k in range(black_box_funcs_dim):
            x = np.squeeze(s_list[k])
            phi_list.append(black_box_funcs_list[k](x))

            prev_X = gp_list[k].X
            prev_Y = gp_list[k].Y
            new_X = np.vstack([prev_X, [[x]]])
            new_Y = np.vstack([prev_Y, [phi_list[-1]]])
            gp_list[k].set_XY(new_X, new_Y)
        return phi_list, gp_list

    def get_obj(x):
        z = x
        for k in range(black_box_funcs_dim):
            z = black_box_funcs_list[k](z)
        z = z + alpha * x ** 2
        return z

    obj_init = get_obj(x_init[0])

    def forward_sample(x, gp_list):
        # forward sample a grey-box function conditioned on the posterior.
        phi_0 = black_box_funcs_list[0]
        phi_1 = black_box_funcs_list[1]
        phi_2 = black_box_funcs_list[2]

        gp_0 = gp_list[0]
        gp_1 = gp_list[1]
        gp_2 = gp_list[2]

        if x.ndim == 1:
            x = np.expand_dims(x, axis=1)
        num_x = x.shape[0]

        random_Z0 = np.random.randn(num_x, 1)
        z_0_mean, z_0_var = gp_0.predict(x)
        z_0_std = np.sqrt(z_0_var)
        z0_sample = z_0_mean + random_Z0 * z_0_std


        random_Z1 = np.random.randn(num_x, 1)
        z_1_mean, z_1_var = gp_1.predict(z0_sample)
        z_1_std = np.sqrt(z_1_var)
        z1_sample = z_1_mean + random_Z1 * z_1_std

        random_Z2 = np.random.randn(num_x, 1)
        z_2_mean, z_2_var = gp_2.predict(z1_sample)
        z_2_std = np.sqrt(z_2_var)
        z2_sample = z_2_mean + random_Z2 * z_2_std

        ghat = z2_sample + alpha * x ** 2
        return ghat

    def check_consistency(x, z, gp_list, beta):
        # check if the intermediate variable z is consistent with the input x
        # conditioned on the current observations
        # single point check feasibility
        if type(beta) is not list:
            beta = [beta] * len(gp_list)

        phi_0 = black_box_funcs_list[0]
        phi_1 = black_box_funcs_list[1]
        phi_2 = black_box_funcs_list[2]

        gp_0 = gp_list[0]
        gp_1 = gp_list[1]
        gp_2 = gp_list[2]

        if x.ndim == 1:
            x = np.expand_dims(x, axis=1)
        z_0_mean, z_0_var = gp_0.predict(x)
        z_1_mean, z_1_var = gp_1.predict(np.expand_dims(z[:, 0], axis=1))
        z_2_mean, z_2_var = gp_2.predict(np.expand_dims(z[:, 1], axis=1))

        z_0_mean = np.squeeze(z_0_mean)
        z_0_var = np.squeeze(z_0_var)

        z_1_mean = np.squeeze(z_1_mean)
        z_1_var = np.squeeze(z_1_var)

        z_2_mean = np.squeeze(z_2_mean)
        z_2_var = np.squeeze(z_2_var)

        z_0_std = np.sqrt(z_0_var)
        z_1_std = np.sqrt(z_1_var)
        z_2_std = np.sqrt(z_2_var)
        # check if z_0 and z_1 are inside the confidence interval
        z_0_feas = (z[:, 0] >= z_0_mean - beta[0] * z_0_std) * \
            (z[:, 0] <= z_0_mean + beta[0] * z_0_std)

        z_1_feas = (z[:, 1] >= z_1_mean - beta[1] * z_1_std) * \
            (z[:, 1] <= z_1_mean + beta[1] * z_1_std)

        feas = z_0_feas * z_1_feas

        z_2_lw_bound = z_2_mean - beta[2] * z_2_std
        BIG_M = 1e20
        x = np.squeeze(x)
        opt_obj_val = (1 - feas) * BIG_M + feas * z_2_lw_bound + alpha * x ** 2
        return opt_obj_val

    func_val = [get_obj(x) for x in x_grid]
    func_val_min = np.min(func_val)
    func_val_argmin = np.argmin(func_val)
    opt_sol = x_grid[func_val_argmin]

    config['var_dim'] = x_dim
    config['x_grid'] = x_grid
    config['xz_grid'] = xz_grid
    config['black_box_funcs_list'] = black_box_funcs_list
    config['kernel_list'] = GP_kernel_list
    config['sample_phi'] = sample_phi_and_update_gp
    config['sample_phi_pass_through'] = sample_phi_one_pass_and_update_gp
    config['check_consistency'] = check_consistency
    config['forward_sample'] = forward_sample

    config['x_init'] = np.expand_dims(x_init, axis=0)
    config['y_init_list'] = y_init_list
    config['obj_init'] = obj_init
    config['ground_truth_opt_val'] = func_val_min
    config['ground_truth_opt_sol'] = opt_sol
    config['beta_func'] = lambda t: max(black_box_funcs_norm_list)

    # specify the keys for CONFIG purpose
    config['parameter_set'] = x_grid
    config['eval_simu'] = False

    config['obj'] = get_obj
    config['bounds'] = [x_range] * x_dim
    config['discretize_num_list'] = [discrete_num_per_dim] * x_dim
    config['init_points'] = np.expand_dims(x_init, axis=0)
    config['init_safe_points'] = x_init
    config['train_X'] = x_init
    config['noise_level'] = 1e-4
    config['kernel'] = [obj_kernel]
    config['obj_kernel'] = obj_kernel
    config['total_eval_num'] = 30
    return config


def get_config(problem_name):
    """
    Input: problem_name
    Output: configuration of the constrained grey-box problem, including
    variable dimension, number of constraints, objective function and
    constraint function.
    """

    if problem_name == '1d3l':
        return get_one_d_three_layer_config()


if __name__ == '__main__':
    pass
