"""
Implement violation-aware Bayesian optimizer.
"""
from scipy.stats import norm
import numpy as np
import GPy

def transform_to_2d(an_array):
    if type(an_array) == list:
        an_array = np.array(an_array)
    if an_array.ndim == 1:
        an_array = np.expand_dims(an_array, axis=0)
    elif an_array.ndim == 0:
        an_array = np.atleast_2d(an_array)
    return an_array

class Base_BO():

    def __init__(self, config):


        self.t = 0
        self.config = config
        self.x_grid = config['x_grid']
        num_x, dim_x = config['x_grid'].shape

        self.num_x = num_x
        self.dim_x = dim_x

        self.num_black_box_funcs = len(config['black_box_funcs_list'])
        if 'black_box_funcs_noise_vars' in config.keys():
            self.black_box_funcs_noise_vars = \
                config['black_box_funcs_noise_vars']
        else:
            self.black_box_funcs_noise_vars = [0.02 ** 2] \
                * self.num_black_box_funcs

        x_init = config['x_init']
        y_init_list = config['y_init_list']
        obj_init = np.array(config['obj_init'])

        if obj_init.ndim == 0:
            obj_init = np.atleast_2d(obj_init)

        if obj_init.ndim == 1:
            obj_init = np.expand_dims(obj_init, axis=1)

        self.obj_gp = GPy.models.GPRegression(
                        x_init, obj_init,
                        config['obj_kernel'],
                        noise_var=config['noise_level'],
                    )
        self.f_min = obj_init[0]

        black_box_func_gp_list = []
        for k in range(self.num_black_box_funcs):
            if 'z_prior_means' in config.keys():
                black_box_func_gp_list.append(
                    GPy.models.GPRegression(
                        x_init, transform_to_2d(y_init_list[k]),
                        config['kernel_list'][k],
                        noise_var=self.black_box_funcs_noise_vars[k],
                        mean_function=config['z_prior_means'][k]
                    )
                )
            else:
                black_box_func_gp_list.append(
                    GPy.models.GPRegression(
                        x_init, transform_to_2d(y_init_list[k]),
                        config['kernel_list'][k],
                        noise_var=self.black_box_funcs_noise_vars[k],
                    )
                )
        self.black_box_func_gp_list = black_box_func_gp_list

    def optimize(self):
        raise NotImplementedError

    def make_step(self):
        x_next = self.optimize()
        # Get measurements of the black-box function
        obj = self.config['obj'](x_next)

        # Add new measurements to the GP models
        prev_X = self.obj_gp.X
        prev_Y = self.obj_gp.Y
        new_X = np.vstack([prev_X, x_next])
        new_Y = np.vstack([prev_Y, [obj]])
        self.obj_gp.set_XY(new_X, new_Y)
        self.f_min = np.min(new_Y)
        self.t = self.t + 1
        return obj


class LCB_BO(Base_BO):

    def __init__(self, config):
        super().__init__(config)
        if 'beta_func' in config.keys():
            self.beta_func = config['beta_func']
        else:
            self.beta_func = lambda t: 3

    def optimize(self):
        obj_mean, obj_var = self.obj_gp.predict(self.x_grid)
        obj_std = np.sqrt(obj_var)
        beta = self.beta_func(self.t)
        obj_lcb = obj_mean - beta * obj_std
        best_x_id = np.argmin(obj_lcb)
        best_x = self.x_grid[best_x_id]
        return best_x

class EI_BO(Base_BO):

    def __init__(self, config):
        super().__init__(config)

    def optimize(self):
        obj_mean, obj_var = self.obj_gp.predict(self.x_grid)
        obj_std = np.sqrt(obj_var)
        f_min = self.f_min
        num_eps = 1e-30
        z = (f_min - obj_mean)/np.maximum(obj_std, num_eps)
        EI = (f_min - obj_mean) * norm.cdf(z) + obj_std * norm.pdf(z)
        best_x_id = np.argmax(EI)
        best_x = self.x_grid[best_x_id]
        return best_x

class EIFN_BO(Base_BO):

    def __init__(self, config, MC_sample_size=30):
        super().__init__(config)
        self.f_min = config['obj_init']
        self.MC_sample_size = MC_sample_size

    def optimize(self):
        MC_sample_size = self.MC_sample_size
        EI_estim = np.zeros_like(self.x_grid)
        f_min = self.f_min
        for k in range(MC_sample_size):
            ghat = self.config['forward_sample'](
                self.x_grid, self.black_box_func_gp_list)
            EI_estim += np.maximum(f_min-ghat, 0)

        EI_estim = EI_estim / float(MC_sample_size)

        best_x_id = np.argmax(EI_estim)
        best_x = self.x_grid[best_x_id]

        phi_list, gp_list = self.config['sample_phi_pass_through'](
            best_x, self.black_box_func_gp_list)
        self.black_box_func_gp_list = gp_list
        return best_x
