import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from .baseline import baseline



def gamma_pow_div(nu, gamma, m, theta1, theta2, mu):
    gamma_ratio = gamma / (1+gamma)
    mu_diff_norm = (mu/theta2).norm(dim=1)
    const = (nu +m) * (gamma_ratio * log_pareto_normalizing_const(m,nu-1,nu)).exp()
    term_1 = (-gamma_ratio * theta2.log().sum(dim=1)).exp()
    term_1 *= (1 + (torch.log(theta1) - torch.log(theta2)).exp().sum(dim=1) /(nu -1) + mu_diff_norm/nu)
    term_2 = (-gamma_ratio  * theta1.log().sum(dim=1)).exp() * ((nu + m - 1)/ (nu - 1))
    return const*(term_1 - term_2)

def log_pareto_normalizing_const(m, nu1, nu2):
    nom = torch.lgamma(torch.tensor(nu1+m))
    denom = torch.lgamma(torch.tensor(nu1)) + m * np.log(2 * nu2)
    return nom - denom

def log_t_normalizing_const(d,nu):
    nom = torch.lgamma(torch.tensor((nu+d)/2))
    denom = torch.lgamma(torch.tensor(nu/2)) + d/2 * (np.log(nu) + np.log(np.pi))
    return nom - denom

class ParetoVAE(baseline) : 
    def __init__(self, n_dim=1, m_dim=1, nu=3.0, recon_sigma=1.0, reg_weight=1.0, num_hidden=64, device='cpu',reconstruction_loss='l2'):
        super(ParetoVAE, self).__init__(n_dim=n_dim, m_dim=m_dim)
        
        self.model_name = f"ParetoVAE_nu_{nu}_mdim_{m_dim}"
        self.reconstruction_loss = reconstruction_loss
        if nu <= 1:
            raise ValueError("Degrees of freedom nu must be greater than 1.")
        
        self.device = device
        self.name = "GParetoVAE"
        self.nu = nu
        self.nu_n = nu + self.n_dim
        self.nu_m = nu + self.m_dim
        self.nu_nm = nu + self.m_dim + self.n_dim
        self.gamma = -2 /(2* nu + 2*self.m_dim + self.n_dim)
        self.gamma_ratio = self.gamma / (1 + self.gamma)
        self.decoder_sigma = 1
        self.k = 1
        self.n_dim = n_dim
        self.m_dim = m_dim
        self.recon_sigma = recon_sigma
        self.reg_weight = reg_weight
        self.alpha = 1

        # Compute k and alpha
        # Note : For computing alternative forms, we don't need C1.
        denom = 1 + (self.m_dim + 1) * self.gamma
        if denom == 0:
            self.k = 1
            ValueError("The denominator of the k value is zero.")
        elif reconstruction_loss == "l1":
            self.gamma = -1 / self.nu_nm
            self.gamma_ratio = self.gamma / (1 + self.gamma)

            C2_logterm1 = -n_dim * np.log(recon_sigma)
            C2_logterm1 += log_pareto_normalizing_const(n_dim + m_dim,nu, nu)
            C2_logterm2 = np.log(self.nu_nm - 1)
            C2_logterm2 -= np.log(nu - 1)

            log_C2 = self.gamma_ratio * (C2_logterm1 - C2_logterm2)

            # 1. k: scale parameter for the alternative prior
            log_k_base = -n_dim * np.log(recon_sigma)
            log_k_base += - np.log(self.nu_n-1) + np.log(nu-1)
            log_k_base += log_pareto_normalizing_const(
                n_dim, nu - 1, nu
            )
            log_k_base *= 1 / (self.nu_n - 1)
            self.k = (nu / self.nu_n) * np.exp(log_k_base)
            # 2. alpha: the therotical coefficient of the gamma-power regularizer.
            self.alpha = -self.gamma * nu / log_C2.exp()

        elif reconstruction_loss == "l2":
            self.nu_m = self.nu + self.m_dim
            self.nu_n = self.nu + self.n_dim/2
            self.gamma = -2 / (2*self.nu + 2*self.m_dim + self.n_dim)
            self.gamma_ratio = self.gamma / (1 + self.gamma)

            ## Compute C2 ##
            ### 1) Normalizing constant C
            C2_logterm1 = log_pareto_normalizing_const(
                self.m_dim, self.nu_m, self.nu_m
            )
            C2_logterm1 += log_t_normalizing_const(
                self.n_dim, 2*self.nu + 2*self.m_dim)
            C2_logterm1 += -self.n_dim* np.log(self.recon_sigma)
            C2_logterm1 += self.n_dim/2 *np.log(2 + 2*self.m_dim /self.nu)
            ### 2) Remaining term
            C2_logterm2 = np.log(self.nu_m + self.n_dim/2 - 1)
            C2_logterm2 -= np.log(self.nu  - 1)
            log_C2 = self.gamma_ratio * (C2_logterm1 - C2_logterm2)

            # 1. alpha: the therotical coefficient of the gamma-power regularizer.
            self.alpha = -self.gamma * self.nu /(2 * log_C2.exp())


            # 2. k: scale parameter for the alternative prior
            const = np.log(2 / (2+self.n_dim/self.nu))  
            base = -self.n_dim * np.log(recon_sigma)
            base += -self.n_dim/2 * np.log(np.pi)
            base += log_pareto_normalizing_const(self.n_dim/2, self.nu-1, self.nu)
            exponent = 1 / (self.nu_n - 1)
            k_log =  exponent * base + const
            self.k = k_log.exp()
        else:
            raise ValueError("Reconstruction loss should be either l1 or l2.")
        print("nu : ", self.nu)
        print("k : ", self.k)
        print("alpha : ", self.alpha / 2)

        # define encoder/decoder -> See baseline.py
        # define encoder
        self.encoder = nn.Sequential(
            nn.Linear(n_dim, num_hidden),
            nn.LeakyReLU(),  
            nn.BatchNorm1d(num_hidden),
            # nn.Linear(num_hidden, num_hidden),
            # nn.LeakyReLU(),  
            # nn.BatchNorm1d(num_hidden),
        )
        self.latent_mu = nn.Linear(num_hidden, m_dim)
        self.latent_logvar = nn.Linear(num_hidden, m_dim)

        # define decoder

        self.decoder = nn.Sequential(
            nn.Linear(m_dim, num_hidden), 
            nn.LeakyReLU(), 
            nn.BatchNorm1d(num_hidden),
            nn.Linear(num_hidden, n_dim),
        )

    def Pareto_reparameterize(self, mu, lambda_phi):
        """
        Reparameterization using a generalized Pareto-like transformation.
        Args:
            lambda_phi (Tensor): Shape parameter tensor (broadcastable to sample shape).
        
        Returns:
            Tensor: Samples from the reparameterized distribution.
        """
        sample_shape = lambda_phi.shape
        nu_m = self.nu_m

        # Chi-squared distributed noise
        chi_dist = torch.distributions.Chi2(df=nu_m)
        v = chi_dist.sample(sample_shape).to(self.device)

        # Standard Laplace noise
        laplace_dist = torch.distributions.Laplace(loc=0.0, scale=1.0)
        eps = laplace_dist.sample(sample_shape).to(self.device)

        # Reparameterized sample
        samples = mu + eps * (nu_m / v) * lambda_phi
        return samples

    
    def encode(self, x) : 
        x = self.encoder(x)
        lambda_phi_log = self.latent_logvar(x)
        latent_mu = self.latent_mu(x)
        z = self.Pareto_reparameterize(latent_mu,lambda_phi_log.exp())
        return z, lambda_phi_log, latent_mu

    def decode(self, z) : 
        return self.decoder(z)
    
    def recon_loss(self, x, recon_x) :
        if self.reconstruction_loss == "l2":
            return F.mse_loss(recon_x, x, reduction='none').sum(dim=1).mean(dim=0) / (2 * self.recon_sigma**2)
        else:
            return F.l1_loss(recon_x, x, reduction='none').sum(dim=1).mean(dim=0) / self.recon_sigma

    def reg_loss(self, lambda_phi_log, latent_mu) : 
        lambda_phi = lambda_phi_log.exp()
        k_vec = self.k * torch.ones_like(lambda_phi)
        divs = self.alpha * gamma_pow_div(self.nu_n, self.gamma, self.m_dim, lambda_phi, k_vec,  latent_mu)
        loss = divs.mean(dim=0)
        return loss

    def total_loss(self, x, recon_x, lambda_phi_log, latent_mu) : 
        recon = self.recon_loss(recon_x, x)
        reg = self.reg_loss(lambda_phi_log, latent_mu)
        return recon, reg, recon + self.reg_weight * reg

    def generate(self, N = 1000): 
        eps = self.Pareto_reparameterize(torch.zeros(N,self.m_dim).to(self.device),self.k * torch.ones(N,self.m_dim).to(self.device))
        return self.decode(eps)

    def reconstruct(self, x) : 
        return self.decode(self.encode(x)[0])
    
    def recon_data(self, x) : 
        with torch.no_grad():
            enc_z, *_ = self.encode(x)
            recon_x = self.decode(enc_z)
        return recon_x


    def forward(self, x) : 
        enc_z, lambda_phi_log, latent_mu = self.encode(x)
        recon_x = self.decode(enc_z)
        total_loss = self.total_loss(x, recon_x, lambda_phi_log,latent_mu)
        return total_loss
