import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.distributions.normal import Normal
from torch.utils.data import Dataset
from matplotlib import cm

np.random.seed(42)

class VanillaVAE(nn.Module):
    def __init__(self,
                 in_channels,
                 latent_dim,
                 hidden_dims,
                 fix_encode_logvar,
                 **kwargs) -> None:
        super(VanillaVAE, self).__init__()

        self.latent_dim = latent_dim

        # self.encoder = nn.Sequential(nn.Linear(in_channels, hidden_dims), nn.LeakyReLU(), nn.Linear(hidden_dims, hidden_dims), nn.LeakyReLU())
        self.encoder = nn.Sequential(nn.Linear(in_channels, hidden_dims), nn.LeakyReLU(), nn.Linear(hidden_dims, hidden_dims), nn.LeakyReLU())
        self.fc_mu = nn.Linear(hidden_dims, latent_dim)
        self.fc_var = nn.Linear(hidden_dims, latent_dim)

        self.decoder_input = nn.Sequential(nn.Linear(latent_dim, hidden_dims), nn.LeakyReLU(), nn.Linear(hidden_dims, hidden_dims), nn.LeakyReLU())
        # self.decoder_input = nn.Sequential(nn.Linear(latent_dim, hidden_dims), nn.LeakyReLU())
        self.decoder_output_mu = nn.Linear(hidden_dims, in_channels)
        self.decoder_output_logstd = nn.Linear(hidden_dims, in_channels)
        self.fix_encode_logvar = fix_encode_logvar


    def encode(self, x):
        x = self.encoder(x)
        mu = self.fc_mu(x)
        log_var = self.fc_var(x)
        if self.fix_encode_logvar:
            log_var = torch.ones_like(log_var)*(-9.21)

        return [mu, log_var]

    def decode(self, z):
        result = self.decoder_input(z)
        result_mu = self.decoder_output_mu(result)
        result_log_var = self.decoder_output_logstd(result)

        return [result_mu, result_log_var]

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input):
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        recon = self.decode(z)
        return [recon, input, mu, log_var, z]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        loss_fn = kwargs["loss_fn"]

        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        if loss_fn == 'mse':
        # use mse loss
            recons_loss =F.mse_loss(recons[0], input)
            loss = recons_loss + kld_weight * kld_loss
            return {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':-kld_loss.detach()}
        #
        if loss_fn == 'nll':
        # use NLL loss
            NLL = -Normal(recons[0], torch.exp(0.5 * recons[1])).log_prob(input).mean()
            loss = NLL + kld_weight * kld_loss
            return {'loss': loss, 'Reconstruction_Loss': NLL.detach(), 'KLD': -kld_loss.detach()}


    def sample(self,
               num_samples,
               current_device, **kwargs):
        z = torch.randn(num_samples,
                        self.latent_dim)
        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x, **kwargs):
        return self.forward(x)[0]


def multi_gauss(num, mu, sigma: np.ndarray):
    L = np.linalg.cholesky(sigma)
    return (L @ np.random.randn(2, num)).T + mu


class GMM:
    def __init__(self, weights=None, mu_list=None, sigma_list=None):
        self.weights = [] if weights is None else weights
        self.mu_list = [] if mu_list is None else mu_list
        self.sigma_list = [] if sigma_list is None else sigma_list

    def sample(self, num):
        num_for_classes = np.random.multinomial(num, self.weights, size=1)[0]
        points = []
        for num, weight, mu, sigma in zip(num_for_classes, self.weights, self.mu_list, self.sigma_list):
            points.append(multi_gauss(num, mu, sigma))

        return np.concatenate(points)

    def pdf(self, xx, dim):
        pdf = 0.
        for weight, mu, sigma in zip(self.weights, self.mu_list, self.sigma_list):
            ddet_sqrt = np.sqrt(np.linalg.det(sigma))
            pdf += weight * 1 / (np.sqrt(2 * np.pi) ** dim * ddet_sqrt) * np.exp(
                -np.sum((xx - mu).T * np.linalg.solve(sigma, (xx - mu).T), axis=0) / 2)
        return pdf


