import os
import sys
import time
from abc import ABCMeta, abstractmethod
from typing import Callable, Optional

import numpy as np
import matplotlib
# matplotlib.use("Agg")
import matplotlib.pyplot as plt


from scipy import optimize
from scipy.stats import norm, qmc
import scipy.linalg

from . import myutils

class BO_core(object):
    __metaclass__ = ABCMeta

    def __init__(self, gp_regressor, bounds, rng=None):
        self.GPmodel = gp_regressor
        if self.GPmodel.normalizer is None:
            self.mean = 0.
            self.std = 1.
        else:
            self.mean = self.GPmodel.normalizer.mean
            self.std = self.GPmodel.normalizer.std

        self.y_max = np.max(gp_regressor.Y)
        self.unique_X = np.unique(gp_regressor.X, axis=0)
        self.input_dim = np.shape(gp_regressor.X)[1]
        self.bounds = bounds
        self.bounds_list = bounds.T.tolist()
        self.sampling_num = 1
        self.max_inputs = None

        if rng is None:
            self.rng = np.random.default_rng()
        else:
            self.rng = rng

        self.sampler = qmc.Sobol(d=self.input_dim, scramble=True, seed=rng)


    def update(self, X, Y, optimize=False):
        self.GPmodel.set_XY(np.r_[self.GPmodel.X, np.atleast_2d(X)], np.r_[self.GPmodel.Y, np.atleast_2d(Y)])
        if optimize:
            self.GPmodel.optimize_restarts(num_restarts=10)

        self.y_max = np.max([self.y_max, np.max(Y)])
        self.unique_X = np.unique(np.r_[self.unique_X, X], axis=0)

        if self.GPmodel.normalizer is None:
            self.mean = 0.
            self.std = 1.
        else:
            self.mean = self.GPmodel.normalizer.mean
            self.std = self.GPmodel.normalizer.std

    def _make_start_points(self):
        """
        Generate random start points within the bounds for gradient-based optimization from multiple start points
        I assume that this is used for recommended point, acquisition function maximization and sample path maximization.
        """
        num_starts = 2 ** np.min([self.input_dim+1, 7])
        x0s = self.sampler.random(n=num_starts) * (self.bounds[1] - self.bounds[0]) + self.bounds[0]

        # add unique training points whose posterior mean is large up to NUM_TOP
        NUM_TOP = np.min([50, np.shape(self.unique_X)[0]])
        mean_train, _ = self.GPmodel.predict(self.unique_X)
        top_idx = np.argpartition(mean_train.ravel(), -NUM_TOP)[-NUM_TOP:]
        x0s = np.unique(np.r_[x0s, self.unique_X[top_idx]], axis=0)

        if self.max_inputs is not None:
            # add sampled maximum inputs
            x0s = np.unique(np.r_[x0s, self.max_inputs], axis=0)

        return x0s

    def minimize_continuous_space(self, objective_function, bounds=None, jac=None, tol=1e-5):
        x0s = self._make_start_points()

        x_min, f_min = np.zeros((1, self.input_dim)), np.inf
        for x0 in x0s:
            if np.any(x0 < self.bounds[0]) or np.any(x0 > self.bounds[1]):
                print("Warning: Initial point {} is out of bounds {}".format(x0, self.bounds))
                continue

            res = optimize.minimize(objective_function, x0, method='L-BFGS-B', bounds=bounds, jac=jac, tol=tol)
            x_min_temp, f_min_temp = res.x, res.fun

            if f_min > f_min_temp:
                x_min, f_min = x_min_temp, f_min_temp
        return np.atleast_2d(x_min), f_min

    @abstractmethod
    def pre_computation_acq(self, X_candidates=None):
        pass

    @abstractmethod
    def acq(self, x):
        return 0

    def _acq_params_val_and_grad(self, params): # params = [mu, var]
        raise NotImplementedError("The function _acq_params_val_and_grad is not implemented yet.")

    def next_input_pool(self, X):
        acquisition_values = self.acq(X)
        max_idx = np.argmax(acquisition_values)
        next_input = np.atleast_2d(X[max_idx])

        # delete max_idx from X
        X = X[np.arange(np.shape(X)[0]) != max_idx, :]
        return next_input, X

    def _acq_val_and_grad(self, x):
        x = np.atleast_2d(x)
        mean, var = self.GPmodel.predict_noiseless(x)
        mean_grad, var_grad = self.GPmodel.predictive_gradients(x)
        acq_val, acq_params_grad = self._acq_params_val_and_grad(np.r_[mean.ravel(), var.ravel()])
        acq_grad = acq_params_grad[0] * mean_grad.ravel() + acq_params_grad[1] * var_grad.ravel()
        return (- acq_val, -acq_grad)

    def next_input(self):
        x_min, f_min = self.minimize_continuous_space(self._acq_val_and_grad, self.bounds_list, jac=True)
        print('optimized acquisition function value:', -1*f_min)
        return x_min

    def LCB_maximizer(self, width_param=1., pool_X=None):
        if pool_X is not None:
            mean, var = self.GPmodel.predict_noiseless(pool_X)
            LCB = mean.ravel() - width_param * np.sqrt(var).ravel()
            max_idx = np.argmax(LCB)
            return np.atleast_2d(pool_X[max_idx])
        else:
            def LCB_val_and_grad(x):
                x = np.atleast_2d(x)
                mean, var = self.GPmodel.predict_noiseless(x)
                std = np.sqrt(var).ravel()
                mean_grad, var_grad = self.GPmodel.predictive_gradients(x)
                LCB_val = mean.ravel() - width_param * std
                LCB_grad = mean_grad.ravel() - width_param * 0.5 * (var_grad.ravel() / std)
                return (- LCB_val, -LCB_grad)

            x_min, f_min = self.minimize_continuous_space(LCB_val_and_grad, self.bounds_list, jac=True)
            print('Maximized LCB value:', -1*f_min)
            return x_min


    def posterior_maximum(self, pool_X=None):

        if pool_X is not None:
            mean, _ = self.GPmodel.predict_noiseless(pool_X)
            max_idx = np.argmax(mean.ravel())
            return np.atleast_2d(pool_X[max_idx]), mean[max_idx].ravel()
        else:
            def posterior_mean_val_and_grad(x):
                x = np.atleast_2d(x)
                mean, _ = self.GPmodel.predict_noiseless(x)
                mean_grad, _ = self.GPmodel.predictive_gradients(x)
                return (- mean, -mean_grad.ravel())

            x_min, f_min = self.minimize_continuous_space(posterior_mean_val_and_grad, self.bounds_list, jac=True)
            print('Maximized posterior mean value:', -1*f_min)
            return x_min, - f_min

    def generate_sample_path(self, basis_dim=2000):
        self.rff_features = myutils.RFF(kernel=self.GPmodel.kern, rng=self.rng, basis_dim=basis_dim)

        self.prior_weights_sample = self.rng.normal(0, 1, size=(basis_dim, self.sampling_num))
        transformed_train_X = self.rff_features.transform(self.GPmodel.X)

        b = (
            self.GPmodel.Y_normalized
            - transformed_train_X @ self.prior_weights_sample
            - self.rng.normal(0, np.sqrt(self.GPmodel['.*Gaussian_noise.variance'].values), size=(self.GPmodel.Y.shape[0], self.sampling_num))
            )
        if self.GPmodel.posterior._woodbury_inv is None:
            self.K_inv_features = scipy.linalg.solve_triangular(
                self.GPmodel.posterior._woodbury_chol.T,
                scipy.linalg.solve_triangular(self.GPmodel.posterior._woodbury_chol, b, lower=True, overwrite_b=True),
                overwrite_b=True
                )
        else:
            self.K_inv_features = self.GPmodel.posterior._woodbury_inv @ b # N \times sampling_num


    def get_sample_path_maximizer(self, pool_X=None):
        max_sample = np.zeros(self.sampling_num)
        max_inputs = list()

        if pool_X is None:
            for j in range(self.sampling_num):
                # See https://proceedings.mlr.press/v119/wilson20a.html
                def normalized_sample_path_val_and_grad(x):
                    x = np.atleast_2d(x)
                    X_features = self.rff_features.transform(x)
                    X_features_grad = self.rff_features.transform_grad(x)
                    sample_path_val = X_features.dot(np.c_[self.prior_weights_sample[:, j]]) + self.GPmodel.kern.K(x, self.GPmodel.X) @ self.K_inv_features[:, j]
                    sample_path_grad = X_features_grad.dot(np.c_[self.prior_weights_sample[:, j]]).ravel() + self.dK_dx(x, self.GPmodel.X).T @ self.K_inv_features[:, j]
                    return (- sample_path_val.ravel(), - sample_path_grad.ravel())

                x_min, f_min = self.minimize_continuous_space(normalized_sample_path_val_and_grad, self.bounds_list, jac=True)
                max_sample[j] = - f_min * self.std + self.mean
                max_inputs.append(x_min)
        else:
            transformed_pool_X = self.rff_features.transform(pool_X)
            sample_path = transformed_pool_X @ self.prior_weights_sample + self.GPmodel.kern.K(pool_X, self.GPmodel.X) @ self.K_inv_features
            sample_path = sample_path * self.std + self.mean
            max_idx = np.argmax(sample_path, axis=0)
            max_sample = sample_path[max_idx, np.arange(self.sampling_num)].ravel()
            max_inputs = pool_X[max_idx, :].tolist()

        return max_sample, np.vstack(max_inputs)

    # See https://proceedings.mlr.press/v119/wilson20a.html
    def sample_path_val(self, x):
        x = np.atleast_2d(x)
        X_features = self.rff_features.transform(x)
        sample_path_val = X_features.dot(self.prior_weights_sample) + self.GPmodel.kern.K(x, self.GPmodel.X) @ self.K_inv_features
        return sample_path_val * self.std + self.mean

    ###########################################################
    # Testing modules
    ###########################################################

    def _check_grad(self, func, x):
        x = np.atleast_2d(x)

        # Implemented gradients
        func_val, func_grad = func(x)

        # Numerical approximation of gradients
        func_grad_approx = list()
        for i in range(self.input_dim):
            basis_vec = np.zeros((1, self.input_dim))
            basis_vec[0, i] = 1
            x_temp = x + 1e-10 * basis_vec
            func_val_temp, _ = func(x_temp)
            func_grad_approx.append((func_val_temp - func_val) / 1e-10)

        absolute_error = np.abs(func_grad - np.vstack(func_grad_approx).ravel())
        return absolute_error, np.abs(func_grad)

    def _check_continuous_optimize(self):
        def acq_val_and_grad(x):
            x = np.atleast_2d(x)
            mean, var = self.GPmodel.predict_noiseless(x)
            mean_grad, var_grad = self.GPmodel.predictive_gradients(x)
            acq_val, acq_params_grad = self._acq_params_val_and_grad(np.r_[mean.ravel(), var.ravel()])
            acq_grad = acq_params_grad[0] * mean_grad.ravel() + acq_params_grad[1] * var_grad.ravel()
            return (- acq_val, -acq_grad)

        def LCB_val_and_grad(x):
            x = np.atleast_2d(x)
            mean, var = self.GPmodel.predict_noiseless(x)
            std = np.sqrt(var).ravel()
            mean_grad, var_grad = self.GPmodel.predictive_gradients(x)
            LCB_val = mean.ravel() - 1. * std
            LCB_grad = mean_grad.ravel() - 1. * 0.5 * (var_grad.ravel() / std)
            return (- LCB_val, -LCB_grad)

        def posterior_mean_val_and_grad(x):
            x = np.atleast_2d(x)
            mean, _ = self.GPmodel.predict_noiseless(x)
            mean_grad, _ = self.GPmodel.predictive_gradients(x)
            return (- mean, -mean_grad.ravel())

        np.set_printoptions(precision=4)
        x0s = self._make_start_points()
        # for objective_function in [acq_val_and_grad, LCB_val_and_grad, posterior_mean_val_and_grad]:
        for objective_function in [acq_val_and_grad]:
            print("---------------------------")
            x_min, f_min = np.zeros((1, self.input_dim)), np.inf
            for x0 in x0s:
                if np.any(x0 < self.bounds[0]) or np.any(x0 > self.bounds[1]):
                    print("Warning: Initial point {} is out of bounds {}".format(x0, self.bounds))
                    continue

                res = optimize.minimize(objective_function, x0, method='L-BFGS-B', bounds=self.bounds_list, jac=True)
                x_min_temp, f_min_temp = res.x, res.fun

                abs_error, abs_acq_grad = self._check_grad(objective_function, x0)
                print("numerical gradient error", self.acq(x0), abs_error, abs_acq_grad, x0)

                f0, _ = objective_function(x0)
                if f_min_temp > (f0 + np.abs(f0) * 1e-2):
                    print("x0 and fun(x0):", x0, f0)
                    print("xm and fun(xm):", x_min_temp, f_min_temp)

                if f_min > f_min_temp:
                    x_min, f_min = x_min_temp, f_min_temp
        return 0

    def _check_sample_path(self):
        x0s = self._make_start_points()

        temp_sampling_num = self.sampling_num
        self.sampling_num = 10000
        self.generate_sample_path()

        mean, var = self.GPmodel.predict_noiseless(x0s)
        sample_path_values = self.sample_path_val(x0s)
        mean_approx, var_approx = np.mean(sample_path_values, axis=1), np.var(sample_path_values, axis=1, ddof=1)

        for i in range(np.shape(x0s)[0]):
            plot_x = np.linspace(np.min(sample_path_values[i, :]), np.max(sample_path_values[i, :]), 200)
            pdf = norm.pdf(plot_x, loc=mean[i], scale=np.sqrt(var[i]))
            plt.hist(sample_path_values[i, :], bins=100, density=True, label="sample path at {}".format(x0s[i]))
            plt.plot(plot_x, pdf, label="mu={}, var={}".format(mean[i], var[i]))
            plt.legend()
            plt.savefig("test_figures/sample_path_{}.png".format(i))
            plt.close()

        print("mean error", np.abs(mean.ravel() - mean_approx))
        error = np.c_[np.abs(var.ravel() - var_approx)]
        sort_idx = np.argsort(error.ravel())
        print("var error", np.c_[var[sort_idx], var_approx[sort_idx], error[sort_idx]] )
        self.sampling_num = temp_sampling_num

    def _check_sample_path_grad(self):
        np.set_printoptions(precision=4)
        self.generate_sample_path()
        x0s = self._make_start_points()

        def normalized_sample_path_val_and_grad(x):
            x = np.atleast_2d(x)
            X_features = self.rff_features.transform(x)
            X_features_grad = self.rff_features.transform_grad(x)
            sample_path_val = X_features.dot(np.c_[self.prior_weights_sample[:, 0]]) + self.GPmodel.kern.K(x, self.GPmodel.X) @ self.K_inv_features[:, 0]
            sample_path_grad = X_features_grad.dot(np.c_[self.prior_weights_sample[:, 0]]).ravel() + self.dK_dx(x, self.GPmodel.X).T @ self.K_inv_features[:, 0]
            return (- sample_path_val.ravel(), - sample_path_grad.ravel())

        objective_function = normalized_sample_path_val_and_grad

        x_min, f_min = np.zeros((1, self.input_dim)), np.inf
        for x0 in x0s:
            if np.any(x0 < self.bounds[0]) or np.any(x0 > self.bounds[1]):
                print("Warning: Initial point {} is out of bounds {}".format(x0, self.bounds))
                continue

            res = optimize.minimize(objective_function, x0, method='L-BFGS-B', bounds=self.bounds_list, jac=True)
            x_min_temp, f_min_temp = res.x, res.fun

            print("numerical gradient error", self._check_grad(normalized_sample_path_val_and_grad, x0), x0)

            f0, _ = objective_function(x0)
            if f_min_temp > (f0 + np.abs(f0) * 1e-2):
                print("x0 and fun(x0):", x0, f0)
                print("xm and fun(xm):", x_min_temp, f_min_temp)

            if f_min > f_min_temp:
                x_min, f_min = x_min_temp, f_min_temp

    def dK_dx(self, x: np.ndarray, X2: np.ndarray) -> np.ndarray:
        if x.shape != (1, self.input_dim):
            raise ValueError("x must be a (1 times d) array, but got shape {} in dK_dx".format(x.shape))

        r = self.GPmodel.kern._scaled_dist(x, X2).T # np.shape(X2)[0] \times 1
        dK_dr = self.GPmodel.kern.dK_dr(r)

        diff_X_X2 = x - X2 # np.shape(X2)[0] \times d
        r_zero_idx = r.ravel() == 0
        diff_X_X2[r_zero_idx, :] = 1. # if r = 0, then X - X2 = 0 and its limit X - X2 / r is 1.
        r[r_zero_idx, :] = 1. # avoid division by zero.
        dr_dX = diff_X_X2 / (r * self.GPmodel.kern.lengthscale**2) # derivative w.r.t. scaled L2 norm
        return dK_dr * dr_dX # np.shape(X2)[0] \times d



