import numpy as np
from numpy.linalg import LinAlgError, cholesky, lstsq, pinv, det
from scipy.optimize import minimize
from scipy.linalg import cho_solve
from scipy.stats import norm
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel
import copy

import logging

EPS = 1e-8
logger = logging.getLogger(__name__)

print("random seed:", np.random.rand())


class Data:
    """
    Store samples in the learning process.
    Note: not storing same samples to avoid non-positive definite kernel.

    Properties:
        X:
    """

    def __init__(self, dim: int):
        self._dim = dim
        self._collection_X = np.empty(shape=(0, self._dim))
        self._collection_y = np.empty(shape=(0,))

    def append(self, X: np.array, y: np.array):
        assert X.ndim == 2 and y.ndim == 1
        assert X.shape[1] == self._dim

        self._collection_X = np.concatenate([self._collection_X, X])
        self._collection_y = np.concatenate([self._collection_y, y])

    def assign(self, X: np.ndarray, y: np.array):
        assert X.ndim == 2 and y.ndim == 1
        assert X.shape[1] == self._dim

        self._collection_X = X
        self._collection_y = y

    @property
    def X(self) -> np.ndarray:
        return self._collection_X

    @property
    def Y(self) -> np.ndarray:
        return self._collection_y

    def __len__(self) -> int:
        assert len(self._collection_X) == len(self._collection_y)
        return self._collection_y.size

    @property
    def dim(self) -> int:
        return self._dim


