import torch
import numpy as np

from utils.functions import log_t_normalizing_const, log_Pareto_normalizing_const

def select_sampler(type="t1",device='cpu', seed=1):
    if type == "t1":
        return T_Sampler(device,seed)
    elif type == "p1":
        return Pareto_Sampler(device,seed)
    elif type == "p2":
        return Nonlinear_P_Sampler(device, seed)
    else:
        return None


class Sampler():
    def __init__(self, device, SEED):
        self.device = device
        self.SEED = SEED

    def sample_generation(self, K = 1, N = 1000, ratio_list = [1.0], mu_list = None, var_list = None, nu_list = None) : 
        N_list = np.random.multinomial(N, ratio_list)
        result_list = [self.sampling(N_list[ind], mu_list[ind], var_list[ind], nu_list[ind]) for ind in range(K)]
        result = torch.cat(result_list)
        shuffled_ind = torch.randperm(result.shape[0])
        return result[shuffled_ind]
    
    def density_contour(self, x, K, sample_nu_list, mu_list, var_list, ratio_list) : 
        output = 0
        for ind in range(K) : 
            output += ratio_list[ind] * self.density(x, sample_nu_list[ind], mu_list[ind], var_list[ind])
        return output
    
    def density(self):
        pass


class T_Sampler(Sampler):
    def sampling(self,N, mu, cov, nu) :
        MVN_dist = torch.distributions.MultivariateNormal(torch.zeros_like(mu), cov)
        eps = MVN_dist.sample(sample_shape=torch.tensor([N]))
        
        if nu != 0 : 
            chi_dist = torch.distributions.chi2.Chi2(torch.tensor([nu]))
            v = chi_dist.sample(sample_shape=torch.tensor([N]))
            eps *= torch.sqrt(nu/v)

        return (mu + eps).to(self.device)

    def density(self, x, nu, mu = torch.zeros(1), var = torch.ones(1,1)) : 
        if nu == 0 : 
            const_term = - 0.5 * np.log(2 * np.pi)
            exp_term = - 0.5 * (mu - x).pow(2) / var
            return torch.exp(const_term + exp_term) / torch.sqrt(var)
        else : 
            const_term = log_t_normalizing_const(nu, 1)
            power_term = -torch.log(1 + (mu - x).pow(2) / (nu * var)) * (nu + 1) / 2
            return torch.exp(const_term + power_term) / torch.sqrt(var)

    
class Pareto_Sampler(Sampler):
    def sampling(self, N, mu, cov, nu) :
        mu = torch.tensor([mu]).to(self.device) if not isinstance(mu, torch.Tensor) else mu.to(self.device)
        nu = torch.tensor([nu]).to(self.device) if not isinstance(nu, torch.Tensor) else nu.to(self.device)
        scale_Pa = cov.reshape(-1).to(self.device) * nu
        Pa_dist = torch.distributions.pareto.Pareto(scale=scale_Pa, alpha=nu, validate_args=None)
        sample = Pa_dist.sample(sample_shape=torch.tensor([N]))
        sample -= scale_Pa # shifting to 0
        return sample.to(self.device)
    
    def density(self, x, nu, mu=torch.ones(1), var=torch.ones(1)):
        nu = torch.tensor(nu) if not isinstance(nu, torch.Tensor) else nu
        mu = torch.tensor(mu) if not isinstance(mu, torch.Tensor) else mu
        x = torch.tensor(x) if not isinstance(x, torch.Tensor) else x
        mask = x >= 1e-6
        const_term = log_Pareto_normalizing_const(1,nu,nu) - np.log(var)
        power_term = -(nu + 1) * torch.log(x)
        
        result = torch.where(mask, torch.exp(const_term + power_term), torch.tensor(0.0))
        return result
    
class Nonlinear_P_Sampler(Sampler):
    def sampling(self, N, mu, cov, nu) :
        scale_Pa = cov.reshape(-1).to(self.device) * nu
        Pa_dist = torch.distributions.pareto.Pareto(scale=scale_Pa, alpha=nu, validate_args=None)
        m = torch.distributions.studentT.StudentT(torch.tensor([2.0]))
        eps1 = Pa_dist.sample(sample_shape=torch.tensor([N])).to(self.device) 
        eps2 = Pa_dist.sample(sample_shape=torch.tensor([N])).to(self.device)  
        x = Pa_dist.sample(sample_shape=torch.tensor([N])).to(self.device)
        x -= scale_Pa # shifting to 0
        eps1 -= scale_Pa
        y = x + eps1
        mask = (x < 30) & (y < 30)
        x = x[mask].reshape(-1,1)
        y = y[mask].reshape(-1,1)
        # z = x + torch.sqrt(x / 4) + eps2
        print(f"maximum value of each element : {x.max(), y.max()}")
        samples = torch.cat([x,y], dim = 1)
        return samples.to(self.device)