###########################################################
class Random(BO_core):
    def __init__(self, gp_regressor, bounds, rng=None):
        super().__init__(gp_regressor, bounds, rng=rng)

    def next_input_pool(self, X):
        rand_idx = self.rng.integers(0, np.shape(X)[0], size=1)
        next_input = np.atleast_2d(X[rand_idx])

        # delete min_idx from X
        X = X[np.arange(np.shape(X)[0]) != rand_idx, :]
        return next_input, X

    def next_input(self):
        next_input = self.rng.random((1, self.input_dim)) * (self.bounds[1]- self.bounds[0]) + self.bounds[0]
        return next_input


class ProbabilityImprovement(BO_core):
    def __init__(self, gp_regressor, bounds, rng=None, xi=0):
        super().__init__(gp_regressor, bounds, rng=rng)
        self.xi = xi

    def acq(self, x):
        x = np.atleast_2d(x)
        mean, var = self.GPmodel.predict_noiseless(x)
        std = np.sqrt(var)
        return ((mean - self.y_max - self.xi) / std).ravel()

    def _acq_params_val_and_grad(self, params):
        mean, var = params
        std = np.sqrt(var)

        acq_val = ((mean - self.y_max - self.xi) / std).ravel()
        acq_grad_mean = 1. / std
        acq_grad_var = - acq_val / (2. * var)
        return acq_val, np.r_[acq_grad_mean.ravel(), acq_grad_var.ravel()]


