import math

import torch
from torch import distributions as D
from torch import nn
from torch.nn import functional as F

import utils


class VAE_MNIST(nn.Module):
    def __init__(self, px_y_family_ll, device, grayscale=False, sigma=0.03, n_mix=10, fc=False):
        super(VAE_MNIST, self).__init__()

        self.device = device
        self.c = 16
        self.z_dims = 16
        self.img_channels = 1
        self.img_HW = 32
        self.use_fc = fc

        self.likelihood_family = px_y_family_ll
        self.grayscale = grayscale

        # Layers for q(z|x):
        if self.use_fc:
            self.qz_fc1 = nn.Linear(in_features=32*32, out_features=512)
            self.qz_fc2 = nn.Linear(in_features=512, out_features=512)
            self.qz_mu = nn.Linear(in_features=512, out_features=self.z_dims)
            self.qz_pre_sp = nn.Linear(in_features=512, out_features=self.z_dims)
        else:
            self.qz_conv1 = nn.Conv2d(
                in_channels=1, out_channels=self.c, kernel_size=4, stride=2, padding=1
            )  # out: c x 16 x 16
            self.qz_conv2 = nn.Conv2d(
                in_channels=self.c,
                out_channels=self.c * 2,
                kernel_size=4,
                stride=2,
                padding=1,
            )  # out: c x 8 x 8
            self.qz_mu = nn.Linear(in_features=self.c * 2 * 8 * 8, out_features=self.z_dims)
            self.qz_pre_sp = nn.Linear(
                in_features=self.c * 2 * 8 * 8, out_features=self.z_dims
            )

        # Layers for p(x|z):
        if self.use_fc:
            self.px_l1 = nn.Linear(in_features=self.z_dims, out_features=512)
            self.px_fc1 = nn.Linear(in_features=512, out_features=512)
            if self.likelihood_family == 'MoL' and self.grayscale:
                self.px_fc2 = nn.Linear(in_features=512, out_features=32*32 *3*n_mix) # (img_channels * 3 + 1) * n_mix # mean, variance and mixture coeff per channel plus logits
            else:
                self.px_fc2 = nn.Linear(in_features=512, out_features=32*32)
        else:
            self.px_l1 = nn.Linear(in_features=self.z_dims, out_features=self.c * 2 * 8 * 8)
            self.px_conv1 = nn.ConvTranspose2d(
                in_channels=self.c * 2,
                out_channels=self.c,
                kernel_size=4,
                stride=2,
                padding=1,
            )
            if self.likelihood_family == "MoL" and self.grayscale:
                self.px_conv2 = nn.ConvTranspose2d(
                    in_channels=self.c,
                    out_channels=3 * n_mix,
                    kernel_size=4,
                    stride=2,
                    padding=1,
                )
            else:
                self.px_conv2 = nn.ConvTranspose2d(
                    in_channels=self.c, out_channels=1, kernel_size=4, stride=2, padding=1
                )

        if self.likelihood_family == "MoL":
            # Number of mix logistic components for MoL
            # self.n_mix = n_mix
            # self.px_out_channels = (self.img_channels * 3 + 1) * self.n_mix # mean, variance and mixture coeff per channel plus logits
            self.n_components = n_mix

        # Set standard deviation of p(x|z)
        # self.log_sigma = 0
        self.log_sigma = torch.tensor(sigma).log()
        if self.likelihood_family == "GaussianLearnedSigma":
            ## Sigma VAE
            self.log_sigma = nn.Parameter(torch.full((1,), 0, dtype=torch.float32)[0])

    def q_z(self, x):
        if self.use_fc:
            h = x.view(x.size(0), -1) # flatten batch of multi-channel feature maps to a batch of feature vectors
            h = F.relu(self.qz_fc1(h))
            h = F.relu(self.qz_fc2(h))
        else:
            h = F.relu(self.qz_conv1(x))
            h = F.relu(self.qz_conv2(h))
            h = h.view(
                h.size(0), -1
            )  # flatten batch of multi-channel feature maps to a batch of feature vectors
        z_mu = self.qz_mu(h)
        z_pre_sp = self.qz_pre_sp(h)
        z_std = F.softplus(z_pre_sp)
        return self.reparameterize(z_mu, z_std), z_mu, z_std

    def p_x(self, z):
        h = F.relu(self.px_l1(z))
        if self.use_fc:
            h = F.relu(self.px_fc1(h))
            h = self.px_fc2(h)
            if self.likelihood_family == 'MoL':
                x = h.view(h.size(0), -1, 32, 32)
            else:
                h = torch.sigmoid(h)
                x = h.view(h.size(0), 1, 32, 32) # unflatten batch of feature vectors to a batch of multi-channel feature maps
        else:
            h = h.view(
                h.size(0), self.c * 2, 8, 8
            )  # unflatten batch of feature vectors to a batch of multi-channel feature maps
            h = F.relu(self.px_conv1(h))
            h = self.px_conv2(h)
            if self.likelihood_family == "MoL":
                x = h
            else:
                x = torch.sigmoid(h)
        return x

    def reparameterize(self, mu, std):
        eps = torch.randn(mu.size())
        eps = eps.to(self.device)

        return mu + eps * std

    def sample_x(self, num=10):
        # sample latent vectors from the normal distribution
        z = torch.randn(num, self.z_dims)
        z = z.to(self.device)

        fz = self.p_x(z)

        if self.likelihood_family == "MoL":
            fz = utils.sample_from_discretized_mix_logistic(
                fz, grayscale=self.grayscale
            )
            fz = (fz + 1) / 2
            fz = fz.clamp(min=0.0, max=1.0)

        return fz

    def reconstruction(self, x, use_sample=False):
        with torch.no_grad():
            z_sample, z_mean, _ = self.q_z(x)
            if use_sample:
                fz = self.p_x(z_sample)
            else:
                fz = self.p_x(z_mean)

        if self.likelihood_family == "MoL":
            fz = utils.sample_from_discretized_mix_logistic(
                fz, grayscale=self.grayscale
            )
            fz = (fz + 1) / 2
            fz = fz.clamp(min=0.0, max=1.0)

        return fz

    def loglikelihood_x_y(self, x, fz):
        """Computer the loglikelihood: <log p(x|z)>_q
        - For MNIST, we use Bernoulli for p(x|y)
        - For Colour Image, we can try out:
        1. N(f(z), (c I)^2), gaussian with constant variance
        2. N(f(z), (sigma I)^2), gaussian with shared learnt variance
        3. Mixture of logistics:
                Assume input data to be originally uint8 (0, ..., 255) and then rescaled
            by 1/255: discrete values in {0, 1/255, ..., 255/255}.
            When using the original discretize logistic mixture logprob implementation,
            this data should be rescaled to be in [-1, 1].
        etc.

        see paper 'Simple and Effective VAE Training with Calibrated Decoders'
            by Oleh Rybkin, Kostas Daniilidis, Sergey Levine
        https://arxiv.org/pdf/2006.13202.pdf

        code : https://github.com/orybkin/sigma-vae-pytorch/blob/master/model.py
        """

        if self.likelihood_family == "GaussianFixedSigma":
            # For constant variance, assume it's c: i.e. self.log_sigma
            log_sigma = self.log_sigma
        elif self.likelihood_family == "GaussianLearnedSigma":
            # Sigma VAE learns the variance of the decoder as another parameter
            log_sigma = self.log_sigma

            # Learning the variance can become unstable in some cases. Softly limiting log_sigma to a minimum of -6
            # ensures stable training.
            min = -6
            log_sigma = min + F.softplus(log_sigma - min)
        elif self.likelihood_family == "MoL":
            x = x * 2 - 1  # Transform from [0, 1] to [-1, 1]
        elif self.likelihood_family == "Bernoulli":
            x = x.view(-1, self.img_channels * self.img_HW**2)
            fz = fz.view(-1, self.img_channels * self.img_HW**2)
        else:
            raise NotImplementedError

        if self.likelihood_family == "MoL":
            # mixture of logistic likelihood
            ll = -utils.discretized_mix_logistic_loss(x, fz, grayscale=self.grayscale)
        elif (
            self.likelihood_family == "GaussianFixedSigma"
            or self.likelihood_family == "GaussianLearnedSigma"
        ):
            # gaussian log likelihood
            nll = 0.5 * (
                ((x - fz) ** 2) * torch.exp(-2 * log_sigma)
                + 2 * log_sigma
                + torch.log(torch.tensor(2 * math.pi))
            )
            nll = torch.sum(torch.flatten(nll, start_dim=1), dim=-1)
            ll = -nll
        elif self.likelihood_family == "Bernoulli":
            ll = torch.sum(
                torch.flatten(
                    x * torch.log(fz + 1e-8) + (1 - x) * torch.log(1 - fz + 1e-8),
                    start_dim=1,
                ),
                dim=-1,
            )
        return ll

    def forward(self, x, eval_individual=False):
        z, qz_mu, qz_std = self.q_z(x)

        fz = self.p_x(z)

        # For likelihood : <log p(x|y)>_q :
        ll = self.loglikelihood_x_y(x, fz)

        qz = D.normal.Normal(qz_mu, qz_std)
        qz = D.independent.Independent(qz, 1)
        pz = D.normal.Normal(torch.zeros_like(z), torch.ones_like(z))
        pz = D.independent.Independent(pz, 1)

        # For: KL[q(z|x) || p(z)]
        kl = D.kl.kl_divergence(qz, pz)

        elbo = ll - kl

        if eval_individual:
            return elbo, ll, kl
        else:
            return -elbo.mean(), ll.mean(), kl.mean()



