import torch
from models.vae import VAE
import numpy as np

class ParetoVAE(VAE):

    def __init__(
        self,
        nu=2.1,
        input_shape=1,
        latent_dim=32,
        reconstruction="l1",
        mode="cnn",
        device="cuda" if torch.cuda.is_available() else "cpu",
        normalization="batchnorm",
        activation="relu",
        dropout_rate=0.0,
    ):
        super(ParetoVAE, self).__init__(
            input_shape=input_shape,
            latent_dim=latent_dim,
            nu=nu,
            reconstruction=reconstruction,
            mode=mode,
            device=device,
            normalization=normalization,
            activation=activation,
            dropout_rate=dropout_rate,
        )

        if nu <= 1:
            raise ValueError("Degrees of freedom nu must be greater than 1.")

        self.name = "ParetoVAE"
        self.nu_n = nu + self.input_dim
        self.nu_m = nu + self.latent_dim
        self.nu_nm = nu + self.latent_dim + self.input_dim
        self.gamma = -1 / self.nu_nm
        self.gamma_ratio = self.gamma / (1 + self.gamma)
        self.decoder_sigma = 1

        # Compute k and alpha
        # Note : For computing alternative forms, we don't need C1.
        denom = 1 + (self.latent_dim + 1) * self.gamma
        if denom == 0:
            self.k = 1
            ValueError("The denominator of the k value is zero.")
        elif reconstruction == "l1":
            C2_logterm1 = np.log(self.nu_nm - 1)
            C2_logterm1 -= np.log(self.nu - 1)
            C2_logterm2 = self._log_pareto_normalizing_const(
                self.input_dim + self.latent_dim, self.nu, self.nu
            )
            C2_logterm2 += -self.input_dim * np.log(self.decoder_sigma)
            log_C2 = self.gamma_ratio * (C2_logterm1 - C2_logterm2)

            # 1. k: scale parameter for the alternative prior
            log_k_base = -self.input_dim * np.log(self.decoder_sigma)
            log_k_base += self._log_pareto_normalizing_const(
                self.input_dim, self.nu - 1, self.nu
            )
            log_k_base *= 1 / (self.nu_n - 1)
            self.k = (self.nu / self.nu_n) * np.exp(log_k_base)
            # 2. alpha: the therotical coefficient of the gamma-power regularizer.
            self.alpha = -self.gamma * self.nu / log_C2.exp()

        elif reconstruction == "l2":
            self.nu_n = self.nu + self.input_dim / 2
            self.nu_m = self.nu + self.latent_dim

            self.gamma = -2 / (
                2 * self.nu + 2 * self.latent_dim + self.input_dim
            )
            self.gamma_ratio = self.gamma / (1 + self.gamma)

            C2_logterm1 = self._log_pareto_normalizing_const(
                self.latent_dim, self.nu_m, self.nu_m
            )
            C2_logterm1 += self._log_t_normalizing_const(
                self.input_dim, 2 * self.nu + 2 * self.latent_dim
            )
            C2_logterm1 += -self.input_dim * np.log(self.decoder_sigma)
            C2_logterm1 += (
                self.input_dim / 2 * np.log(2 + 2 * self.latent_dim / self.nu)
            )
            ### 2) Remaining term
            C2_logterm2 = np.log(self.nu_m + self.input_dim / 2 - 1)
            C2_logterm2 -= np.log(self.nu - 1)
            log_C2 = self.gamma_ratio * (C2_logterm1 - C2_logterm2)

            # 1. alpha: the therotical coefficient of the gamma-power regularizer.
            self.alpha = -self.gamma * self.nu / (2 * log_C2.exp())

            # 2. k: scale parameter for the alternative prior
            const = np.log(2 / (2 + self.input_dim / self.nu))
            base = -self.input_dim * np.log(self.decoder_sigma)
            base += -self.input_dim / 2 * np.log(np.pi)
            base += self._log_pareto_normalizing_const(
                self.input_dim / 2, self.nu - 1, self.nu
            )
            exponent = 1 / (self.nu_n - 1)
            k_log = exponent * base + const
            self.k = k_log.exp()
        else:
            raise ValueError("Reconstruction loss should be either l1 or l2.")

        print("C2 : ", log_C2.exp())
        print("nu : ", self.nu)
        print("k : ", self.k)
        print("alpha : ", self.alpha)
        print("gamma : ", self.gamma)

    def forward(self, x):
        z, (lambda_phi_log, _) = self.encode(x)
        x_hat = self.decode(z)
        return z, x_hat, (lambda_phi_log, _)

    def encode(self, x):
        x = self.encoder(x)
        mu = self.enc_to_latent1(x)
        lambda_phi_log = self.enc_to_latent2(x)
        # Reparameterization trick
        z = self._reparameterize(mu, lambda_phi_log.exp())
        return z, (mu, lambda_phi_log)

    def _reparameterize(self, mu, lambda_phi):
        """
        Reparameterization using a generalized Pareto-like transformation.
        Args:
            lambda_phi (Tensor): Shape parameter tensor (broadcastable to sample shape).

        Returns:
            Tensor: Samples from the reparameterized distribution.
        """
        sample_shape = lambda_phi.shape
        nu_m = self.nu + self.latent_dim

        # Chi-squared distributed noise
        gamma_dist = torch.distributions.Gamma(concentration=nu_m, rate=1.0)
        v = gamma_dist.sample(sample_shape).to(self.device)

        # Standard Laplace noise
        laplace_dist = torch.distributions.Laplace(loc=0.0, scale=1.0)
        eps = laplace_dist.sample(sample_shape).to(self.device)

        # Reparameterized sample
        samples = mu + eps * (nu_m / v) * lambda_phi
        return samples

    def regularization(self, derivates):
        mu, lambda_phi_log = derivates
        lambda_phi = lambda_phi_log.exp()
        k_vec = self.k * torch.ones_like(lambda_phi)
        divs = self.alpha * self._gamma_pow_div(
            self.nu_n, self.gamma, self.latent_dim, lambda_phi, k_vec, mu
        )
        loss = divs.sum(dim=0)
        return loss

    def _log_pareto_normalizing_const(self, m, nu1, nu2):
        nom = torch.lgamma(torch.tensor(nu1 + m))
        denom = torch.lgamma(torch.tensor(nu1)) + m * torch.log(
            torch.tensor(nu2)
        )
        return nom - denom

    def _log_t_normalizing_const(self, d, nu):
        nom = torch.lgamma(torch.tensor((nu + d) / 2))
        denom = torch.lgamma(torch.tensor(nu / 2)) + d / 2 * (
            np.log(nu) + np.log(np.pi)
        )
        return nom - denom

    def _gamma_pow_div(self, nu, gamma, m, theta1, theta2, mu):
        gamma_ratio = gamma / (1 + gamma)
        mu_diff_norm = (mu / theta2).norm(dim=1)
        const = (nu + m) * (
            gamma_ratio * self._log_pareto_normalizing_const(m, nu - 1, nu)
        ).exp()
        term_1 = (-gamma_ratio * theta2.log().sum(dim=1)).exp() * (
            1
            + (torch.log(theta1) - torch.log(theta2)).exp().sum(dim=1)
            / (nu - 1)
            + mu_diff_norm / nu
        )
        term_2 = (-gamma_ratio * theta1.log().sum(dim=1)).exp() * (
            (nu + m - 1) / (nu - 1)
        )
        return const * (term_1 - term_2)

    def generate(self, n_gen=64, temperature=1.0):
        # Generation: sample random latent vectors and decode

        laplace_dist = torch.distributions.Laplace(loc=0.0, scale=1.0)
        eps = laplace_dist.sample((n_gen, self.latent_dim)).to(self.device)
        gamma = torch.distributions.Gamma(concentration=self.nu_n, rate=1.0)
        w = gamma.sample((n_gen, self.latent_dim)).to(self.device)

        z_gen = eps * self.nu / w
        z_gen *= self.k * temperature
        x_gen = self.decode(z_gen.to(self.device)).detach().cpu()
        return x_gen