class ExpectedImprovement(BO_core):
    def __init__(self, gp_regressor, bounds, rng=None, xi=0):
        super().__init__(gp_regressor, bounds, rng=rng)
        self.xi = xi

    def acq(self, x):
        x = np.atleast_2d(x)
        mean, var = self.GPmodel.predict_noiseless(x)
        std = np.sqrt(var)
        Z = (mean - self.y_max - self.xi) / std
        return ((Z * std)*norm.cdf(Z) + std*norm.pdf(Z)).ravel()

    # def _acq_params_val_and_grad(self, params):
    #     mean, var = params
    #     std = np.sqrt(var)
    #     Z = (mean - self.y_max - self.xi) / std
    #     pdf = norm.pdf(Z)
    #     cdf = norm.cdf(Z)

    #     acq_val = ((Z * std) * cdf + std * pdf).ravel()
    #     acq_grad_mean = cdf
    #     acq_grad_var = pdf / (2 * std)
    #     return acq_val, np.r_[acq_grad_mean.ravel(), acq_grad_var.ravel()]

    # This is the analytical lower bound. See https://arxiv.org/pdf/1212.4899
    def _approx_phi_cdf(self, z):
        return (np.sqrt(z**2 + 8) - 3. * z) / 4.


    def _acq_params_val_and_grad(self, params):
        mean, var = params
        std = np.sqrt(var)
        Z = (mean - self.y_max - self.xi) / std

        if Z < -8:
            log_std = np.log(std)
            log_cdf = norm.logcdf(Z)
            Z_squared_8_root = np.sqrt(Z**2 + 8)
            acq_val = log_std + log_cdf + np.log((Z_squared_8_root + Z) / 4.)

            approx_phi_divided_cdf = (Z_squared_8_root - 3. * Z) / 4.
            temp_val = 1. / Z_squared_8_root

            acq_grad_mean = (approx_phi_divided_cdf + temp_val) / std

            dZ_dvar = -0.5 * Z / var
            acq_grad_var = 1. / (2. * var) + (approx_phi_divided_cdf + temp_val) * dZ_dvar
        else:
            pdf = norm.pdf(Z)
            cdf = norm.cdf(Z)

            raw_acq_val = (Z * std) * cdf + std * pdf
            acq_val = np.log(raw_acq_val).ravel()
            acq_grad_mean = cdf / raw_acq_val
            acq_grad_var = pdf / (2 * std) / raw_acq_val

        return acq_val, np.r_[acq_grad_mean.ravel(), acq_grad_var.ravel()]

