import numpy as np
from numpy.linalg import LinAlgError, cholesky, lstsq, pinv, det, inv
from scipy.optimize import minimize
from scipy.stats import norm
from typing import Callable
import logging


from hpbandster.optimizers.config_generators.simulated_annealing import SimulatedAnnealing


EPS = 1e-8
logger = logging.getLogger(__name__)

print("random seed:", np.random.rand())


def random_with_bound(low_bounds: np.array, up_bounds: np.array, nb: int) -> np.array:
    """
    Random sample points in low bounds and up bounds.
    Args:
        low_bounds: array (d).
        up_bounds: array (d), must have the same length as low_bounds.
        nb: int (n), number.
    Returns:
        sampled_points: array (n, d), random sampled points.
    """
    assert len(low_bounds) == len(up_bounds)
    d = len(low_bounds)
    sampled_points = np.random.rand(nb, d) * (up_bounds - low_bounds) + low_bounds
    return sampled_points


def invmat_vec(Kyy: np.ndarray, v: np.array) -> np.array:
    assert Kyy.ndim == 2 and (v.ndim == 1 or v.ndim == 2)
    assert Kyy.shape[0] == Kyy.shape[1]
    m = Kyy.shape[0]
    try:
        L = cholesky(Kyy + EPS * np.eye(m))
    except LinAlgError:
        raise RuntimeError("cholesky")
    return lstsq(L.T, lstsq(L, v, rcond=EPS)[0], rcond=EPS)[0]


class Data:
    """
    Store samples in the learning process.
    Note: not storing same samples to avoid non-positive definite kernel.

    Properties:
        X:
    """

    def __init__(self, dim: int):
        self._dim = dim
        self._collection_X = np.empty(shape=(0, self._dim))
        self._collection_y = np.empty(shape=(0,))

    def append(self, X: np.array, y: np.array):
        assert X.ndim == 2 and y.ndim == 1
        assert X.shape[1] == self._dim

        self._collection_X = np.concatenate([self._collection_X, X])
        self._collection_y = np.concatenate([self._collection_y, y])

    def assign(self, X: np.ndarray, y: np.array):
        logger.debug(f"X shape = {X.shape}")
        assert X.ndim == 2 and y.ndim == 1
        assert X.shape[1] == self._dim

        self._collection_X = X
        self._collection_y = y

    @property
    def Xh(self) -> np.ndarray:
        return self._collection_X

    @property
    def Y(self) -> np.ndarray:
        return self._collection_y

    def __len__(self) -> int:
        assert len(self._collection_X) == len(self._collection_y)
        return self._collection_y.size

    @property
    def dim(self) -> int:
        return self._dim


