"""
Implement a BO agent.
"""
from util import transform_to_2d
from scipy.stats import norm
import numpy as np
import GPy


class BO_AGENT():

    def __init__(self, bo_agent_config):

        if 'beta_func' in bo_agent_config.keys():
            self.beta_func = bo_agent_config['beta_func']
        else:
            self.beta_func = lambda t: 3

        self.t = 0
        self.bo_agent_config = bo_agent_config
        # black-box functions list, with the first element as the objective
        self.black_box_funcs_list = bo_agent_config['black_box_funcs_list']
        self.num_black_box_funcs = len(bo_agent_config['black_box_funcs_list'])
        self.num_black_box_constrs = self.num_black_box_funcs - 1
        self.x_grid = bo_agent_config['x_grid']

        num_x, dim_x = bo_agent_config['x_grid'].shape

        self.num_x = num_x
        self.dim_x = dim_x

        if 'black_box_funcs_noise_vars' in bo_agent_config.keys():
            self.black_box_funcs_noise_vars = \
                bo_agent_config['black_box_funcs_noise_vars']
        else:
            self.black_box_funcs_noise_vars = [0.02 ** 2] \
                * self.num_black_box_funcs

        x_init = bo_agent_config['x_init']
        y_init_list = bo_agent_config['y_init_list']

        # the list of black-box functions, with the first element as the
        # black-box objective function
        black_box_func_gp_list = []
        for k in range(self.num_black_box_funcs):
            if 'prior_means' in bo_agent_config.keys():
                black_box_func_gp_list.append(
                    GPy.models.GPRegression(
                        x_init, transform_to_2d(y_init_list[k]),
                        bo_agent_config['kernel_list'][k],
                        noise_var=self.black_box_funcs_noise_vars[k],
                        mean_function=bo_agent_config['prior_means'][k]
                    )
                )
            else:
                black_box_func_gp_list.append(
                    GPy.models.GPRegression(
                        x_init, transform_to_2d(y_init_list[k]),
                        bo_agent_config['kernel_list'][k],
                        noise_var=self.black_box_funcs_noise_vars[k],
                    )
                )
        self.black_box_func_gp_list = black_box_func_gp_list

        self.eta = bo_agent_config['eta']
        self.local_A = bo_agent_config['local_A']
        self.acq_type = bo_agent_config['acq_type']
        if self.acq_type == 'penalty':
            self.mu_dual = 0
            self.rho = 5
            self.old_x = 0

    def local_penalty_primal_update(self, old_x, eval_and_udt=True):
        gp_obj = self.black_box_func_gp_list[0]

        obj_mean, obj_var = gp_obj.predict(self.x_grid)
        obj_mean = obj_mean.squeeze()
        obj_var = obj_var.squeeze()

        # calculate EI
        f_min = np.min(self.black_box_func_gp_list[0].Y)
        z = (f_min - obj_mean)/np.maximum(np.sqrt(obj_var), 1e-20)
        EI = (f_min - obj_mean) * norm.cdf(z) + np.sqrt(obj_var) * norm.pdf(z)

        diff = (self.x_grid - old_x).squeeze()
        aux_obj = EI - self.mu_dual * diff - 0.5 * self.rho * diff**2
        # optimize the local aux_obj
        aux_opt_id = np.argmax(aux_obj)

        opt_x = self.x_grid[aux_opt_id]

        self.mu_dual += self.rho * (opt_x - old_x)
        # do local black-box function evalutions
        if eval_and_udt:
            eval_result = self.local_evaluate(opt_x)
            self.update_local_gp(opt_x, eval_result)

    def local_cei_primal_update(self, constraints, eval_and_udt=True):
        gp_obj = self.black_box_func_gp_list[0]
        gp_constr_list = self.black_box_func_gp_list[1:]

        gp_old_constr = np.array([gp_constr_list[i].Y[-1]
                                  for i in range(self.num_black_box_constrs)]
                                 )



        obj_mean, obj_var = gp_obj.predict(self.x_grid)
        obj_mean = obj_mean.squeeze()
        obj_var = obj_var.squeeze()

        constrain_mean_list = []
        constrain_var_list = []
        for i in range(self.num_black_box_constrs):
            mean, var = gp_constr_list[i].predict(self.x_grid)
            constrain_mean_list.append(np.squeeze(mean))
            constrain_var_list.append(np.squeeze(var))

        constrain_mean_arr = np.array(constrain_mean_list).T
        constrain_var_arr = np.array(constrain_var_list).T

        # calculate Pr(g_i(x)<=0)
        prob_negtive = norm.cdf(
            0,
            constrain_mean_arr+np.squeeze(np.array(constraints))-\
            np.squeeze(gp_old_constr),
            constrain_var_arr
        )
        # calculate feasibility prob
        prob_feasible = np.prod(prob_negtive, axis=1)

        # calculate EI
        f_min = np.min(self.black_box_func_gp_list[0].Y)
        z = (f_min - obj_mean)/np.maximum(np.sqrt(obj_var), 1e-20)
        EI = (f_min - obj_mean) * norm.cdf(z) + np.sqrt(obj_var) * norm.pdf(z)
        EIc = prob_feasible * EI

        aux_obj = EIc
        # optimize the local aux_obj
        aux_opt_id = np.argmax(aux_obj)

        opt_x = self.x_grid[aux_opt_id]

        # do local black-box function evalutions
        if eval_and_udt:
            eval_result = self.local_evaluate(opt_x)
            self.update_local_gp(opt_x, eval_result)

    def local_pd_primal_update(self, inq_dual, eq_dual, eval_and_udt=True):
        gp_obj = self.black_box_func_gp_list[0]
        gp_constr_list = self.black_box_func_gp_list[1:]

        obj_mean, obj_var = gp_obj.predict(self.x_grid)
        obj_mean = obj_mean.squeeze()
        obj_var = obj_var.squeeze()

        constrain_mean_list = []
        constrain_var_list = []
        for i in range(self.num_black_box_constrs):
            mean, var = gp_constr_list[i].predict(self.x_grid)
            constrain_mean_list.append(np.squeeze(mean))
            constrain_var_list.append(np.squeeze(var))

        constrain_mean_arr = np.array(constrain_mean_list).T
        constrain_var_arr = np.array(constrain_var_list).T

        beta = self.beta_func(self.t)
        obj_lcb = obj_mean - beta * np.sqrt(obj_var)

        eta = self.eta
        dualA = eq_dual@self.local_A
        dualA = np.atleast_1d(dualA)

        if self.num_black_box_constrs > 0:
            constrain_lcb_arr = constrain_mean_arr - \
                beta * np.sqrt(constrain_var_arr)

            aux_obj = obj_lcb + eta * inq_dual @ constrain_lcb_arr.T + \
                eta * dualA @ self.x_grid.T
        else:
            aux_obj = obj_lcb + eta * dualA @ self.x_grid.T


        # optimize the local aux_obj
        aux_opt_id = np.argmin(aux_obj)

        opt_x = self.x_grid[aux_opt_id]

        # do local black-box function evalutions
        if eval_and_udt:
            eval_result = self.local_evaluate(opt_x)
            self.update_local_gp(opt_x, eval_result)

        local_Ax = self.local_A @ opt_x
        if self.num_black_box_constrs > 0:
            lowerg = constrain_lcb_arr[aux_opt_id]
        else:
            lowerg = 0
        return local_Ax, lowerg

    def local_evaluate(self, x):
        eval_result = []
        for k in range(self.num_black_box_funcs):
            eval_result.append(self.black_box_funcs_list[k](x))
        return eval_result

    def update_local_gp(self, x, eval_result):
        for k in range(self.num_black_box_funcs):
            y = eval_result[k]
            gp = self.black_box_func_gp_list[k]
            new_gp = self.update_single_gp(
                gp,
                x,
                y
            )
            self.black_box_func_gp_list[k] = new_gp

    def update_single_gp(self, gp, x, y):
        # Add this to the GP model
        prev_X = gp.X
        prev_Y = gp.Y

        new_X = np.vstack([prev_X, x])
        new_obj = np.vstack([prev_Y, y])
        gp.set_XY(new_X, new_obj)
        return gp