class UncertaintySampling(BO_core):
    def __init__(self, gp_regressor, bounds, rng=None):
        super().__init__(gp_regressor, bounds, rng=rng)

    def acq(self, x):
        x = np.atleast_2d(x)
        _, var = self.GPmodel.predict_noiseless(x)
        return var

    def _acq_params_val_and_grad(self, params):
        _, var = params
        std = np.sqrt(var)
        return std, np.r_[[0], [1. / (2 * std)]]


class GP_UCB(BO_core):
    def __init__(self, gp_regressor, bounds, root_beta_func: Callable, rng=None, iteration=1):
        super().__init__(gp_regressor, bounds, rng=rng)
        self.iteration = iteration
        self.root_beta_func = root_beta_func
        self.root_beta = self.root_beta_func(self.iteration)
        print("root beta is {}".format(self.root_beta))

    def update(self, X, Y, optimize=False):
        self.iteration += 1
        self.root_beta = self.root_beta_func(self.iteration)
        print("root beta in {}-th iteration is {}".format(self.iteration, self.root_beta))
        super().update(X, Y, optimize=optimize)

    def acq(self, x):
        x = np.atleast_2d(x)
        mean, var = self.GPmodel.predict_noiseless(x)
        std = np.sqrt(var)
        return (mean + self.root_beta * std).ravel()

    def _acq_params_val_and_grad(self, params):
        mean, var = params
        std = np.sqrt(var)
        return mean + self.root_beta*std, np.r_[[1.], [self.root_beta / (2 * std)]]


