import numpy as np
import torch
from models.vae import VAE


class T3VAE(VAE):

    def __init__(
        self,
        nu=2.1,
        input_shape=1,
        latent_dim=32,
        reconstruction="l2",
        mode="cnn",
        device="cuda" if torch.cuda.is_available() else "cpu",
        normalization="batchnorm",
        activation="relu",
        dropout_rate=0.0,
    ):
        super(T3VAE, self).__init__(
            input_shape=input_shape,
            latent_dim=latent_dim,
            reconstruction=reconstruction,
            mode=mode,
            device=device,
            normalization=normalization,
            activation=activation,
            dropout_rate=dropout_rate,
        )
        self.name = "T3VAE"
        if nu <= 2:
            raise ValueError("Degrees of freedom nu must be greater than 2.")

        self.nu_n = nu + self.input_dim
        self.nu_m = nu + self.latent_dim
        self.gamma = -2 / (self.nu + self.input_dim + self.latent_dim)

        self.MVN_dist = torch.distributions.MultivariateNormal(
            torch.zeros(self.latent_dim), torch.eye(self.latent_dim)
        )
        self.chi_dist = torch.distributions.chi2.Chi2(torch.tensor([self.nu_n]))

        self.decoder_sigma = 1

        def _log_t_normalizing_const(nu, d):
            nom = torch.lgamma(torch.tensor((nu + d) / 2))
            denom = torch.lgamma(torch.tensor(nu / 2)) + d / 2 * (
                np.log(nu) + np.log(np.pi)
            )
            return nom - denom

        log_tau_base = (
            -self.input_dim * np.log(self.decoder_sigma)
            + _log_t_normalizing_const(nu, self.input_dim)
            - np.log(nu + self.input_dim - 2)
            + np.log(nu - 2)
        )
        log_tau = 2 / (nu + self.input_dim - 2) * log_tau_base
        self.tau_sq = (self.nu / self.nu_n) * log_tau.exp()
        self.tau = self.tau_sq.sqrt()

        const_2bar1_term_1 = 1 + latent_dim / (nu + self.input_dim - 2)
        const_2bar1_term_2_log = -self.gamma / (1 + self.gamma) * log_tau_base
        self.const_2bar1 = const_2bar1_term_1 * const_2bar1_term_2_log.exp()

        print("tau : ", self.tau)
        print("gamma : ", self.gamma)

    def _reparameterize(self, mu, logvar):
        """
        t-reparametrization trick

        Let nu_n = nu + input_dim
        1. Generate v ~ chiq(nu_n) and eps ~ N(0, I), independently.
        2. Caculate x = mu + std * eps / (sqrt(v/nu_n)), where std = sqrt(nu/(nu_n) * var)
        """

        # Student T dist : [B, z_dim]
        eps = self.MVN_dist.sample(sample_shape=torch.tensor([mu.shape[0]])).to(
            self.device
        )
        std = torch.exp(0.5 * logvar)
        std = torch.tensor(self.nu / self.nu_n).sqrt() * std
        v = self.chi_dist.sample(sample_shape=torch.tensor([mu.shape[0]])).to(
            self.device
        )
        return mu + std * eps * torch.sqrt(self.nu_n / v)

    def forward(self, x):
        z, (mu, logvar) = self.encode(x)
        x_hat = self.decode(z)
        return z, x_hat, (mu, logvar)

    def regularization(self, derivatives):
        mu, logvar = derivatives
        mu_norm_sq = torch.linalg.norm(mu, ord=2, dim=1).pow(2)
        trace_var = (
            self.nu
            / (self.nu + self.input_dim - 2)
            * torch.sum(logvar.exp(), dim=1)
        )
        log_det_var = (
            -self.gamma / (2 + 2 * self.gamma) * torch.sum(logvar, dim=1)
        )
        loss = (
            torch.mean(
                mu_norm_sq
                + trace_var
                - self.nu * self.const_2bar1 * log_det_var.exp(),
                dim=0,
            )
            + self.nu_n * self.tau_sq
        )
        return loss

    def generate(self, n_gen=64):
        """
        Instead of t-prior, we use alternative prior p(z) ~ t(z|nu+input_dim,tau^2*I)
        By doing this, we can generate more stable images.
        """

        tau = self.tau.to(self.device)
        prior_chi_dist = torch.distributions.chi2.Chi2(
            torch.tensor([self.nu_n])
        )
        prior_z = self.MVN_dist.sample(sample_shape=torch.tensor([n_gen])).to(
            self.device
        )
        v = prior_chi_dist.sample(sample_shape=torch.tensor([n_gen])).to(
            self.device
        )
        prior_t = self.decoder_sigma * prior_z * torch.sqrt(self.nu_n / v)
        prior_t *= tau

        x_gen = self.decode(prior_t.to(self.device)).detach().cpu()

        return x_gen