class VAE_CIFAR(nn.Module):
    def __init__(self, px_y_family_ll, device, grayscale=False, sigma=0.03, n_mix=10, c=32, z_dims=20):
        super(VAE_CIFAR, self).__init__()

        self.device = device
        self.c = c
        self.z_dims = z_dims
        self.img_channels = 3
        self.img_HW = 32

        self.likelihood_family = px_y_family_ll
        self.grayscale = grayscale

        self.qz_conv1 = nn.Conv2d(in_channels=self.img_channels, out_channels=self.c, kernel_size=3, stride=1, padding=1) # out: c x 32 x 32
        self.qz_conv2 = nn.Conv2d(in_channels=self.c, out_channels=self.c*2, kernel_size=4, stride=2, padding=1) # out: 2c x 16 x 16
        self.qz_conv3 = nn.Conv2d(in_channels=self.c*2, out_channels=self.c*4, kernel_size=5, stride=2, padding=2) # out: 4c x 8 x 8
        self.qz_mu = nn.Linear(in_features=self.c*4*8*8, out_features=self.z_dims)
        self.qz_pre_sp = nn.Linear(in_features=self.c*4*8*8, out_features=self.z_dims)

        self.px_l1 = nn.Linear(in_features=self.z_dims, out_features=self.c*4*8*8)
        self.px_conv1 = nn.ConvTranspose2d(in_channels=self.c*4, out_channels=self.c*2, kernel_size=6, stride=2, padding=2) # 16x16
        self.px_conv2 = nn.ConvTranspose2d(in_channels=self.c*2, out_channels=self.c, kernel_size=6, stride=2, padding=2) # 32x32
        if self.likelihood_family == "MoL":
            self.px_conv3 = nn.ConvTranspose2d(in_channels=self.c, out_channels=self.img_channels * n_mix, kernel_size=5, stride=1, padding=2)
        else:
            self.px_conv3 = nn.ConvTranspose2d(in_channels=self.c, out_channels=self.img_channels, kernel_size=5, stride=1, padding=2) # 32x32


        if self.likelihood_family == "MoL":
            # Number of mix logistic components for MoL
            # self.n_mix = n_mix
            # self.px_out_channels = (self.img_channels * 3 + 1) * self.n_mix # mean, variance and mixture coeff per channel plus logits
            self.n_components = n_mix

        # Set standard deviation of p(x|z)
        # self.log_sigma = 0
        self.log_sigma = torch.tensor(sigma).log()
        if self.likelihood_family == "GaussianLearnedSigma":
            ## Sigma VAE
            self.log_sigma = nn.Parameter(torch.full((1,), 0, dtype=torch.float32)[0])

    def q_z(self, x):
        h = F.relu(self.qz_conv1(x))
        h = F.relu(self.qz_conv2(h))
        h = F.relu(self.qz_conv3(h))
        h = h.view(
            h.size(0), -1
        )  # flatten batch of multi-channel feature maps to a batch of feature vectors
        z_mu = self.qz_mu(h)
        z_pre_sp = self.qz_pre_sp(h)
        z_std = F.softplus(z_pre_sp)
        return self.reparameterize(z_mu, z_std), z_mu, z_std

    def p_x(self, z):

        h = F.relu(self.px_l1(z))
        h = h.view(
            h.size(0), self.c * 4, 8, 8
        )  # unflatten batch of feature vectors to a batch of multi-channel feature maps
        h = F.relu(self.px_conv1(h))
        h = F.relu(self.px_conv2(h))
        h = self.px_conv3(h)
        if self.likelihood_family == "MoL":
            x = h
        else:
            x = torch.sigmoid(h)
        return x

    def reparameterize(self, mu, std):
        eps = torch.randn(mu.size())
        eps = eps.to(self.device)

        return mu + eps * std

    def sample_x(self, num=10):
        # sample latent vectors from the normal distribution
        z = torch.randn(num, self.z_dims)
        z = z.to(self.device)

        fz = self.p_x(z)

        if self.likelihood_family == "MoL":
            fz = utils.sample_from_discretized_mix_logistic(
                fz, grayscale=self.grayscale
            )
            fz = (fz + 1) / 2
            fz = fz.clamp(min=0.0, max=1.0)

        return fz

    def reconstruction(self, x, use_sample=False):
        with torch.no_grad():
            z_sample, z_mean, _ = self.q_z(x)
            if use_sample:
                fz = self.p_x(z_sample)
            else:
                fz = self.p_x(z_mean)

        if self.likelihood_family == "MoL":
            fz = utils.sample_from_discretized_mix_logistic(
                fz, grayscale=self.grayscale
            )
            fz = (fz + 1) / 2
            fz = fz.clamp(min=0.0, max=1.0)

        return fz

    def loglikelihood_x_y(self, x, fz):
        """Computer the loglikelihood: <log p(x|z)>_q
        - For MNIST, we use Bernoulli for p(x|y)
        - For Colour Image, we can try out:
        1. N(f(z), (c I)^2), gaussian with constant variance
        2. N(f(z), (sigma I)^2), gaussian with shared learnt variance
        3. Mixture of logistics:
                Assume input data to be originally uint8 (0, ..., 255) and then rescaled
            by 1/255: discrete values in {0, 1/255, ..., 255/255}.
            When using the original discretize logistic mixture logprob implementation,
            this data should be rescaled to be in [-1, 1].
        etc.

        see paper 'Simple and Effective VAE Training with Calibrated Decoders'
            by Oleh Rybkin, Kostas Daniilidis, Sergey Levine
        https://arxiv.org/pdf/2006.13202.pdf

        code : https://github.com/orybkin/sigma-vae-pytorch/blob/master/model.py
        """

        if self.likelihood_family == "GaussianFixedSigma":
            # For constant variance, assume it's c: i.e. self.log_sigma
            log_sigma = self.log_sigma
        elif self.likelihood_family == "GaussianLearnedSigma":
            # Sigma VAE learns the variance of the decoder as another parameter
            log_sigma = self.log_sigma

            # Learning the variance can become unstable in some cases. Softly limiting log_sigma to a minimum of -6
            # ensures stable training.
            min = -6
            log_sigma = min + F.softplus(log_sigma - min)
        elif self.likelihood_family == "MoL":
            x = x * 2 - 1  # Transform from [0, 1] to [-1, 1]
        elif self.likelihood_family == "Bernoulli":
            x = x.view(-1, self.img_channels * self.img_HW**2)
            fz = fz.view(-1, self.img_channels * self.img_HW**2)
        else:
            raise NotImplementedError

        if self.likelihood_family == "MoL":
            # mixture of logistic likelihood
            ll = -utils.discretized_mix_logistic_loss(x, fz, grayscale=self.grayscale)
        elif (
            self.likelihood_family == "GaussianFixedSigma"
            or self.likelihood_family == "GaussianLearnedSigma"
        ):
            # gaussian log likelihood
            nll = 0.5 * (
                ((x - fz) ** 2) * torch.exp(-2 * log_sigma)
                + 2 * log_sigma
                + torch.log(torch.tensor(2 * math.pi))
            )
            nll = torch.sum(torch.flatten(nll, start_dim=1), dim=-1)
            ll = -nll
        elif self.likelihood_family == "Bernoulli":
            ll = torch.sum(
                torch.flatten(
                    x * torch.log(fz + 1e-8) + (1 - x) * torch.log(1 - fz + 1e-8),
                    start_dim=1,
                ),
                dim=-1,
            )
        return ll

    def forward(self, x, eval_individual=False):
        z, qz_mu, qz_std = self.q_z(x)

        fz = self.p_x(z)

        # For likelihood : <log p(x|y)>_q :
        ll = self.loglikelihood_x_y(x, fz)

        qz = D.normal.Normal(qz_mu, qz_std)
        qz = D.independent.Independent(qz, 1)
        pz = D.normal.Normal(torch.zeros_like(z), torch.ones_like(z))
        pz = D.independent.Independent(pz, 1)

        # For: KL[q(z|x) || p(z)]
        kl = D.kl.kl_divergence(qz, pz)

        elbo = ll - kl

        if eval_individual:
            return elbo, ll, kl
        else:
            return -elbo.mean(), ll.mean(), kl.mean()