class IRGP_UCB(BO_core):
    def __init__(self, gp_regressor, bounds, s:Callable, rng=None, iteration=1):
        super().__init__(gp_regressor, bounds, rng=rng)
        self.iteration = iteration
        self.s = s

        self.root_beta = np.sqrt(self.s(self.iteration) + self.rng.exponential(scale=2))
        print("root beta in {}-th iteration is {}".format(self.iteration, self.root_beta))

    def update(self, X, Y, optimize=False):
        self.iteration += 1
        self.root_beta = np.sqrt(self.s(self.iteration) + self.rng.exponential(scale=2))
        print("root beta in {}-th iteration is {}".format(self.iteration, self.root_beta))
        super().update(X, Y, optimize=optimize)

    def acq(self, x):
        x = np.atleast_2d(x)
        mean, var = self.GPmodel.predict_noiseless(x)
        std = np.sqrt(var)
        return (mean + self.root_beta * std).ravel()

    def _acq_params_val_and_grad(self, params):
        mean, var = params
        std = np.sqrt(var)
        return mean + self.root_beta*std, np.r_[[1.], [self.root_beta / (2 * std)]]


class ThompsonSampling(BO_core):
    def __init__(self, gp_regressor, bounds, rng=None):
        super().__init__(gp_regressor, bounds, rng=rng)
        self.generate_sample_path()

    def update(self, X, Y, optimize=False):
        super().update(X, Y, optimize=optimize)
        self.generate_sample_path()

    def acq(self, x):
        x = np.atleast_2d(x)
        return self.sample_path_val(x)

    def next_input(self):
        self.maximums, self.max_inputs = self.get_sample_path_maximizer()
        return np.atleast_2d(self.max_inputs[0])

    def next_input_pool(self, X):
        self.maximums, self.max_inputs = self.get_sample_path_maximizer(X)
        next_input = np.atleast_2d(self.max_inputs[0])

        # delete max_idx from X
        X = X[np.logical_not(np.all(X == next_input, axis=1)), :]
        return next_input, X


