import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vutils
import tqdm
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


class Decoder(nn.Module):
    def __init__(self, z_ndim=50):
        super().__init__()
        self.z_ndim = z_ndim
        self.net = nn.Sequential(
            nn.Linear(z_ndim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
        )

    def forward(self, z):
        return self.net(z)

    def cond_ll(self, x, z):
        return stable_cross_entropy(self.forward(z), x)


class BinaryVAE(nn.Module):
    def __init__(self, z_ndim):
        super().__init__()
        self.z_ndim = z_ndim
        self.decoder = Decoder(z_ndim)
        self.encoder = nn.Sequential(
            nn.Linear(784, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, z_ndim * 2),
        )

    def forward(self, x, noise):
        mean_std = self.encoder(x)
        mean = mean_std[:, : self.z_ndim]
        std = th.abs(mean_std[:, self.z_ndim :]) + 1e-6
        sampled_z = mean + std * noise
        recon_x_logit = self.decoder.forward(sampled_z)
        return recon_x_logit, mean, std

    def loss(self, x):
        noise = th.randn((x.shape[0], self.z_ndim)).to(x)
        recon_x_logit, mean, std = self.forward(x, noise)
        reg = kl_gaussian(mean, std)
        log_loss = stable_cross_entropy(recon_x_logit, x)
        return reg + log_loss

    def sample(self, num_image):
        noise = th.randn((num_image, self.z_ndim)).cuda()
        x = th.sigmoid(self.decoder.forward(noise))
        return x

    # do we need iwae?
    def log_pdf(self, z, x):
        log_norm_gaussian = -0.5 * self.z_ndim * np.log(2 * np.pi)
        z_term = -0.5 * th.sum(th.square(z), dim=1)
        log_prior = z_term + log_norm_gaussian

        recon_x_logits = self.decoder.forward(z)
        log_ll = stable_cross_entropy(recon_x_logits, x)

        return log_prior + log_ll


def stable_cross_entropy(logits, labels):
    max_logits_zero = F.relu(logits)
    neg_abs_logits = -th.abs(logits)
    terms = max_logits_zero - logits * labels + F.softplus(neg_abs_logits)
    return th.sum(terms, dim=1)


def kl_gaussian(mean, std):
    terms = 0.5 * (th.square(std) + th.square(mean) - 1.0 - 2.0 * th.log(std))
    return th.sum(terms, dim=1)


def train_vae(model_file):
    z_ndim = 50
    batch_size = 128
    lr = 1e-3
    n_epochs = 40
    log_interval = 100
    model = BinaryVAE(z_ndim).cuda()
    train_loader = DataLoader(
        datasets.MNIST(
            "../data",
            train=True,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,)),
                    lambda x: x > 0,
                    lambda x: x.float(),
                ]
            ),
        ),
        batch_size=batch_size,
        shuffle=True,
    )
    test_loader = DataLoader(
        datasets.MNIST(
            "../data",
            train=False,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,)),
                    lambda x: x > 0,
                    lambda x: x.float(),
                ]
            ),
        ),
        batch_size=batch_size,
        shuffle=True,
    )

    optim = th.optim.Adam(model.parameters(), lr=lr)
    for epoch in tqdm.tqdm(range(1, n_epochs + 1)):
        train_loss = 0
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.float().cuda().flatten(start_dim=1)
            loss = model.loss(data).mean()
            optim.zero_grad()
            loss.backward()
            train_loss += loss.item() * len(data)
            optim.step()
            if batch_idx % log_interval == 0:
                print(
                    "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                        epoch,
                        batch_idx * len(data),
                        len(train_loader.dataset),
                        100.0 * batch_idx / len(train_loader),
                        loss.item() * len(data) / len(data),
                    )
                )
                with th.no_grad():
                    samples = model.sample(100).view(100, 1, 28, 28)
                    vutils.save_image(
                        samples,
                        f"../logs/{epoch:03d}-{batch_idx:03d}.png",
                        normalize=True,
                        nrow=10,
                    )

        with th.no_grad():
            test_loss = 0
            for i, (data, _) in enumerate(test_loader):
                data = data.float().cuda().flatten(start_dim=1)
                loss = model.loss(data).mean()
                test_loss += loss.item()

            test_loss /= len(test_loader.dataset)
            print("====> Test set loss: {:.4f}".format(test_loss))

    th.save(model.state_dict(), model_file)


if __name__ == "__main__":
    train_vae("../data/vae.pth")