class gmm_dataset(Dataset):
    def __init__(self, root, gmm):
        self.set = np.load(root)
        self.gmm = gmm

    def __getitem__(self, index):
        data = self.set[index]
        pdf = self.gmm.pdf(data, 2)
        return torch.tensor(data, dtype=torch.float), pdf

    def __len__(self):
        return len(self.set)


if __name__ == "__main__":
    # ======================================================== #
    # ------------------- Hyper-parameters ------------------- #
    device = 'cuda:0'
    loss_fn = 'mse'
    kl_weights = [1.0]
    learning_rates = [1e-5]
    learning_rate = 1e-5
    n_epoch = 1000
    hidden_dim = 10
    important_sampling_times = 1
    # ======================================================= #

    for kl_weight in kl_weights:
        px_type = 'gmm'  # 'gmm' or 'unimodal_long'
        if px_type == 'gmm':
            mu_i = 3.
            gmm = GMM(weights=[0.5, 0.5],
                      mu_list=[np.array([-mu_i, -mu_i]), np.array([mu_i, mu_i])],
                      sigma_list=[np.array([[1.0, 0], [0., 1.0]]),
                                  np.array([[1.0, 0], [0., 1.0]])])
        elif px_type == 'unimodal_long':
            gmm = GMM(weights=[1.0],
                      mu_list=[np.array([-1., -1.])],
                      sigma_list=[np.array([[0.8, 0.7], [0.7, 1.2]])])

        # ============= data prepare and visualization =========================== #
        train_x = gmm.sample(10000)
        np.save(f'./{px_type}_trainset.npy', train_x)

        test_x = gmm.sample(2000)
        np.save(f'./{px_type}_testset.npy', test_x)

        # show the sampled data points
        # plt.scatter(train_x[:, 0], train_x[:, 1])
        # plt.show()

        # True True Probability Density of data: $p(x)$
        x_axis = np.arange(-6, 6, 0.1)
        y_axis = np.arange(-6, 6, 0.1)
        grid = np.meshgrid(x_axis, y_axis)
        pdfs = np.zeros((len(x_axis), len(y_axis)))
        for xi in range(len(x_axis)):
            for yi in range(len(y_axis)):
                pdfs[xi, yi] = gmm.pdf([grid[0][xi, yi], grid[1][xi, yi]], dim=2)

        cset = plt.contourf(x_axis, y_axis, pdfs, 20)
        plt.colorbar(cset)
        plt.title(f'{px_type}_True Probability Density of data: $p(x)$')
        # plt.savefig(f'./results/true_pdf_px_{px_type}.pdf')
        # plt.savefig(f'./results/true_pdf_px_{px_type}.png')

        plt.show()
        plt.clf()
        # ======================================================================= #

        # ================== Model Training and Testing ===================================== #
        vae = VanillaVAE(in_channels=2, latent_dim=1, hidden_dims=hidden_dim, fix_encode_logvar=False).to(device=device)
        optimizer = optim.Adam(vae.parameters(), lr=learning_rate)

        train_set = gmm_dataset(f'./{px_type}_trainset.npy', gmm)
        train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
        test_set = gmm_dataset(f'./{px_type}_testset.npy', gmm)
        test_loader = DataLoader(test_set, batch_size=32, shuffle=True)


        train_losses = []
        mse_losses = []
        train_counter = []
        gaps = []
        for epoch in range(n_epoch):
            # train
            vae.train()
            for batch_idx, (x, px) in enumerate(train_loader):
                x = x.to(device=device)
                result = vae.forward(x)
                loss = vae.loss_function(*result, M_N=kl_weight, loss_fn=loss_fn)  # loss_fn: mse or nll
                optimizer.zero_grad()
                loss["loss"].backward()
                optimizer.step()

                if batch_idx % 100 == 0:
                    print(f'Train Epoch: {epoch} [{batch_idx * len(x)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss["loss"].cpu().item():.6f}')
                    # compute the gap between ELBO and Px
                    with torch.no_grad():
                        recon_mu, recon_log_var = result[0][0], result[0][1]
                        if loss_fn == 'mse':
                            recon_log_var = torch.zeros_like(recon_log_var)
                        z_mu = result[2]
                        z_log_var = result[3]


                        NLL = -torch.mean(Normal(recon_mu, torch.exp(0.5 * recon_log_var)).log_prob(x), dim=-1)
                        kld_loss = -0.5 * torch.sum(1 + z_log_var - z_mu ** 2 - z_log_var.exp(), dim=1)
                        elbo = -(NLL + kld_loss)


                        mse_gap = F.mse_loss(recon_mu, x).cpu()
                        print(f'\n===================== kl_weight: {kl_weight} lr_{learning_rate} =======================================\n'
                              f'the mse gap between input and recon_mu: {mse_gap.cpu()} Details are:\nrecon_mu: {recon_mu[:min(len(x),10)]}'
                              f'\ninput: {x[:min(len(x),10)]}\n--------------------------------------------------------------------\n')

                        gap = abs((elbo.exp().cpu() - px)).mean()
                        print(f'the mean gap (ELBO.exp() - Px): {gap} \nDetails elbo.exp(): {elbo.exp()[:min(len(x),10)]}, \npx: {px[:min(len(x),10)]}\n')

                    train_losses.append(loss["loss"].item())
                    train_counter.append(
                        (batch_idx * len(x)) + (epoch * len(train_loader.dataset)))
                    mse_losses.append(mse_gap)
                    gaps.append(gap)
                    torch.save(vae.state_dict(), f'./results/{px_type}_model_kl_{kl_weight}_lr_{learning_rate}_{n_epoch}.pth')
                    torch.save(optimizer.state_dict(), f'./results/{px_type}_optimizer_kl_{kl_weight}_lr_{learning_rate}_{n_epoch}.pth')

            if epoch % 100 == 0:
                # plt.plot(train_counter, gaps, label='gap: elbo.exp - px')
                # plt.figure(3)
                # plt.plot(train_counter, train_losses, label='train_loss (-elbo)')
                # plt.plot(train_counter, mse_losses, label='mse_gap (recon_mu - input)')
                plt.plot(train_counter, gaps, label='gap (elbo.exp() - px)')  # <<<<<<<<<< important
                plt.legend()
                plt.title(f'{px_type}_kl_weight: {kl_weight} Losses _lr_{learning_rate}')
                plt.savefig(f'./results/{px_type}_Training_loss_kl_{kl_weight}_lr_{learning_rate}.png')
                plt.show()
                plt.clf()

            # 1) inference, reconstruction, generation test and loss curve
            if epoch % 50 == 0:

                with torch.no_grad():
                    # 1.1) inference on train set and self-made class set
                    # plt.figure(1)
                    sample_num = 5000
                    hist_zs = []
                    for batch_idx, (x, px) in enumerate(train_loader):
                        x = x.to(device=device)
                        result = vae.forward(x)
                        encoded_z = result[4]
                        hist_zs.append(encoded_z.cpu().numpy())
                        # plt.scatter(result[0][0][:, 0].cpu(), result[0][0][:, 1].cpu(), s=10, c='red', alpha=0.7, label='reconstructed train points')
                        # plt.scatter(x[:, 0].cpu(), x[:, 1].cpu(), s=10, c='green', alpha=0.7, label='true train points')
                        # plt.scatter(encoded_z[:, 0].cpu(), encoded_z[:, 1].cpu(), s=10, c='green', alpha=0.7)
                        if len(x) * batch_idx >= sample_num:
                            break
                    hist_zs = np.stack(hist_zs).reshape(-1, encoded_z.shape[-1])
                    plt.hist(hist_zs, bins=50, color='orange', label=r'$z$ ~ $q(z|x)$')
                    plt.plot()
                    plt.xlabel(r'z', fontsize=15)
                    plt.ylabel('Density', fontsize=15)
                    plt.legend()

                    plt.grid(True)
                    plt.grid(color='gray',
                             linestyle='-',
                             linewidth=1,
                             alpha=0.3)
                    plt.tight_layout()
                    plt.savefig(f'./figs/deep_vae_posterior_z_kl_{kl_weight}_lr_{learning_rate}.pdf')
                    plt.savefig(f'./figs/deep_vae_posterior_z_kl_{kl_weight}_lr_{learning_rate}.png')

                    plt.show()
                    plt.clf()

                    # classes = []
                    # colors = ['cyan', 'orange', 'purple']
                    # classes.append(Normal(torch.tensor([-mu_i, -mu_i]), torch.tensor([1., 1.])))
                    # classes.append(Normal(torch.tensor([mu_i, mu_i]), torch.tensor([1., 1.])))
                    # # classes.append(Normal(torch.tensor([0., 3.]), torch.tensor([1., 1.])))
                    # for class_i in range(len(classes)):
                    #     class_input = classes[class_i].sample([sample_num]).to(device=device)
                    #     result = vae.forward(class_input)
                    #     encoded_z = result[4]
                    #     plt.scatter(encoded_z[:, 0].cpu(), encoded_z[:, 1].cpu(), s=10, c=colors[class_i], alpha=0.7, label=f'class {class_i}')
                    #
                    # plt.legend()
                    # plt.title(f'{px_type}_kl_weight:{kl_weight} Encoded Z _lr_{learning_rate}')
                    # plt.savefig(f'./results/{px_type}_Encoded_z_kl_{kl_weight}_lr_{learning_rate}.pdf')
                    # plt.savefig(f'./results/{px_type}_Encoded_z_kl_{kl_weight}_lr_{learning_rate}.png')

                    # plt.show()
                    # plt.clf()

                    # 1.2) reconstruction on test set
                    for batch_idx, (x, px) in enumerate(test_loader):
                        x = x.to(device=device)
                        result = vae.forward(x)

                        plt.scatter(result[0][0][:, 0].cpu(), result[0][0][:, 1].cpu(), s=10, c='blue', alpha=0.7, label='reconstructed test points')
                        plt.scatter(x[:, 0].cpu(), x[:, 1].cpu(), s=10, c='brown', alpha=0.7, label='true test points')
                        if len(x) * batch_idx >= sample_num:
                            break
                    plt.legend()
                    plt.title(f'{px_type}_kl_weight: {kl_weight} Reconstruction _lr_{learning_rate}')
                    plt.savefig(f'./results/{px_type}_Reconstruction_kl_{kl_weight}_lr_{learning_rate}.pdf')
                    plt.savefig(f'./results/{px_type}_Reconstruction_kl_{kl_weight}_lr_{learning_rate}.png')
                    plt.show()
                    # plt.clf()


                    # 1.3) generation
                    sample_batch = 500
                    result = vae.sample(sample_batch, device)
                    plt.scatter(result[0][:, 0].cpu(), result[0][:, 1].cpu(), s=10, c='blue', alpha=0.7, label='generated points')
                    plt.scatter(test_x[:, 0], test_x[:, 1], s=10, c='green', alpha=0.7, label='true data points')
                    plt.legend()
                    plt.title(f'{px_type}_kl_weight: {kl_weight} Generation _lr_{learning_rate}')
                    plt.savefig(f'./results/{px_type}_Generation_kl_{kl_weight}_lr_{learning_rate}.pdf')
                    plt.savefig(f'./results/{px_type}_Generation_kl_{kl_weight}_lr_{learning_rate}.png')

                    plt.show()
                    plt.clf()



            # 2) grid test
            if epoch % 100 == 0:
                vae.eval()

                x_axis = np.arange(-6, 6, 0.02)
                y_axis = np.arange(-6, 6, 0.02)
                grid = np.meshgrid(x_axis, y_axis)
                elbo_exps = np.zeros((len(x_axis), len(y_axis)))
                with torch.no_grad():
                    for xi in range(len(x_axis)):
                        for yi in range(len(y_axis)):
                            input = torch.tensor(np.array([grid[0][xi, yi], grid[1][xi, yi]]), dtype=torch.float).unsqueeze(
                                0).to(device=device)
                            result = vae.forward(input)
                            recon_mu, recon_log_var = result[0][0], result[0][1]
                            if loss_fn == 'mse':
                                recon_log_var = torch.zeros_like(recon_log_var)
                            z_mu = result[2]
                            z_log_var = result[3]

                            NLL = -torch.mean(Normal(recon_mu, torch.exp(0.5 * recon_log_var)).log_prob(input.to(device=device)), dim=-1)
                            kld_loss = -0.5 * torch.sum(1 + z_log_var - z_mu ** 2 - z_log_var.exp(), dim=1)
                            elbo = -(NLL + kld_loss)
                            # elbo_exps[xi, yi] = torch.min(elbo.exp(), torch.ones_like(elbo))
                            elbo_exps[xi, yi] = elbo.exp()
                            # if important_sampling_times:
                            #     importance_px = 0.
                            #     for sample_i in range(important_sampling_times):
                            #         z_i = vae.reparameterize(z_mu, z_log_var)
                            #         result_i = vae.decode(z_i)
                            #         recon_mu_i, recon_log_var_i = result_i[0], result_i[1]
                            #         if loss_fn == 'mse':
                            #             recon_log_var_i = torch.zeros_like(recon_log_var_i)
                            #         px_given_z_i = torch.mean(Normal(recon_mu_i, torch.exp(0.5 * recon_log_var_i)).log_prob(input),
                            #                           dim=-1).exp()
                            #         pz_i = torch.mean(Normal(torch.tensor([0., 0.]).to(device=device), torch.tensor([1., 1.]).to(device=device)).log_prob(z_i),
                            #                           dim=-1).exp()
                            #         qz_i_given_x = torch.mean(Normal(z_mu, torch.exp(0.5 * z_log_var)).log_prob(z_i),
                            #                           dim=-1).exp()
                            #         importance_px += (px_given_z_i * pz_i) / qz_i_given_x
                            #
                            #     importance_px /= important_sampling_times
                            #     elbo_exps[xi, yi] = torch.min(importance_px, torch.ones_like(importance_px))
                            #     elbo_exps[xi, yi] = torch.max(importance_px, torch.zeros_like(importance_px))
                            # # loss = vae.loss_function(*result, M_N=1.0)
                            # elbo_exps[xi, yi] = loss["loss"].exp()

                #
                # plt.figure(4)
                cset = plt.contourf(x_axis, y_axis, elbo_exps, 20)
                plt.colorbar(cset)
                plt.tight_layout()
                plt.savefig(f'./figs/deep_vae_elbo_kl_{kl_weight}_lr_{learning_rate}.pdf')
                plt.savefig(f'./figs/deep_vae_elbo_kl_{kl_weight}_lr_{learning_rate}.png')
                plt.title(
                    f'{px_type}_kl_weight: {kl_weight} _lr_{learning_rate} Estimated px with importance sample {important_sampling_times}')
                plt.savefig(
                    f'./results/{px_type}_Estimated_px_sampling_{important_sampling_times}_kl_{kl_weight}_lr_{learning_rate}.pdf')
                plt.savefig(
                    f'./results/{px_type}_Estimated_px_sampling_{important_sampling_times}_kl_{kl_weight}_lr_{learning_rate}.png')
                plt.show()
                plt.clf()

                fig = plt.figure(figsize=(7, 6))
                ax = plt.axes(projection='3d')
                ax.plot_surface(grid[0], grid[1], elbo_exps, cmap=cm.coolwarm)  # ocean  PuBu  rainbow  coolwarm
                # cb = fig.colorbar(surf, shrink=0.8, aspect=15)  # 添加颜色棒,shrink表示缩放,aspect表示
                ax.set_xlabel(r'$x_1$', fontsize=9)
                ax.set_ylabel(r'$x_2$', fontsize=9)
                ax.set_zlabel(r'$p(x)$')
                plt.tight_layout()
                plt.xticks(fontsize=9)
                plt.yticks(fontsize=9)
                plt.savefig(f'./figs/deep_vae_px_3d_kl_{kl_weight}_lr_{learning_rate}.pdf')
                plt.savefig(f'./figs/deep_vae_px_3d_kl_{kl_weight}_lr_{learning_rate}.png')
                plt.show()