class MaxValueEntropySearch(BO_core):
    def __init__(self, gp_regressor, bounds, rng=None, sampling_num=10):
        super().__init__(gp_regressor, bounds, rng=rng)
        self.sampling_num = sampling_num

    def pre_computation_acq(self, X_candidates=None):
        self.generate_sample_path()
        self.maximums, self.max_inputs = self.get_sample_path_maximizer(X_candidates)
        print('sampled maximums:', self.maximums)

    def acq(self, x):
        x = np.atleast_2d(x)
        mean, var = self.GPmodel.predict_noiseless(x)
        std = np.sqrt(var)
        normalized_max = (self.maximums - mean) / std
        pdf = norm.pdf(normalized_max)
        cdf = norm.cdf(normalized_max)
        return np.mean((normalized_max * pdf) / (2*cdf) - np.log(cdf), axis=1).ravel()

    def _acq_params_val_and_grad(self, params):
        mean, var = params
        std = np.sqrt(var)
        Z = (self.maximums - mean) / std
        pdf = norm.pdf(Z)
        logcdf = norm.logcdf(Z)
        cdf = np.exp(logcdf)

        # pdf(Z) / cdf(Z) \approx - Z for Z << 0. see inverse mills ratio
        pdf_divided_cdf = - Z
        cdf_nonzero_idx = cdf > 0
        pdf_divided_cdf[cdf_nonzero_idx] = pdf[cdf_nonzero_idx] / cdf[cdf_nonzero_idx]
        # pdf_divided_cdf[cdf_zero_idx] = - Z[cdf_zero_idx]

        acq_val = np.mean(0.5 * Z * pdf_divided_cdf - logcdf).ravel()

        acq_grad_wrt_z = - 0.5 * pdf_divided_cdf * (1 + Z**2 + Z * pdf_divided_cdf)
        z_grad_mean = - 1. / std
        z_grad_var = - Z / (2. * var)
        acq_grad_mean = np.mean(acq_grad_wrt_z * z_grad_mean).ravel()
        acq_grad_var = np.mean(acq_grad_wrt_z * z_grad_var).ravel()
        return acq_val, np.r_[acq_grad_mean, acq_grad_var]


class JointEntropySearch(BO_core):
    def __init__(self, gp_regressor, bounds, rng=None, sampling_num=10):
        super().__init__(gp_regressor, bounds, rng=rng)
        self.sampling_num = sampling_num

    def pre_computation_acq(self, X_candidates=None):
        self.generate_sample_path()
        self.maximums, self.max_inputs = self.get_sample_path_maximizer(X_candidates)
        print('sampled maximums:', self.maximums)
        self.mean_max_samples, self.var_max_samples = self.GPmodel.predict_noiseless(self.max_inputs)
        self.centerized_maximums_divide_var = (self.maximums - self.mean_max_samples) / (self.var_max_samples)


    def acq(self, x):
        x = np.atleast_2d(x)
        mean, var = self.GPmodel.predict_noiseless(x)

        # ignore the noise variance by setting include_likelihood=False
        cov_x_maximum_samples = self.GPmodel.posterior_covariance_between_points(x, self.max_inputs, include_likelihood=False)
        conditional_mean = mean + cov_x_maximum_samples @ self.centerized_maximums_divide_var.T
        conditional_var = var - cov_x_maximum_samples**2 / (self.var_max_samples.T)
        conditional_var[conditional_var <= 0] = 1e-10
        conditional_std = np.sqrt(conditional_var)

        standardized_maximum = (self.maximums.T - conditional_mean) / conditional_std
        Z = norm.cdf(standardized_maximum)
        inv_mills_ratio = - standardized_maximum
        Z_nonzero_idx = Z > 0
        inv_mills_ratio[Z_nonzero_idx] = norm.pdf(standardized_maximum[Z_nonzero_idx]) / Z[Z_nonzero_idx]

        conditional_var_truncated_normal = conditional_var * (1 - standardized_maximum * inv_mills_ratio - inv_mills_ratio**2 )

        acq_val = np.log(var + self.GPmodel['.*Gaussian_noise.variance'].values).ravel() - np.mean(np.log(conditional_var_truncated_normal + self.GPmodel['.*Gaussian_noise.variance'].values), axis=1)

        return acq_val.ravel()

    """
    Since the gradients are hard to implement, I implemented the numerical approximation by the differences at 2 points for each dimension, that is, it requires (d + 1) times acquisition function evaluations.
    """
    def _acq_val_and_grad(self, x):
        x = np.atleast_2d(x)

        shifted_xs = np.r_[x, x + 1e-10 * np.eye(self.input_dim)]
        acq_vals = self.acq(shifted_xs)
        return (- acq_vals[0], - (acq_vals[1:] - acq_vals[0]) / 1e-10)