class BayesEQI:
    """GP model with Expected Quantile Improvement.
    """

    def __init__(self, dim: int,
                 xh_low_bounds: np.array,
                 xh_up_bounds: np.array,
                 param_low_bounds: np.array = None,
                 param_up_bounds: np.array = None,
                 alpha: float = 0.5,
                 omega: float = 2.0):
        """
        Args:
            dim: dimension of data, including X and h.
            low_bounds: low bounds of X
            up_bounds: up bounds of X
        """
        assert xh_low_bounds.size == xh_up_bounds.size == dim

        self._dim = dim
        self._data = Data(dim)  # store data
        self._xh_low_bounds = xh_low_bounds
        self._xh_up_bounds = xh_up_bounds
        self._param_low_bounds = np.ones(dim * 2) * 1e-4 if param_low_bounds is None else param_low_bounds
        self._param_up_bounds = np.ones(dim * 2) * 1e4 if param_up_bounds is None else param_up_bounds
        self._alpha = alpha
        self._omega = omega

        assert self._param_low_bounds.size == self._param_up_bounds.size == 2 * dim

        self._theta = None      # dim - 1
        self._ltheta = None     # dim - 1
        self._tau = None        # 1
        self._ltau = None       # 1
        self._beta = None       # dim

        self._t0 = None         # 避免重复计算
        self._t2 = None         # 避免重复计算
        self._min_Qn = None     # 避免重复计算

        self._nll = []
        self._train_kernel = []
        self._stable_noise = 1e-2

    def _kernel(self, Xh1: np.ndarray, Xh2: np.ndarray, params: np.array = None):
        assert Xh1.ndim == 2 and Xh2.ndim == 2
        assert Xh1.shape[1] == Xh2.shape[1] == self._dim

        dim = self._dim - 1
        m, n = Xh1.shape[0], Xh2.shape[0]

        if params is not None:
            theta = params[0:dim]
            ltheta = params[dim:dim * 2]
            tau = params[dim * 2]
            ltau = params[dim * 2 + 1]
        else:
            theta = self._theta
            ltheta = self._ltheta
            tau = self._tau
            ltau = self._ltau

        X1, X2 = Xh1[:, 0:-1], Xh2[:, 0:-1]
        h1, h2 = Xh1[:, -1], Xh2[:, -1]

        if X1 is None:
            print("X1 is none")
        if X2 is None:
            print("X2 is none")
        if theta is None:
            print("theta is none")
        # print("X1:", X1.shape)
        # print("X2:", X2.shape)
        # print("theta:", theta.shape)
        K1 = np.sum((X1**2) * theta, axis=1).reshape(-1, 1) + np.sum((X2**2) * theta, axis=1) - 2 * np.dot(X1 * theta, X2.T)
        K1 = np.exp(-K1)

        K2 = np.sum((X1**2) * ltheta, axis=1).reshape(-1, 1) + np.sum((X2**2) * ltheta, axis=1) - 2 * np.dot(X1 * ltheta, X2.T)
        K2 = np.exp(-K2)

        hmin = np.minimum(np.tile(h1.reshape(-1, 1), (1, n)), np.tile(h2.reshape(1, -1), (m, 1)))
        return tau ** 2 * K1 + ltau ** 2 * K2 * hmin ** self._omega

    def _kernel_no_noise(self, Xh1: np.ndarray, Xh2: np.ndarray):
        assert Xh1.ndim == 2 and Xh2.ndim == 2
        assert Xh1.shape[1] == Xh2.shape[1] == self._dim

        theta = self._theta
        tau = self._tau

        X1, X2 = Xh1[:, 0:-1], Xh2[:, 0:-1]
        K1 = np.sum((X1**2) * theta[0], axis=1).reshape(-1, 1) + np.sum((X2**2) * theta[0], axis=1) - 2 * np.dot(X1 * theta[0], X2.T)
        K1 = np.exp(-K1)
        return tau ** 2 * K1

        # return self._kernel(Xh1, Xh2)

    def fit(self, Xhn: np.ndarray, Yn: np.array, nb_init_points: int = 50000):
        """
        Fit one Gaussian process model on X and y by simulated annealing method.
        1. Compute negative log likelihood loss.
        2. Minimize the loss.

        Args:
            Xhn: array (m x d)
            Yn: array (m)
        """
        self._data.assign(Xhn, Yn)

        def nll(params: np.array) -> float:
            """
            Negative Log Likelyhood.

            Log likelyhood:
                max log p(y|X) = log N(y|0,K) = - 0.5* y^T K^{-1} y - 0.5 * log|K| - N/2 * log(2*pi)
            Negative log likelyhood:
                - log p(y|X).
            """
            Kyy = self._kernel(self._data.Xh, self._data.Xh, params) + self._stable_noise * np.eye(len(self._data))
            m = Kyy.shape[0]
            try:
                L = cholesky(Kyy + EPS * np.eye(m))
            except LinAlgError:
                raise RuntimeError("error: cholesky")
            LogDet = np.sum(np.log(np.diagonal(L) + EPS))
            Sigma = 0.5 * self._data.Y.T.dot(lstsq(L.T, lstsq(L, self._data.Y, rcond=EPS)[0], rcond=EPS)[0])
            return LogDet + Sigma + 0.5 * len(self._data) * np.log(2 * np.pi)
            # Kyy = self._kernel(self._data.Xh, self._data.Xh, params) + self._stable_noise * np.eye(len(self._data))
            # m = Kyy.shape[0]
            # self._train_kernel.append(Kyy)
            # return 0.5 * np.log(det(Kyy + EPS * np.eye(m))) + \
            #     0.5 * self._data.Y.T.dot(pinv(Kyy + EPS * np.eye(m)).dot(self._data.Y)) + \
            #     0.5 * len(self._data) * np.log(2 * np.pi)

        # To avoid local minimum, calculate the optimal point by starting from dozens of initial points.
        min_val = float('inf')
        best_res = None

        init_x0 = np.random.uniform(self._param_low_bounds, self._param_up_bounds)
        np.random.uniform(self._param_low_bounds, self._param_up_bounds)
        sa = SimulatedAnnealing(min_f=nll,
                                x0=init_x0,
                                x_low_bounds=self._param_low_bounds,
                                x_up_bounds=self._param_up_bounds)
        best_res = sa.annealing()
        params = best_res
        min_val = nll(best_res)

        # min_val = float('inf')
        # best_res = None
        # initial_points = random_with_bound(self._param_low_bounds, self._param_up_bounds, nb_init_points)

        # initial_points = lhs(self._param_low_bounds, self._param_up_bounds, nb_init_points)

        # for point in initial_points:
        #     res = minimize(nll, x0=point,
        #                    bounds=tuple((l, u) for l, u in zip(self._param_low_bounds, self._param_up_bounds)),
        #                    method="l-bfgs-b")
        #     if res.fun < min_val:
        #         min_val = res.fun
        #         best_res = res
        # self._nll.append(min_val)
        # params = best_res.x

        # for point in initial_points:
        #     val = nll(point)
        #     if val < min_val:
        #         min_val = val
        #         best_res = point

        params = best_res
        self._nll.append(min_val)

        dim = self._dim - 1
        self._theta = params[0:dim]
        self._ltheta = params[dim:dim * 2]
        self._tau = params[dim * 2]
        self._ltau = params[dim * 2 + 1]

        if self._theta is None:
            print("self._theta is none")

    def _pre_acquisition(self):
        Xhn = self._data.Xh                     # Xhn: (n, dim)
        Yn = self._data.Y                       # Yn: (n,)
        assert Xhn.ndim == 2 and Yn.ndim == 1
        psi = self._kernel(Xhn, Xhn) + self._stable_noise * np.eye(Xhn.shape[0])    # Ky: (n, n)

        beta = pinv(Xhn.T.dot(pinv(psi)).dot(Xhn)).dot(Xhn.T.dot(pinv(psi)).dot(Yn))    # beta: (dim,)
        self._beta = beta

        # beta
        t0 = pinv(psi)                           # t0: (n, n)
        t1 = Xhn.T.dot(t0)                      # t1: (dim, n)
        t2 = pinv(t1.dot(Xhn))                  # t2: (dim, dim)

        self._t0 = t0
        self._t2 = t2

        self._min_Qn = float("inf")

        for i in range(Xhn.shape[0]):
            # Xn = Xhn[i, :].reshape(1, -1).copy()
            # phi = self._kernel_no_noise(Xhn, Xn)       # phi: (n, 1)

            # print(f"psi:{psi.shape}, phi:{phi.shape}, beta:{beta.shape}")

            # # Xn = Xhn.copy()
            # Xn[:, -1] = 0
            # ln = Xn.dot(beta) + phi.T.dot(pinv(psi)).dot(Yn - Xhn.dot(beta))        # ln: (1,)

            # # rn2: (1,1)
            # rn2 = self._tau ** 2 - phi.T.dot(pinv(psi)).dot(phi) \
            #     + (Xn - phi.T.dot(pinv(psi)).dot(Xhn)).dot(pinv(Xhn.T.dot(pinv(psi)).dot(Xhn))).dot((Xn - phi.T.dot(pinv(psi)).dot(Xhn)).T)
            # rn = np.sqrt(np.diag(rn2))

            Xn = Xhn[i, :].reshape(1, -1).copy()
            phi = self._kernel_no_noise(Xhn, Xn)       # phi: (n, 1)

            # print(f"psi:{psi.shape}, phi:{phi.shape}, beta:{beta.shape}")

            # Xn = Xhn.copy()
            Xn[:, -1] = 0

            ln = Xn.dot(beta) + phi.T.dot(invmat_vec(psi, Yn - Xhn.dot(beta)))        # ln: (1,)

            # rn2: (1,1)
            invpsi_Xhn = invmat_vec(psi, Xhn)
            rn2 = self._tau ** 2 - phi.T.dot(invmat_vec(psi, phi.squeeze())) \
                + (Xn - phi.T.dot(invpsi_Xhn)).dot(invmat_vec(Xhn.T.dot(invpsi_Xhn), (Xn - phi.T.dot(invpsi_Xhn)).T))
            rn = np.sqrt(np.diag(rn2))

            # print(f"Xn:{Xn.shape}, ln:{ln.shape}, rn2:{rn2.shape}")
            if np.isnan(rn).any():
                print(f"--------- rn is nan! rn2:{rn2}, term1:{phi.T.dot(pinv(psi)).dot(phi)}, term2:{(Xn - phi.T.dot(pinv(psi)).dot(Xhn)).dot(pinv(Xhn.T.dot(pinv(psi)).dot(Xhn))).dot((Xn - phi.T.dot(pinv(psi)).dot(Xhn)).T)}")
                print(f"psi:{psi}")
                print(f"phi:{phi}")
                print(f"params:{self.params}")
                print(f"Xhn: {Xhn}")
                print(f"Xn: {Xn}")
                print(f"Yn: {Yn}")

            Qn = ln + rn * norm.ppf(self._alpha)            # Qn: (1,)
            # Qn = np.nan_to_num(Qn, nan=np.inf)

            if np.isnan(Qn).any():
                print(f"----------Qn is nan!! ln:{ln}, rn:{rn}, rn2:{rn2}")

            Qn = np.nan_to_num(Qn, nan=np.inf)
            if self._min_Qn > Qn[0]:
                self._min_Qn = Qn[0]
            if np.isnan(self._min_Qn):
                print(f"----------min_Qn is nan!")

    def _acquisition(self, xh: np.array) -> float:
        """Calculate acquisition.
        """
        Xhn = self._data.Xh                     # Xhn: (n, dim)
        Yn = self._data.Y                       # Yn: (n,)

        # 计算ln+1, rn+1

        gammap = self._kernel(Xhn, xh)   # \gamma_n(x_{n+1}, h_{n+1}), gammap: (6, 1)
        mn = xh.dot(self._beta) + gammap.T.dot(self._t0).dot(Yn - Xhn.dot(self._beta))    # mn: (1,)

        Xhnp = np.concatenate((Xhn, xh), axis=0)    # Xhnp: (n+1, dim)
        Ynp = np.concatenate((Yn, mn), axis=0)      # Ynp: (n+1,)
        psi_np = self._kernel(Xhnp, Xhnp) + self._stable_noise * np.eye(Xhnp.shape[0])           # psi_np: (n+1, n+1)
        phi_np = self._kernel_no_noise(Xhnp, xh)    # phi_np: (n+1, 1)

        # beta_np，即beta_{n+1}
        t5 = pinv(psi_np)                       # t5: (n+1, n+1)
        # t6 = Xhnp.T.dot(t5)                     # t6: (dim, n+1)
        # t7 = pinv(t6.dot(Xhnp))                 # t7: (dim, dim)   (F_{n+1}^T Psi^{-1} F_{n+1})^{-1}
        # beta_np = t7.dot(t6.dot(Ynp))           # beta_np: (dim,)
        t7 = pinv(Xhnp.T.dot(invmat_vec(psi_np, Xhnp)))

        # l_np，即l_{n+1}
        X_np = xh.copy()
        X_np[0, -1] = 0                         # X_np: (1, dim)
        t8 = phi_np.T.dot(t5)                   # t8: (1, n+1)

        Lambda = ((X_np - t8.dot(Xhnp)).dot(t7).dot(Xhnp.T) + phi_np.T).dot(t5)  # Lambda: (1, n+1)
        l_np = Lambda.dot(Ynp)                  # l_np: (1,)
        # l_np = X_np.dot(beta_np) + t8.dot(Ynp - Xhnp.dot(beta_np))  # l_np: (1,)

        # rnp2，即r_{n+1}^2
        t9 = X_np - t8.dot(Xhnp)                # t9: (1, dim)
        # rnp2 = self._tau**2 - t8.dot(phi_np) + t9.dot(t7).dot(t9.T)  # rnp2: (1,1)

        rnp2 = self._tau**2 + X_np.dot(t7).dot(X_np.T) - 2 * X_np.dot(t7).dot(t8.dot(Xhnp).T)
        rnp = np.sqrt(np.diag(rnp2))            # rnp: (1,)

        if np.isnan(rnp):
            print(f"-------rnp is nan! rnp2={rnp2}, term1:{t8.dot(phi_np)}, term2:{t9.dot(t7).dot(t9.T)}")
            print(f"-------rnp is nan! rnp2={rnp2}, term1:{X_np.dot(t7).dot(X_np.T)}, term2:{2 * X_np.dot(t7).dot(t8.dot(Xhnp).T)}")

        # snp2，即s_n(x_{n+1}, h_{n+1})^2
        hnp = xh[0, -1]                         # hnp: ()
        t10 = xh - gammap.T.dot(self._t0).dot(Xhn)    # t10: (1, dim)
        snp2 = self._tau**2 + self._ltau**2 * hnp**self._omega - gammap.T.dot(self._t0).dot(gammap) \
            + t10.dot(self._t2).dot(t10.T)            # snp2: (1,1)
        assert snp2 > 0, "snp2 is less than zero."
        snp = np.sqrt(np.diag(snp2))            # snp: (1,)

        # Qnp
        Qnp = l_np[0] + rnp[0] * norm.ppf(self._alpha)  # Qnp: ()
        if np.isnan(Qnp):
            print(f"------Qnp is nan! l_np:{l_np[0]}, rnp:{rnp[0]}")

        # acquisition
        t11 = (self._min_Qn - Qnp) / (Lambda[0, -1] * snp[0])   # t11: ()
        if np.isnan(t11):
            print(f"--------------------t11 is nan.! snp:{snp[0]}, lambda:{Lambda[0, -1]}, min_Qn:{self._min_Qn}")
        acq = (self._min_Qn - Qnp) * norm.cdf(t11) + Lambda[0, -1] * snp[0] * norm.pdf(t11)
        if np.isnan(acq):
            print(f"t11:{t11}---------------- acq is nan!")
        return acq, Lambda

    def propose_point(self, nb_init_points: int = 100) -> np.ndarray:
        """
        Propose a point with the maximum expected quantile improvement.
        Returns:
            proposed_point: one point with shape (1, dim).
        """

        self._pre_acquisition()

        def min_obj(xh: np.array) -> float:
            """
            Args:
                xh: one point.

            Returns:
                minus_acq: negative acquisition.
            """
            xh = xh.reshape(1, -1)
            acq, Lambda = self._acquisition(xh)
            return -acq, Lambda

        min_val = float('inf')
        best_point = None
        initial_points = random_with_bound(self._xh_low_bounds, self._xh_up_bounds, nb_init_points)

        for point in initial_points:
            obj, Lambda = min_obj(point)
            if obj < min_val:
                min_val = obj
                best_point = point

        if best_point is None:
            print("obj:", obj, "Lambda:", Lambda)
            print("beta:", self._beta, "params:", self.params)
            # print("kernel:", self._train_kernel)

        proposed_point = best_point.reshape(1, -1)
        return proposed_point

    def predict(self, xh: list, nb: int = None) -> np.ndarray:
        """Calculate ln(x), rn(x).
        Args:
            xh: array (n, dim).
            nb: int, use how many points to predict.
        """
        xh = np.array(xh)[np.newaxis, :]
        assert xh.ndim == 2          # x: (nx, dim-1)
        if nb is None:
            Xhn = self._data.Xh         # Xhn: (n, dim)
            Yn = self._data.Y           # Yn: (n,)
        else:
            Xhn = self._data.Xh[:nb, :]
            Yn = self._data.Y[:nb]
        assert Xhn.ndim == 2 and Yn.ndim == 1
        # print(Xhn.shape)
        # print(Yn.shape)

        # self.fit(Xhn, Yn)

        # print("params:", self.params, "nll:", self._nll(self.params))

        # n, _ = x.shape
        # xh = np.concatenate((x, np.zeros((n, 1))), axis=1)      # xh: (nx, dim)

        Ky = self._kernel(Xhn, Xhn) + self._stable_noise * np.eye(Xhn.shape[0])           # Ky: (n, n)
        Kf = self._kernel_no_noise(Xhn, xh)    # Kf: (n, nx)

        # beta
        m = Ky.shape[0]
        t0 = pinv(Ky)                           # t0: (n, n)
        t1 = Xhn.T.dot(t0)                      # t1: (dim, n)
        t2 = pinv(t1.dot(Xhn) + EPS * np.eye(self._dim))                  # t2: (dim, dim)
        beta = t2.dot(t1.dot(Yn))               # beta: (dim,)
        assert beta.ndim == 1

        # ln
        t3 = Kf.T.dot(t0)                       # t3: (nx, n)
        # ln = xh.dot(beta) + t3.dot(Yn - Xhn.dot(beta))  # ln: (nx,)
        ln = xh.dot(beta) + Kf.T.dot(pinv(Ky).dot(Yn - Xhn.dot(beta)))

        # print("ln:", ln)
        # print("delta:", Yn - Xhn.dot(beta))
        # print("Yn:", Yn)
        # print("ln2:", Kf.T.dot(pinv(Ky).dot(Yn - Xhn.dot(beta))))
        # print("ln original:", Kf.T.dot(pinv(Ky).dot(Yn)))
        # ln = Kf.T.dot(pinv(Ky).dot(Yn))

        # rn
        t4 = xh - t3.dot(Xhn)                   # t4: (nx, dim)
        rn2 = self._tau ** 2 + t4.dot(t2).dot(t4.T) - t3.dot(Kf)     # rn2: (nx, nx)
        rn = np.sqrt(np.diag(rn2))              # rn: (nx,)
        return ln

        # Kff = self._kernel(xh, xh)              # Kff: (nx, nx)
        # mu_new = Kf.T.dot(pinv(Ky).dot(Yn))             # mu_new: (nx,)
        # Sigma_new = Kff - Kf.T.dot(pinv(Ky)).dot(Kf)    # Sigma_new: (nx, nx)
        # std_new = np.sqrt(np.diag(Sigma_new))           # std_new: (nx,)
        # return mu_new, std_new

    @property
    def params(self):
        return np.concatenate((self._theta, self._ltheta, [self._tau], [self._ltau]), axis=0)
