import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from matplotlib import cm

np.random.seed(42)


def multi_gauss(num, mu, sigma: np.ndarray):
    L = np.linalg.cholesky(sigma)
    return (L @ np.random.randn(2, num)).T + mu


class GMM:
    def __init__(self, weights=None, mu_list=None, sigma_list=None):
        self.weights = [] if weights is None else weights
        self.mu_list = [] if mu_list is None else mu_list
        self.sigma_list = [] if sigma_list is None else sigma_list

    def sample(self, num):
        num_for_classes = np.random.multinomial(num, self.weights, size=1)[0]
        points = []
        for num, weight, mu, sigma in zip(num_for_classes, self.weights, self.mu_list, self.sigma_list):
            points.append(multi_gauss(num, mu, sigma))

        return np.concatenate(points)

    def pdf(self, xx, dim):
        pdf = 0.
        for weight, mu, sigma in zip(self.weights, self.mu_list, self.sigma_list):
            ddet_sqrt = np.sqrt(np.linalg.det(sigma))
            pdf += weight * 1 / (np.sqrt(2 * np.pi) ** dim * ddet_sqrt) * np.exp(
                -np.sum((xx - mu).T * np.linalg.solve(sigma, (xx - mu).T), axis=0) / 2)
        return pdf


def compute_ppca_mle(covariance, z_dim):
    """Compute the probabilistic pPCA MLE.

    Args:
      covariance: The data covariance matrix
      z_dim: The number of hidden dimensions

    Returns:
      (sigma_sq_mle, W_mle): The MLE solution to pPCA.
    """
    w, u = np.linalg.eigh(covariance)
    eigvals, eigvecs = w[::-1], u[:, ::-1]
    missing_eigvals = eigvals[z_dim:]
    sigma_sq_mle = missing_eigvals.sum() / (eigvals.shape[0] - z_dim)

    active_eigvals = np.diag(eigvals[:z_dim])
    active_components = eigvecs[:, :z_dim]

    W_mle = active_components.dot(
        (active_eigvals - sigma_sq_mle * np.eye(z_dim)) ** 0.5)
    return sigma_sq_mle, W_mle