class PI_from_MaxSample(BO_core):
    def __init__(self, gp_regressor, bounds, rng=None):
        super().__init__(gp_regressor, bounds, rng=rng)

    def pre_computation_acq(self, X_candidates=None):
        self.generate_sample_path()
        self.maximums, self.max_inputs = self.get_sample_path_maximizer(X_candidates)
        print('sampled maximums and point:', self.maximums, np.atleast_2d(self.max_inputs.ravel()))
        print("corresponding acq, mean and var:", self.acq(np.atleast_2d(self.max_inputs.ravel())), self.GPmodel.predict_noiseless(np.atleast_2d(self.max_inputs.ravel())))

    def acq(self, x):
        x = np.atleast_2d(x)
        mean, var = self.GPmodel.predict_noiseless(x)
        std = np.sqrt(var)
        return ((mean - self.maximums) / std).ravel()

    def _acq_params_val_and_grad(self, params):
        mean, var = params
        std = np.sqrt(var)

        acq_val = ((mean - self.maximums) / std).ravel()
        acq_grad_mean = 1. / std
        acq_grad_var = - acq_val / (2. * var)
        return acq_val, np.r_[acq_grad_mean.ravel(), acq_grad_var.ravel()]

class EI_from_MaxSample(BO_core):
    def __init__(self, gp_regressor, bounds, rng=None):
        super().__init__(gp_regressor, bounds, rng=rng)

    def pre_computation_acq(self, X_candidates=None):
        self.generate_sample_path()
        self.maximums, self.max_inputs = self.get_sample_path_maximizer(X_candidates)
        print('sampled maximums and point:', self.maximums, np.atleast_2d(self.max_inputs.ravel()))
        print("corresponding mean and var:", self.GPmodel.predict_noiseless(np.atleast_2d(self.max_inputs.ravel())))

    def acq(self, x):
        x = np.atleast_2d(x)
        mean, var = self.GPmodel.predict_noiseless(x)
        std = np.sqrt(var)
        Z = (mean - self.maximums) / std
        return ((Z * std)*norm.cdf(Z) + std*norm.pdf(Z)).ravel()


    def _acq_params_val_and_grad(self, params):
        mean, var = params
        std = np.sqrt(var)
        Z = (mean - self.maximums) / std

        if Z < -8:
            log_std = np.log(std)
            log_cdf = norm.logcdf(Z)
            Z_squared_8_root = np.sqrt(Z**2 + 8)
            acq_val = log_std + log_cdf + np.log((Z_squared_8_root + Z) / 4.)

            approx_phi_divided_cdf = (Z_squared_8_root - 3. * Z) / 4.
            temp_val = 1. / Z_squared_8_root

            acq_grad_mean = (approx_phi_divided_cdf + temp_val) / std

            dZ_dvar = -0.5 * Z / var
            acq_grad_var = 1. / (2. * var) + (approx_phi_divided_cdf + temp_val) * dZ_dvar
        else:
            pdf = norm.pdf(Z)
            cdf = norm.cdf(Z)

            raw_acq_val = (Z * std) * cdf + std * pdf
            acq_val = np.log(raw_acq_val).ravel()
            acq_grad_mean = cdf / raw_acq_val
            acq_grad_var = pdf / (2 * std) / raw_acq_val

        return acq_val, np.r_[acq_grad_mean.ravel(), acq_grad_var.ravel()]



