import numpy as np
from numpy.linalg import LinAlgError, cholesky, lstsq, pinv, det
from scipy.optimize import minimize
from scipy.stats import norm
from typing import Callable
import random
from hpbandster.optimizers.config_generators.simulated_annealing import SimulatedAnnealing


import logging

EPS = 1e-8
logger = logging.getLogger(__name__)

print("random seed:", np.random.rand())


def lhs(low_bounds: np.array, up_bounds: np.array, nb: int = 6) -> np.ndarray:
    assert low_bounds.ndim == up_bounds.ndim == 1
    assert low_bounds.size == up_bounds.size

    ind = np.empty((nb, low_bounds.size))
    for i in range(low_bounds.size):
        ind[:, i] = random.sample(list(range(nb)), nb)
    # print(ind)

    def _sample(j: int, pos: int) -> float:
        l, u = low_bounds[j], up_bounds[j]
        sample = l + (u - l) / nb * pos + (u - l) / nb * np.random.rand()
        # print("j:", j, "pos:", pos, "sample:", sample, l + (u - l) / nb * pos, l + (u - l) / nb * (pos + 1))
        return sample

    x = np.empty((nb, low_bounds.size))
    for i in range(nb):
        x[i, :] = np.array([_sample(j, pos) for j, pos in enumerate(ind[i, :])])
    return x


class Data:
    """
    Store samples in the learning process.
    Note: not storing same samples to avoid non-positive definite kernel.

    Properties:
        X:
    """

    def __init__(self, dim: int):
        self._dim = dim
        self._collection_X = np.empty(shape=(0, self._dim))
        self._collection_y = np.empty(shape=(0,))

    def append(self, X: np.array, y: np.array):
        assert X.ndim == 2 and y.ndim == 1
        assert X.shape[1] == self._dim

        self._collection_X = np.concatenate([self._collection_X, X])
        self._collection_y = np.concatenate([self._collection_y, y])

    def assign(self, X: np.ndarray, y: np.array):
        logger.debug(f"X shape = {X.shape}")
        assert X.ndim == 2 and y.ndim == 1
        assert X.shape[1] == self._dim

        self._collection_X = X
        self._collection_y = y

    @property
    def X(self) -> np.ndarray:
        return self._collection_X

    @property
    def Y(self) -> np.ndarray:
        return self._collection_y

    def __len__(self) -> int:
        assert len(self._collection_X) == len(self._collection_y)
        return self._collection_y.size

    @property
    def dim(self) -> int:
        return self._dim