if __name__ == "__main__":
    # ======================================================== #
    # ------------------- Hyper-parameters ------------------- #
    px_type = 'multi-mode'  # 'multi-mode' or 'single-mode'
    plot_px = True
    plot_qz_pz = True
    plot_px_from_qz_pz = True
    # ======================================================= #

    device = 'cuda:0'
    kl_weights = [1.]

    if px_type == 'multi-mode':
        mu_i = 3.
        gmm = GMM(weights=[0.5, 0.5],
                  mu_list=[np.array([-mu_i, -mu_i]), np.array([mu_i, mu_i])],
                  sigma_list=[np.array([[1.0, 0], [0., 1.0]]),
                              np.array([[1.0, 0], [0., 1.0]])])

    elif px_type == 'single-mode':
        gmm = GMM(weights=[1.0],
                  mu_list=[np.array([0., 0.])],
                  sigma_list=[np.array([[1.0, 0.0], [0.0, 1.0]])])

    # ============= data prepare and visualization =========================== #
    train_x = gmm.sample(10000)
    np.save(f'./{px_type}_trainset.npy', train_x)

    test_x = gmm.sample(1000)
    np.save(f'./{px_type}_testset.npy', test_x)

    # show the sampled data points
    # plt.scatter(train_x[:, 0], train_x[:, 1])
    # plt.show()
    # plt.clf()

    if plot_px:
        if px_type == 'multi-mode':
            x_axis = np.arange(-6, 6, 0.02)
            y_axis = np.arange(-6, 6, 0.02)
        elif px_type == 'single-mode':
            x_axis = np.arange(-4, 4, 0.02)
            y_axis = np.arange(-4, 4, 0.02)
        grid = np.meshgrid(x_axis, y_axis)
        px_pdfs = np.zeros((len(x_axis), len(y_axis)))
        for xi in range(len(x_axis)):
            for yi in range(len(y_axis)):
                # log_pdfs[xi, yi] = np.log(gmm.pdf([grid[0][xi, yi], grid[1][xi, yi]], dim=2))
                px_pdfs[xi, yi] = (gmm.pdf([grid[0][xi, yi], grid[1][xi, yi]], dim=2))

        fig = plt.figure(figsize=(7, 6))
        ax = plt.axes(projection='3d')

        ax.plot_surface(grid[0], grid[1], px_pdfs, cmap=cm.viridis)  # ocean  PuBu  rainbow  coolwarm
        # cb = fig.colorbar(surf, shrink=0.8, aspect=15)

        ax.set_xlabel(r'$x_1$', fontsize=9)
        ax.set_ylabel(r'$x_2$', fontsize=9)
        ax.set_zlabel(r'$p(x)$')
        plt.tight_layout()
        plt.xticks(fontsize=9)
        plt.yticks(fontsize=9)
        plt.savefig(f'./figs/{px_type}_px_3d.pdf')
        plt.savefig(f'./figs/{px_type}_px_3d.png')
        plt.show()

    # ============================== compute analytic solutions ================================================== #
    covariance = np.cov(train_x, rowvar=False)
    sigma_sq_mle, E_mle = compute_ppca_mle(covariance, 1)
    m = np.eye(1)+E_mle.T.dot(E_mle)/sigma_sq_mle
    A = np.linalg.inv(m).dot(E_mle.T)/sigma_sq_mle
    C = np.linalg.inv(m.T)


    if plot_qz_pz:
        # --------- distribution of q(z) ------------------ #
        # q(z) = \int q(z|x)p(x) dx = \sum_i \pi_i N(A \mu_i, A \Sigma_i A^T + C)
        Sigma_xi = np.array([[1.0, 0.0], [0.0, 1.0]])
        Sigma_zi = A.dot(Sigma_xi).dot(A.T) + C
        if px_type == 'multi-mode':
            qz = GMM(weights=[0.5, 0.5],
                     mu_list=[A.dot(np.array([[-mu_i], [-mu_i]])), A.dot(np.array([[mu_i], [mu_i]]))],
                     sigma_list=[Sigma_zi, Sigma_zi])
        elif px_type == 'single-mode':
            qz = GMM(weights=[1.0],
                     mu_list=[0.],
                     sigma_list=[Sigma_zi])


        if px_type == 'multi-mode':
            # simulated OOD data for multi-modal case: p_ood (x) = N([0, 0], ood_Sigma_xi)
            ood_Sigma_xi = np.array([[1.0, 0.0], [0.0, 1.0]])
            ood_Sigma_zi = A.dot(ood_Sigma_xi).dot(A.T) + C
            ood_qz = GMM(weights=[1.0],
                         mu_list=[0.],
                         sigma_list=[ood_Sigma_zi])

            z_axis = np.arange(-3, 3, 0.01)
            qzs = np.zeros(len(z_axis))
            ood_qzs = np.zeros(len(z_axis))
            pzs = np.zeros(len(z_axis))
            grid = np.meshgrid(z_axis)
            kl_qz_pz_id = 0.
            kl_qz_pz_ood = 0.
            kl_q_ood_q_id = 0.
            for zi in range(len(z_axis)):
                qzs[zi] = qz.pdf(np.array([[grid[0][zi]]]), dim=1)
                pzs[zi] = norm.pdf(grid[0][zi])
                ood_qzs[zi] = ood_qz.pdf(np.array([[grid[0][zi]]]), dim=1)

                # compute the KL[q(z)||p(z)]
                # KL[q(z)||p(z)] = \int q(z) * log(q(z)/p(z)) = int U(z)*(q(z)/U(z)) * log(q(z)/p(z))
                # = E_{U(z)}[(q(z)/U(z)) * log(q(z)/p(z))], where U(z) is an uniform distribution
                kl_qz_pz_id += qzs[zi]*(3-(-3))*np.log(qzs[zi]/pzs[zi])
                kl_qz_pz_ood += ood_qzs[zi]*(3-(-3))*np.log(ood_qzs[zi]/pzs[zi])
                kl_q_ood_q_id += ood_qzs[zi]*(3-(-3))*np.log(ood_qzs[zi]/qzs[zi])
            print(f'KL[q_id(z)||p(z)]={kl_qz_pz_id}  KL[q_ood(z)||p(z)]={kl_qz_pz_ood}\n'
                  f'==> Gap_1 = KL[q_ood(z)||p(z)] - KL[q_id(z)||p(z)] = {kl_qz_pz_ood-kl_qz_pz_id}\n\n'
                  f'KL[q_id(z)||q_id(z)]=0  KL[q_ood(z)||q_id(z)]={kl_q_ood_q_id}\n'
                  f'==> Gap_2 = KL[q_ood(z)||q_id(z)] - KL[q_id(z)||q_id(z)] = {kl_q_ood_q_id}\n\n'
                  f'As shown in the results, Gap_2 is much larger than Gap_1, which is the insight of PHP method.')

        if px_type == 'single-mode':
            z_axis = np.arange(-3, 3, 0.01)
            qzs = np.zeros(len(z_axis))
            pzs = np.zeros(len(z_axis))
            grid = np.meshgrid(z_axis)
            for zi in range(len(z_axis)):
                qzs[zi] = qz.pdf(np.array([[grid[0][zi]]]), dim=1)
                pzs[zi] = norm.pdf(grid[0][zi])

        plt.plot(z_axis, qzs, color='orange', label=r'q(z)')
        plt.xlabel(r'z', fontsize=15)
        plt.ylabel('Probability', fontsize=15)
        plt.legend()

        plt.grid(True)
        plt.grid(color='gray',
                 linestyle='-',
                 linewidth=1,
                 alpha=0.3)
        plt.tight_layout()
        plt.savefig(
            f'./figs/Analytical_qz_{px_type}.pdf')
        plt.savefig(
            f'./figs/Analytical_qz_{px_type}.png')
        plt.title(
            f'{px_type}: Analytical $q(z)=\int_z q(z|x)p(x)$ by Linear VAE')
        plt.show()
        plt.clf()

        plt.plot(z_axis, pzs, color='green', label=r'p(z)')
        plt.xlabel(r'z', fontsize=15)
        plt.ylabel('Probability', fontsize=15)
        plt.legend()

        plt.grid(True)
        plt.grid(color='gray',
                 linestyle='-',
                 linewidth=1,
                 alpha=0.3)
        plt.tight_layout()
        plt.savefig(
            f'./figs/Pz.pdf')
        plt.savefig(
            f'./figs/Pz.png')
        plt.title(
            f'p(z) = N(0,I)')

        plt.show()
        plt.clf()

        plt.plot(z_axis, qzs, color='orange', label=f'q(z) of ID')
        # plt.plot(z_axis, ood_qzs, color='firebrick', label=f'q(z) of OOD')
        plt.plot(z_axis, pzs, color='green', label=r'p(z)')
        plt.xlabel(r'z', fontsize=15)
        plt.ylabel('Probability', fontsize=15)
        plt.legend()

        plt.grid(True)
        plt.grid(color='gray',
                 linestyle='-',
                 linewidth=1,
                 alpha=0.3)
        plt.savefig(
            f'./figs/qz_pz_{px_type}.pdf')
        plt.savefig(
            f'./figs/qz_pz_{px_type}.png')
        plt.show()
        plt.clf()

        #  an non-analytical version
        # mu_z_s = []
        # for train_xi in train_x:
        #     # z_mu = Ax
        #     mu_q_z_given_x = A.dot(train_xi)
        #     mu_z_s.append(mu_q_z_given_x)
        # mu_z_s = np.array(mu_z_s)
        # plt.hist(mu_z_s, bins=100, density=True, alpha=0.5,
        #      label=f'approximate q(z)')
        # plt.title('$\mu_z \sim q(z|x)$ ')
        # plt.show()

    # ===================== estimated p(x) from q(z) and p(z) ===================================================== #
    # ---------------- distribution of estimated p(x)=\int_z p(x|z)p(z) ------------------- #
    # Notes: \int p(x|z)p(z) dz = N(0, E.dot(E.T)+sigma_sq*I)
    Sigma_M = E_mle.dot(E_mle.T) + sigma_sq_mle*np.eye(2)
    px_from_pz = GMM(weights=[1.0],
             mu_list=[np.array([0., 0.])],
             sigma_list=[Sigma_M])

    # ---------------- distribution of estimated p(x)=\int_z p(x|z)q(z) ------------------ #
    Sigma_xi = np.array([[1.0, 0.0], [0.0, 1.0]])
    Sigma_zi = A.dot(Sigma_xi).dot(A.T) + C
    Sigma_qxfq = E_mle.dot(Sigma_zi).dot(E_mle.T) + sigma_sq_mle*np.eye(2)
    if px_type == 'multi-mode':
        px_from_qz = GMM(weights=[0.5, 0.5],
                 mu_list=[E_mle.dot(A.dot(np.array([[-mu_i], [-mu_i]]))).reshape(-1), E_mle.dot(A.dot(np.array([[mu_i], [mu_i]]))).reshape(-1)],
                 sigma_list=[Sigma_qxfq, Sigma_qxfq])
    elif px_type == 'single-mode':
        px_from_qz = GMM(weights=[1.0],
                 mu_list=[E_mle.dot(A.dot(np.array([[-0.], [-0.]]))).reshape(-1)],
                 sigma_list=[Sigma_qxfq])

    if plot_px_from_qz_pz:
        if px_type == 'multi-mode':
            x_axis = np.arange(-8, 8, 0.02)
            y_axis = np.arange(-8, 8, 0.02)
        elif px_type == 'single-mode':
            x_axis = np.arange(-4, 4, 0.02)
            y_axis = np.arange(-4, 4, 0.02)
        grid = np.meshgrid(x_axis, y_axis)
        elbos = np.zeros((len(x_axis), len(y_axis)))
        log_px_from_pz_s = np.zeros((len(x_axis), len(y_axis)))
        log_px_from_qz_s = np.zeros((len(x_axis), len(y_axis)))


        for xi in range(len(x_axis)):
            for yi in range(len(y_axis)):
                input = np.array([[grid[0][xi, yi]], [grid[1][xi, yi]]])
                EAX_p_EB = E_mle.dot(A).dot(input)
                Ax_p_B = A.dot(input)
                log_2_pi_var = np.log(2*np.pi*sigma_sq_mle)
                d = 2
                q = 1
                L_1 = (1/(2*sigma_sq_mle)) * (-np.trace(E_mle.dot(C).dot(E_mle.T)) - EAX_p_EB.T.dot(EAX_p_EB)
                                              + 2*input.T.dot(EAX_p_EB)-input.T.dot(input)) - (d/2)*log_2_pi_var
                L_2 = (1/2)*(-np.log(np.linalg.det(C))+Ax_p_B.T.dot(Ax_p_B)+np.trace(C)-q)
                elbo_x = L_1 - L_2
                elbos[xi, yi] = elbo_x

                # log_qx_from_pz_s[xi, yi] = np.log(qx_from_pz.pdf([grid[0][xi, yi], grid[1][xi, yi]], dim=2))
                # log_qx_from_qz_s[xi, yi] = np.log(qx_from_qz.pdf([grid[0][xi, yi], grid[1][xi, yi]], dim=2))

                log_px_from_pz_s[xi, yi] = (px_from_pz.pdf([grid[0][xi, yi], grid[1][xi, yi]], dim=2))
                log_px_from_qz_s[xi, yi] = (px_from_qz.pdf([grid[0][xi, yi], grid[1][xi, yi]], dim=2))

        # plt.figure(4)
        cset = plt.contourf(x_axis, y_axis, elbos, 20)
        plt.colorbar(cset)
        plt.title(
            f'{px_type}: Analytical ELBO by Linear VAE')
        plt.savefig(
            f'./figs/Analytical_ELBO_{px_type}.pdf')
        plt.show()
        plt.clf()

        cset = plt.contourf(x_axis, y_axis, log_px_from_pz_s, 20)
        plt.colorbar(cset)
        plt.savefig(
            f'./figs/Analytical_px_from_pz_{px_type}.pdf')
        plt.savefig(
            f'./figs/Analytical_px_from_pz_{px_type}.png')
        plt.title(
            f'{px_type}: Analytical (from p(z)) estimated p(x)=$\int_z q(x|z)p(z)$ by Linear VAE')
        plt.show()
        plt.clf()

        fig = plt.figure(figsize=(7, 6))
        ax = plt.axes(projection='3d')
        ax.plot_surface(grid[0], grid[1], log_px_from_pz_s, cmap=cm.coolwarm)  # ocean  PuBu  rainbow  coolwarm
        # cb = fig.colorbar(surf, shrink=0.8, aspect=15)  # 添加颜色棒,shrink表示缩放,aspect表示
        ax.set_xlabel(r'$x_1$', fontsize=9)
        ax.set_ylabel(r'$x_2$', fontsize=9)
        ax.set_zlabel(r'Estimated $p(x)$ with $p(z)$')
        plt.tight_layout()
        plt.xticks(fontsize=9)
        plt.yticks(fontsize=9)
        plt.savefig(f'./figs/Analytical_px_from_pz_{px_type}_3d.pdf')
        plt.savefig(f'./figs/Analytical_px_from_pz_{px_type}_3d.png')
        plt.show()


        cset = plt.contourf(x_axis, y_axis, log_px_from_qz_s, 20)
        plt.colorbar(cset)
        plt.savefig(
            f'./figs/Analytical_px_from_qz_{px_type}.pdf')
        plt.savefig(
            f'./figs/Analytical_px_from_qz_{px_type}.png')
        plt.title(
            f'{px_type}: Analytical (from q(z)) estimated p(x)=$\int_z q(x|z)q(z)$ by Linear VAE')
        plt.show()
        plt.clf()

        fig = plt.figure(figsize=(7, 6))
        ax = plt.axes(projection='3d')
        ax.plot_surface(grid[0], grid[1], log_px_from_qz_s, cmap=cm.coolwarm)  # ocean  PuBu  rainbow  coolwarm
        # cb = fig.colorbar(surf, shrink=0.8, aspect=15)
        ax.set_xlabel(r'$x_1$', fontsize=9)
        ax.set_ylabel(r'$x_2$', fontsize=9)
        ax.set_zlabel(r'Estimated $p(x)$ with $q(z)$')
        plt.tight_layout()
        plt.xticks(fontsize=9)
        plt.yticks(fontsize=9)
        plt.savefig(f'./figs/Analytical_px_from_qz_{px_type}_3d.pdf')
        plt.savefig(f'./figs/Analytical_px_from_qz_{px_type}_3d.png')
        plt.show()