class GPEI:
    """GP model with Expected Quantile Improvement.
    """

    def __init__(self, dim: int,
                 x_low_bounds: np.array,
                 x_up_bounds: np.array,
                 param_low_bounds: np.array = None,
                 param_up_bounds: np.array = None,
                 nb_restart_trial: int = 20,            # fit model by several restarts.
                 delta: float = 0.01,
                 noise=1e-5,
                 ):
        """
        Args:
            dim: dimension of X.
            x_low_bounds: (dim,)
            x_up_bounds: (dim,)
            param_low_bounds: (dim+1,)
            param_up_bounds: (dim+1,)
            delta: expected improvement tolerance.
            noise: noise in gaussian process model.
        """
        assert x_low_bounds.size == x_up_bounds.size == dim
        if param_low_bounds is not None and param_up_bounds is not None:
            assert param_low_bounds.size == param_up_bounds.size == dim + 1
        self._dim = dim
        self._x_low_bounds = x_low_bounds
        self._x_up_bounds = x_up_bounds
        self._nb_restart_trial = nb_restart_trial
        self._data = Data(dim)
        self._noise = noise
        self._delta = delta

        if param_low_bounds is None:
            param_low_bounds = np.ones(dim + 1) * 1e-3
        if param_up_bounds is None:
            param_up_bounds = np.ones(dim + 1) * 1e3

        tau = ConstantKernel(1.0, (param_low_bounds[-1], param_up_bounds[-1]))
        rbf = RBF(np.ones(dim) * 10, [(l, u) for l, u in zip(param_low_bounds[:-1], param_up_bounds[:-1])])
        self._kernel = tau * rbf
        self._gp = GaussianProcessRegressor(kernel=self._kernel, alpha=noise, n_restarts_optimizer=self._nb_restart_trial)

    def kernel(self, X1: np.ndarray, X2: np.ndarray, params: np.array = None):
        assert X1.ndim == 2 and X2.ndim == 2
        assert X1.shape[1] == X2.shape[1] == self._dim

        if params is None:
            # fit之后使用最新的参数计算kernel。
            return self._gp.kernel_(X1, X2)
        else:
            # 赋值，计算kernel。
            self._kernel.theta = params
            return self._kernel(X1, X2)

    def fit(self, Xn: np.ndarray, Yn: np.array):
        """
        Fit one Gaussian process model on X and y:
        1. Compute negative log likelihood loss.
        2. Minimize the loss.

        Args:
            Xhn: array (m x d)
            Yn: array (m)
        """
        assert Xn.ndim == 2 and Yn.ndim == 1
        assert Xn.shape[0] == Yn.size and Xn.shape[1] == self._dim

        self._data.assign(Xn, Yn)
        self._gp.fit(Xn, Yn)

    def _acquisition(self, Xn_new: np.array) -> float:
        """Calculate acquisition.
        Args:
            Xn_new: (1, dim). new points.
        """
        assert Xn_new.ndim == 2
        min_f = self._data.Y.min()              # min_f: ()

        Xn = self._data.X                       # Xn: (n, dim)
        Yn = self._data.Y                       # Yn: (n,)

        Kyy = self.kernel(Xn, Xn) + self._noise * np.eye(Xn.shape[0])               # Kyy: (n, n)
        Kyf = self.kernel(Xn, Xn_new)                                               # Kyf: (n, 1)
        Kff = self.kernel(Xn_new, Xn_new) + self._noise * np.eye(Xn_new.shape[0])   # Kff: (1, 1)

        L = cholesky(Kyy)
        mu_new = Kyf.T.dot(lstsq(L.T, lstsq(L, Yn, rcond=EPS)[0], rcond=EPS)[0]).ravel()[0]             # scalar
        sigma_new = Kff - Kyf.T.dot(lstsq(L.T, lstsq(L, Kyf, rcond=EPS)[0], rcond=EPS)[0]).ravel()[0]   # scalar
        std_new = np.sqrt(sigma_new)                                                                    # scalar

        if std_new < EPS:
            return 0
        else:

            Z = (min_f - mu_new + self._delta) / std_new                                                # scalar
            acq = (min_f - mu_new + self._delta) * norm.cdf(Z) + std_new * norm.pdf(Z)                  # scalar
            return acq

    def propose_point(self, nb_init_points: int = 50) -> np.ndarray:
        """
        Propose a point with the maximum expected quantile improvement.
        Returns:
            proposed_point: one point with shape (1, dim).
        """

        def min_obj(xh: np.array) -> float:
            """
            Args:
                xh: one point.

            Returns:
                minus_acq: negative acquisition.
            """
            xh = xh.reshape(1, -1)
            minus_acq = -self._acquisition(xh)
            return minus_acq

        min_val = float('inf')
        best_point = None

        initial_points = np.random.uniform(low=self._x_low_bounds, high=self._x_up_bounds, size=(nb_init_points, self._dim))
        for i in range(initial_points.shape[0]):
            res = minimize(min_obj,
                           x0=initial_points[i, :],
                           bounds=[(l, u) for l, u in zip(self._x_low_bounds, self._x_up_bounds)],
                           )

            if min_val > res.fun:
                min_val = res.fun
                best_point = res.x

        return best_point.reshape(1, -1)

    def predict(self, Xn_new: np.ndarray) -> np.ndarray:
        """Calculate mu_new, std_new.
        Args:
            Xn_new: array (nx, dim).
        Returns:
            mu_new: (nx,)
            std_new: (nx,)
        """
        assert Xn_new.ndim == 2          # x: (nx, dim)

        Xn = self._data.X                       # Xn: (n, dim)
        Yn = self._data.Y                       # Yn: (n,)

        Kyy = self.kernel(Xn, Xn) + self._noise * np.eye(Xn.shape[0])               # Kyy: (n, n)
        Kyf = self.kernel(Xn, Xn_new)                                               # Kyf: (n, nx)
        Kff = self.kernel(Xn_new, Xn_new) + self._noise * np.eye(Xn_new.shape[0])   # Kff: (nx, nx)

        L = cholesky(Kyy)     # L: (nx, nx)
        mu_new = Kyf.T.dot(lstsq(L.T, lstsq(L, Yn, rcond=EPS)[0], rcond=EPS)[0])            # mu_new: (nx,)
        Sigma_new = Kff - Kyf.T.dot(lstsq(L.T, lstsq(L, Kyf, rcond=EPS)[0], rcond=EPS)[0])  # Sigma_new: (nx, nx)
        std_new = np.sqrt(np.diag(Sigma_new))                                               # (nx,)

        return mu_new, std_new

    @property
    def theta(self):
        # fit之后得到
        return self._gp.kernel_.theta


def f(x):
    return x * np.sin(x)


if __name__ == "__main__":
    model = GPEI(dim=1,
                 x_low_bounds=np.array([1.0]),
                 x_up_bounds=np.array([10.0]),
                 param_low_bounds=np.array([1e-3, 1e-3]),
                 param_up_bounds=np.array([1e3, 1e3]))

    X = np.array([1.0, 3.0, 5.0, 6.0, 7.0, 8.0]).reshape(-1, 1)
    y = f(X).ravel()

    model.fit(X, y)

    newx = model.propose_point()
    print(newx)

    from matplotlib import pyplot as plt
    x = np.atleast_2d(np.linspace(0, 10, 1000)).T
    y_pred, sigma = model.predict(x)
    plt.plot(x, f(x), 'r:', label=r'$f(x) = x\,\sin(x)$')
    plt.plot(X, y, "r.", label="observation")
    plt.plot(x, y_pred, "b-", label="prediction")
    plt.fill_between(x.ravel(), y_pred - sigma * 1.96, y_pred + sigma * 1.96, alpha=0.3)
    plt.show()