class BayesEI:
    """GP model with Expected Quantile Improvement.
    """

    def __init__(self, dim: int,
                 x_low_bounds: np.array,
                 x_up_bounds: np.array,
                 param_low_bounds: np.array = None,
                 param_up_bounds: np.array = None,
                 delta: float = 0.01
                 ):
        """
        Args:
            dim: dimension of X.
            x_low_bounds: (dim,)
            x_up_bounds: (dim,)
            param_low_bounds: (dim+1,)
            param_up_bounds: (dim+1,)
            delta: expected improvement tolerance.
        """
        assert x_low_bounds.size == x_up_bounds.size == dim

        self._dim = dim
        self._data = Data(dim)  # store data
        self._x_low_bounds = x_low_bounds
        self._x_up_bounds = x_up_bounds
        self._param_low_bounds = np.ones(dim + 1) * 1e-4 if param_low_bounds is None else param_low_bounds
        self._param_up_bounds = np.ones(dim + 1) * 1e4 if param_up_bounds is None else param_up_bounds
        self._delta = delta

        assert self._param_low_bounds.size == self._param_up_bounds.size == dim + 1

        self._theta = None      # dim
        self._tau = None        # 1

        self._min_Qn = None     # 避免重复计算

        self._nll = []
        self._train_kernel = []
        self._stable_noise = 1e-2

    def _kernel(self, X1: np.ndarray, X2: np.ndarray, params: np.array = None):

        if not (X1.shape[1] == X2.shape[1] == self._dim):
            print(X1.shape, X2.shape, self._dim)

        assert X1.ndim == 2 and X2.ndim == 2
        assert X1.shape[1] == X2.shape[1] == self._dim

        dim = self._dim

        if params is not None:
            theta = params[0:dim]
            tau = params[dim]
        else:
            theta = self._theta
            tau = self._tau

        K1 = np.sum((X1**2) * theta[0], axis=1).reshape(-1, 1) + np.sum((X2**2) * theta[0], axis=1) - 2 * np.dot(X1 * theta[0], X2.T)
        K1 = np.exp(-K1)
        return tau ** 2 * K1

    def fit(self, Xn: np.ndarray, Yn: np.array, nb_init_points: int = 50):
        """
        Fit one Gaussian process model on X and y:
        1. Compute negative log likelihood loss.
        2. Minimize the loss.

        Args:
            Xhn: array (m x d)
            Yn: array (m)
        """
        self._data.assign(Xn, Yn)

        def nll(params: np.array) -> float:
            """
            Negative Log Likelyhood.

            Log likelyhood:
                max log p(y|X) = log N(y|0,K) = - 0.5* y^T K^{-1} y - 0.5 * log|K| - N/2 * log(2*pi)
            Negative log likelyhood:
                - log p(y|X).
            """
            Kyy = self._kernel(self._data.X, self._data.X, params) + self._stable_noise * np.eye(len(self._data))
            self._train_kernel.append(Kyy)

            m = Kyy.shape[0]
            try:
                L = cholesky(Kyy + EPS * np.eye(m))
            except LinAlgError:
                # print(f"Kyy={Kyy}")
                # print(f"X={self._data.Xh}")
                # print(f"params={params}")
                raise RuntimeError("cholesky")
            LogDet = np.sum(np.log(np.diagonal(L) + EPS))
            Sigma = 0.5 * self._data.Y.T.dot(lstsq(L.T, lstsq(L, self._data.Y, rcond=EPS)[0], rcond=EPS)[0])
            return LogDet + Sigma + 0.5 * len(self._data) * np.log(2 * np.pi)

            # return 0.5 * np.log(det(Kyy)) + \
            #     0.5 * self._data.Y.T.dot(pinv(Kyy).dot(self._data.Y)) + \
            #     0.5 * len(self._data) * np.log(2 * np.pi)

        # To avoid local minimum, calculate the optimal point by starting from dozens of initial points.
        min_val = float('inf')
        best_res = None
        # initial_points = random_with_bound(self._param_low_bounds, self._param_up_bounds, nb_init_points)
        initial_points = lhs(self._param_low_bounds, self._param_up_bounds, nb_init_points)
        for point in initial_points:
            res = minimize(nll, x0=point,
                           bounds=tuple((l, u) for l, u in zip(self._param_low_bounds, self._param_up_bounds)),
                           method="l-bfgs-b")
            if res.fun < min_val:
                min_val = res.fun
                best_res = res

        self._nll.append(min_val)
        params = best_res.x
        dim = self._dim
        self._theta = params[0:dim]
        self._tau = params[dim]

    def _acquisition(self, Xn_new: np.array) -> float:
        """Calculate acquisition.
        Args:
            Xn_new: (nx, dim). new points.
        """
        assert Xn_new.ndim == 2
        min_f = self._data.Y.min()              # min_f: ()

        Xn = self._data.X                       # Xn: (n, dim)
        Yn = self._data.Y                       # Yn: (n,)

        Kyy = self._kernel(Xn, Xn) + self._stable_noise * np.eye(Xn.shape[0])        # Kyy: (n, n)
        Kyf = self._kernel(Xn, Xn_new)          # Kyf: (n, nx)
        Kff = self._kernel(Xn_new, Xn_new)      # Kff: (nx, nx)

        L = cholesky(Kyy)
        mu_new = Kyf.T.dot(lstsq(L.T, lstsq(L, Yn, rcond=EPS)[0], rcond=EPS)[0])
        Sigma_new = Kff - Kyf.T.dot(lstsq(L.T, lstsq(L, Kyf, rcond=EPS)[0], rcond=EPS)[0])
        std_new = np.sqrt(np.diag(Sigma_new))                   # (nx,)

        if std_new > EPS:
            Z = (min_f - mu_new + self._delta) / std_new        # (nx,)
        else:
            Z = 0

        acq = (min_f - mu_new + self._delta) * norm.cdf(Z) + std_new * norm.pdf(Z)  # acq: (nx,)
        acq = acq[0]
        return acq

    def propose_point(self, nb_init_points: int = 100) -> np.ndarray:
        """
        Propose a point with the maximum expected quantile improvement.
        Returns:
            proposed_point: one point with shape (1, dim).
        """

        def min_obj(xh: np.array) -> float:
            """
            Args:
                xh: one point.

            Returns:
                minus_acq: negative acquisition.
            """
            xh = xh.reshape(1, -1)
            minus_acq = -self._acquisition(xh)
            return minus_acq

        # min_val = float('inf')
        # best_point = None
        # initial_points = random_with_bound(self._x_low_bounds, self._x_up_bounds, nb_init_points)

        # for point in initial_points:
        #     obj = min_obj(point)
        #     if obj < min_val:
        #         min_val = obj
        #         best_point = point

        init_x0 = np.random.uniform(self._x_low_bounds, self._x_up_bounds)
        # np.random.uniform(self._param_low_bounds, self._param_up_bounds)
        sa = SimulatedAnnealing(min_f=min_obj,
                                x0=init_x0,
                                x_low_bounds=self._x_low_bounds,
                                x_up_bounds=self._x_up_bounds)
        best_point = sa.annealing()

        proposed_point = best_point.reshape(1, -1)
        return proposed_point

    def predict(self, Xn_new: np.ndarray) -> np.ndarray:
        """Calculate mu_new, std_new.
        Args:
            Xn_new: array (nx, dim).
        Returns:
            mu_new: (nx,)
            std_new: (nx,)
        """
        Xn_new = np.array(Xn_new)[np.newaxis, :]
        assert Xn_new.ndim == 2          # x: (nx, dim)

        Xn = self._data.X                       # Xn: (n, dim)
        Yn = self._data.Y                       # Yn: (n,)

        # self._param_up_bounds = np.array([1000.0, 1.00])
        # self.fit(Xn, Yn)
        # print("param:", self.params)

        Kyy = self._kernel(Xn, Xn) + self._stable_noise * np.eye(Xn.shape[0])             # Kyy: (n, n)
        Kyf = self._kernel(Xn, Xn_new)          # Kyf: (n, nx)
        Kff = self._kernel(Xn_new, Xn_new)      # Kff: (nx, nx)

        L = cholesky(Kyy)     # L: (nx, nx)
        mu_new = Kyf.T.dot(lstsq(L.T, lstsq(L, Yn, rcond=EPS)[0], rcond=EPS)[0])            # mu_new: (nx,)
        Sigma_new = Kff - Kyf.T.dot(lstsq(L.T, lstsq(L, Kyf, rcond=EPS)[0], rcond=EPS)[0])  # Sigma_new: (nx, nx)
        std_new = np.sqrt(np.diag(Sigma_new))                   # (nx,)

        return mu_new, std_new

        # # beta
        # t0 = pinv(Kyy)                           # t0: (n, n)
        # t1 = Xn.T.dot(t0)                      # t1: (dim, n)
        # t2 = pinv(t1.dot(Xn))                  # t2: (dim, dim)
        # beta = t2.dot(t1.dot(Yn))               # beta: (dim,)

        # mu_new = Xn_new.dot(beta) + Kyf.T.dot(pinv(Kyy).dot(Yn - Xn.dot(beta)))

        # # rn
        # t4 = Xn_new - Kyf.T.dot(t0.dot(Xn))
        # Sigma_new = self._tau ** 2 + t4.dot(t2).dot(t4.T) - Kyf.T.dot(t0).dot(Kyf)     # rn2: (nx, nx)
        # std_new = np.sqrt(np.diag(Sigma_new))              # rn: (nx,)
        # return mu_new, std_new

    @property
    def params(self):
        return np.concatenate((self._theta, [self._tau]), axis=0)


if __name__ == "__main__":
    # A = np.random.rand(3, 3)
    # A = A.dot(A.T)
    # v = np.arange(3)
    # b = invmat_vec(A, v)
    # print("b:", b, b.shape)

    # a = random_with_bound(np.array([1, 2, 3]), np.array([2, 3, 4]), 3)
    # print(a.shape, a)
    # b = np.random.uniform(np.array([1, 2, 3]), np.array([2, 3, 4]), size=(2, 3))
    # print(b.shape, b)

    M = BayesEI(dim=2,
    x_low_bounds=np.array([1e-3, 1e-3]),
    x_up_bounds=np.array([1e3,1e3]),
    param_low_bounds=np.array([1e-3,1e-3]),
    param_up_bounds=np.array([1e3,1e3]),
    delta=0.01,
    )