class EEEI(BO_core):
    def __init__(self, gp_regressor, bounds, rng=None, sampling_num=10):
        super().__init__(gp_regressor, bounds, rng=rng)
        self.sampling_num = sampling_num

    def pre_computation_acq(self, X_candidates=None):
        self.generate_sample_path()
        self.maximums, self.max_inputs = self.get_sample_path_maximizer(X_candidates)
        # print('sampled maximums and point:', self.maximums, np.atleast_2d(self.max_inputs.ravel()))
        # print("corresponding mean and var:", self.GPmodel.predict_noiseless(np.atleast_2d(self.max_inputs.ravel())))

    def acq(self, x):
        x = np.atleast_2d(x)
        mean, var = self.GPmodel.predict_noiseless(x)
        std = np.sqrt(var)
        Z = (mean - np.c_[self.maximums].T) / std
        return np.mean((Z * std)*norm.cdf(Z) + std*norm.pdf(Z), axis=1)


    def _acq_params_val_and_grad(self, params):
        mean, var = params
        std = np.sqrt(var)
        Z = (mean - self.maximums.ravel()) / std

        Z_smaller_idx = Z < -8
        Z_larger_idx = np.logical_not(Z_smaller_idx)

        if np.all(Z_larger_idx):
            pdf = norm.pdf(Z)
            cdf = norm.cdf(Z)

            raw_acq_val = (Z * std) * cdf + std * pdf
            acq_val = np.log(np.mean(raw_acq_val)).ravel()
            acq_grad_mean = cdf / raw_acq_val
            acq_grad_var = pdf / (2 * std) / raw_acq_val

            return acq_val, np.r_[np.mean(acq_grad_mean), np.mean(acq_grad_var)]
        else:
            acq_val = np.zeros(self.sampling_num)
            acq_grad_mean = np.zeros(self.sampling_num)
            acq_grad_var = np.zeros(self.sampling_num)

            # for Z_smaller_idx
            smaller_Z = Z[Z_smaller_idx]
            log_std = np.log(std)
            log_cdf = norm.logcdf(smaller_Z)
            Z_squared_8_root = np.sqrt(smaller_Z**2 + 8)
            acq_val[Z_smaller_idx] = log_std + log_cdf + np.log((Z_squared_8_root + smaller_Z) / 4.)

            approx_phi_divided_cdf = (Z_squared_8_root - 3. * smaller_Z) / 4.
            temp_val = 1. / Z_squared_8_root

            acq_grad_mean[Z_smaller_idx] = (approx_phi_divided_cdf + temp_val) / std

            dZ_dvar = -0.5 * smaller_Z / var
            acq_grad_var[Z_smaller_idx] = 1. / (2. * var) + (approx_phi_divided_cdf + temp_val) * dZ_dvar

            # for Z_larger_idx
            if np.any(Z_larger_idx):
                larger_Z = Z[Z_larger_idx]
                pdf = norm.pdf(larger_Z)
                cdf = norm.cdf(larger_Z)

                raw_acq_val = (larger_Z * std) * cdf + std * pdf
                acq_val[Z_larger_idx] = np.log(raw_acq_val).ravel()
                acq_grad_mean[Z_larger_idx] = cdf / raw_acq_val
                acq_grad_var[Z_larger_idx] = pdf / (2 * std) / raw_acq_val

            return np.mean(acq_val), np.r_[np.mean(acq_grad_mean), np.mean(acq_grad_var)]



class EI_from_MaxMean(BO_core):
    def __init__(self, gp_regressor, bounds, root_nu_func:Callable, rng=None, iteration=1):
        super().__init__(gp_regressor, bounds, rng=rng)
        self.iteration = iteration
        self.root_nu_func = root_nu_func

        self.root_nu = self.root_nu_func(self.iteration)
        print("root nu in {}-th iteration is {}".format(self.iteration, self.root_nu))

    def pre_computation_acq(self, X_candidates=None):
        _, self.mean_max = self.posterior_maximum(X_candidates)

    def update(self, X, Y, optimize=False):
        super().update(X, Y, optimize=optimize)
        self.iteration = self.iteration + 1
        self.root_nu = self.root_nu_func(self.iteration)
        print("root nu in {}-th iteration is {}".format(self.iteration, self.root_nu))

    def acq(self, x):
        x = np.atleast_2d(x)
        mean, var = self.GPmodel.predict_noiseless(x)
        std = np.sqrt(var)
        root_nu_std = std * self.root_nu
        Z = (mean - self.mean_max) / (root_nu_std)
        return ((Z * root_nu_std)*norm.cdf(Z) + root_nu_std*norm.pdf(Z)).ravel()

    def _acq_params_val_and_grad(self, params):
        mean, var = params
        std = np.sqrt(var)
        root_nu_std = std * self.root_nu
        Z = (mean - self.mean_max) / root_nu_std

        if Z < -8:
            log_std = np.log(root_nu_std)
            log_cdf = norm.logcdf(Z)
            Z_squared_8_root = np.sqrt(Z**2 + 8)
            acq_val = log_std + log_cdf + np.log((Z_squared_8_root + Z) / 4.)

            approx_phi_divided_cdf = (Z_squared_8_root - 3. * Z) / 4.
            temp_val = 1. / Z_squared_8_root

            acq_grad_mean = (approx_phi_divided_cdf + temp_val) / root_nu_std

            dZ_dvar = -0.5 * Z / var
            acq_grad_var = 1. / (2. * var) + (approx_phi_divided_cdf + temp_val) * dZ_dvar
        else:
            pdf = norm.pdf(Z)
            cdf = norm.cdf(Z)

            raw_acq_val = (Z * root_nu_std) * cdf + root_nu_std * pdf
            acq_val = np.log(raw_acq_val).ravel()
            acq_grad_mean = cdf / raw_acq_val
            acq_grad_var = pdf * self.root_nu / (2 * std) / raw_acq_val

        return acq_val, np.r_[acq_grad_mean.ravel(), acq_grad_var.ravel()]