"""
Implement violation-aware Bayesian optimizer.
"""
import numpy as np
import GPy


def transform_to_2d(an_array):
    if type(an_array) == list:
        an_array = np.array(an_array)
    if an_array.ndim == 1:
        an_array = np.expand_dims(an_array, axis=0)
    elif an_array.ndim == 0:
        an_array = np.atleast_2d(an_array)
    return an_array


class UN_GREY_BOX_BO():

    def __init__(self, config):

        if 'beta_func' in config.keys():
            self.beta_func = config['beta_func']
        else:
            self.beta_func = lambda t: 3

        self.t = 0
        self.config = config
        self.black_box_funcs_list = config['black_box_funcs_list']
        self.num_black_box_funcs = len(config['black_box_funcs_list'])
        self.x_grid = config['x_grid']
        self.xz_grid = config['xz_grid']
        num_x, dim_x = config['x_grid'].shape

        self.num_x = num_x
        self.dim_x = dim_x

        num_xz, dim_xz = config['xz_grid'].shape
        self.num_xz = num_xz
        self.dim_xz = dim_xz

        if 'black_box_funcs_noise_vars' in config.keys():
            self.black_box_funcs_noise_vars = \
                config['black_box_funcs_noise_vars']
        else:
            self.black_box_funcs_noise_vars = [0.02 ** 2] \
                * self.num_black_box_funcs

        x_init = config['x_init']
        y_init_list = config['y_init_list']

        black_box_func_gp_list = []
        for k in range(self.num_black_box_funcs):
            if 'z_prior_means' in config.keys():
                black_box_func_gp_list.append(
                    GPy.models.GPRegression(
                        x_init, transform_to_2d(y_init_list[k]),
                        config['kernel_list'][k],
                        noise_var=self.black_box_funcs_noise_vars[k],
                        mean_function=config['z_prior_means'][k]
                    )
                )
            else:
                black_box_func_gp_list.append(
                    GPy.models.GPRegression(
                        x_init, transform_to_2d(y_init_list[k]),
                        config['kernel_list'][k],
                        noise_var=self.black_box_funcs_noise_vars[k],
                    )
                )
        self.black_box_func_gp_list = black_box_func_gp_list

    def optimize(self):
        xz = self.xz_grid
        opt_val_list = self.config['check_consistency'](
            xz[:, :self.dim_x], xz[:, self.dim_x:],
            self.black_box_func_gp_list,
            self.beta_func(self.t)
        )
        opt_val_list = np.squeeze(opt_val_list)

        next_point_id = np.argmin(opt_val_list)

        max_val = np.max(opt_val_list)
        min_val = np.min(opt_val_list)
        if max_val - min_val > 1e-20:
            next_xz = self.xz_grid[next_point_id]
        else:
            # random draw when the auxiliary problem fails
            print(f'Random draw!')
            next_xz_id = np.random.choice(self.num_xz)
            next_xz = self.xz_grid[next_xz_id]

        return next_xz

    def make_step(self):
        xz_next = self.optimize()
        x_next = xz_next[:self.dim_x]
        # Get measurements of the black-box function
        obj = self.config['obj'](x_next)

        # Add new measurements to the GP models
        phi_list, gp_list = self.config['sample_phi'](
            xz_next, self.black_box_func_gp_list)
        self.black_box_func_gp_list = gp_list

        self.t = self.t + 1
        return obj
