import torch
from torch import tensor, distributions as dist
import numpy as np
from ..utils import split_mu_sigma
from torch.distributions.kl import register_kl, kl_divergence
from torch.nn.functional import relu, softplus

from .. import params


dist.distribution.Distribution.set_default_validate_args(False)


def generate_distribution(logits, distribution='normal', temperature=None) -> dist.Distribution:
    """
    Generate parameterized distribution

    :param mu: the mean of the distribution
    :param sigma: the standard deviation of the distribution, not needed for mixture of logistics
    :param distribution: 'mixtures_of_logistics', 'normal', 'laplace'
    :param sigma_nonlin: 'logstd', 'std'
    :param sigma_param: 'var', 'std'

    NORMAL: ("normal", [sigma_nonlin], [sigma_param])
    LAPLACE: ("laplace", [sigma_nonlin], [sigma_param])
    CATEGORICAL: ("categorical", [num_classes])
    LOGNORMAL: ("lognormal", [sigma_nonlin], [sigma_param])
    LOGLAPLACE: ("loglaplace", [sigma_nonlin], [sigma_param])
    SOFTLAPLACE: ("softlaplace", [sigma_nonlin], [sigma_param])
    
    :return: torch.distributions.Distribution object
    """
    
    distribution_name = distribution if isinstance(distribution, str) else distribution[0]
    
    # location scale distributions
    if distribution_name in ['normal', 'laplace', 'lognormal', 'loglaplace', 'softlaplace']:

        from .. import params
        model_params = params.model_params
        beta = model_params.gradient_smoothing_beta
        
        if isinstance(distribution, str):
            sigma_nonlin = model_params.distribution_base
            sigma_param = model_params.distribution_sigma_param
        elif isinstance(distribution, tuple):
            assert len(distribution) == 3, "Expected 3 parameters for normal or laplace distribution"
            sigma_nonlin = distribution[1]
            sigma_param = distribution[2]
            
        mu, sigma = split_mu_sigma(logits)
        return generate_loc_scale(mu, sigma, distribution_name, sigma_nonlin, sigma_param, beta, temperature)
            
    # categorical distributions
    elif distribution_name == 'categorical' or distribution_name == 'one_hot':
        
        if isinstance(distribution, str):
            # Convert torch.Size to tensor before using prod
            num_classes = int(torch.prod(torch.tensor(logits.shape[1:])))
        elif isinstance(distribution, tuple):
            assert len(distribution) == 2, "Expected 2 parameters for categorical or one_hot distribution"
            num_classes = distribution[1]
            
        return generate_categorical(logits, distribution_name, num_classes, temperature)
    
    elif distribution_name == 'uniform':
        return dist.Uniform(low=logits, high=logits)
    
    else:
        raise ValueError(f'Unknown distribution: {distribution}')


    
def generate_loc_scale(mu, sigma, distribution_name, sigma_nonlin, sigma_param, beta, temperature=None):
        if temperature is not None:
            sigma = sigma + torch.ones_like(sigma) * np.log(temperature)
    
        if sigma_nonlin == 'logstd':
            sigma = torch.exp(sigma * beta)
        elif sigma_nonlin == 'std':
            sigma = torch.nn.Softplus(beta=beta)(sigma)
        elif sigma_nonlin != 'none':
            raise ValueError(f'Unknown sigma_nonlin {sigma_nonlin}')

        if sigma_param == 'var':
            sigma = torch.sqrt(sigma)
        elif sigma_param != 'std':
            raise ValueError(f'Unknown sigma_param {sigma_param}')

        if distribution_name == 'normal':
            return dist.Normal(loc=mu, scale=sigma)
        elif distribution_name == 'laplace':
            return dist.Laplace(loc=mu, scale=sigma)
        elif distribution_name == 'lognormal':
            mu = torch.clamp(mu, min=-5.2983) # exp(-5.2983) = 0.005
            return dist.LogNormal(loc=mu, scale=sigma)
        elif distribution_name == 'loglaplace':
            mu = torch.clamp(mu, min=-5.2983) # exp(-5.2983) = 0.005
            return LogLaplace(loc=mu, scale=sigma)
        elif distribution_name == 'softlaplace':
            return SoftLaplace(loc=mu, scale=sigma)
        else:
            raise ValueError(f'Unknown distribution {distribution_name}')
        
        
def generate_categorical(logits, distribution_name,  num_classes, temperature=None):
    dims = len(logits.shape)
    #    permute the logits so channel dim is last?
            #if dims == 3:
            #     logits = logits.permute(0, 2, 1)
            # if dims == 4:
            #     logits = logits.permute(0, 2, 3, 1)
    if distribution_name == 'categorical':
        logits = logits.view(logits.shape[0], -1, num_classes)
        return dist.Independent(dist.Categorical(logits=logits), dims-2)

    elif distribution_name == 'one_hot':
        logits = logits.view(logits.shape[0], -1, num_classes)
        return dist.Independent(dist.OneHotCategoricalStraightThrough(logits=logits), dims-2)
    


class MixtureOfGaussians(dist.mixture_same_family.MixtureSameFamily):
    def __init__(self, logits, loc, scale, validate_args=None):
        component = dist.Normal(loc=loc, scale=scale)
        super(MixtureOfGaussians, self).__init__(
            mixture_distribution=dist.Categorical(logits=logits),
            component_distribution=dist.Independent(component, 1),
            validate_args=validate_args,
        )

class MixturesOfLogistics(dist.mixture_same_family.MixtureSameFamily):
    def __init__(self, logits, loc, scale, validate_args=None):
        component = self._logistic(loc, scale)
        super(MixturesOfLogistics, self).__init__(
            mixture_distribution=dist.Categorical(logits=logits),
            component_distribution=dist.Independent(component, 1),
            validate_args=validate_args,
        )

    @staticmethod
    def _logistic(loc, scale):
        return dist.TransformedDistribution(
            dist.Uniform(torch.zeros(loc.shape), torch.ones(loc.shape)),
            [dist.SigmoidTransform().inv, dist.AffineTransform(loc, scale)]
        )

class SoftLaplace(dist.Laplace):
    def __init__(self, loc, scale, validate_args=None):
        super().__init__(loc=loc, scale=scale, validate_args=validate_args)
        
    def _transform(self, x: torch.Tensor) -> torch.Tensor:
        """Apply softplus transformation"""
        return softplus(x)
    
    '''@property
    def mean(self, sample_shape) -> torch.Tensor:
        """
        Mean of SoftLaplace is intractable, returns transformed sample instead.
        """
        with torch.no_grad():
            return self.rsample(sample_shape)'''
    
    def sample(self, sample_shape = torch.Size()) -> torch.Tensor:
        """Generate non-differentiable samples"""
        return self.rsample(sample_shape).detach()
    
    def rsample(self, sample_shape = torch.Size()) -> torch.Tensor:
        """Generate differentiable samples"""
        return self._transform(super().rsample(sample_shape))
    

class LogLaplace(dist.Laplace):
    def __init__(self, loc, scale, validate_args=None):
        super(LogLaplace, self).__init__(loc=loc, scale=scale, validate_args=validate_args)
        self.loc = loc
        self.scale = scale

    @property
    def mean(self):
        # Scale parameter of loglaplace should be less than 1 to have a finite mean
        # where condition is not met, return nan
        mask = self.scale < 1
        masked_mean = torch.where(mask, torch.exp(self.loc) / (1 - self.scale**2), torch.nan)
        return masked_mean

    @property
    def stddev(self):
        # Scale parameter of loglaplace should be smaller than 0.5 to have a finite stddev
        mask = self.scale < 0.5
        masked_var = torch.where(mask, 
                                 torch.exp(2*self.loc) / (1 - 4*self.scale**2) - \
                                    torch.exp(2*self.loc) / (1 - self.scale**2)**2, 
                                 torch.nan)
        return torch.sqrt(masked_var)
    
    def sample(self, sample_shape=torch.Size()):
        #laplace_sample = super(LogLaplace, self).sample(sample_shape)
        #return torch.exp(laplace_sample)
        return self.rsample(sample_shape).detach()
    
    def rsample(self, sample_shape=torch.Size()):
        laplace_sample = super(LogLaplace, self).rsample(sample_shape)
        return torch.exp(laplace_sample)


class ConcatenatedDistribution(dist.distribution.Distribution):
    """
    Concatenated distribution

    """
    def __init__(self, distributions: list, fuse: str = 'sum'):
        self.distributions = distributions
        self.fuse = fuse
        dbs = distributions[0].batch_shape
        batch_shape = torch.Size([dbs[0], len(distributions), *dbs[1:]])
        super(ConcatenatedDistribution, self).__init__(batch_shape=batch_shape)

    def extend(self, distributions: list):
        self.distributions.extend(distributions)
        return ConcatenatedDistribution(self.distributions, self.fuse)

    @property
    def mean(self) -> torch.Tensor:
        means = [d.mean for d in self.distributions]
        means = torch.stack(means, dim=1)
        return means

    @property
    def variance(self) -> torch.Tensor:
        variances = [d.variance for d in self.distributions]
        variances = torch.stack(variances, dim=0)
        return variances

    def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
        samples = [d.rsample(sample_shape) for d in self.distributions]
        samples = torch.stack(samples, dim=0)
        return samples

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        log_probs = [d.log_prob(value[:, i]) for i, d in enumerate(self.distributions)]
        log_probs = torch.stack(log_probs, dim=1)
        if self.fuse == 'sum':
            log_probs = torch.sum(log_probs, dim=1)
        elif self.fuse == 'mean':
            log_probs = torch.mean(log_probs, dim=1)
        else:
            raise ValueError(f'Unknown fuse {self.fuse}')
        return log_probs

    def entropy(self) -> torch.Tensor:
        entropies = [d.entropy() for d in self.distributions]
        entropies = torch.stack(entropies, dim=0)
        if self.fuse == 'sum':
            entropies = torch.sum(entropies, dim=0)
        elif self.fuse == 'mean':
            entropies = torch.mean(entropies, dim=0)
        else:
            raise ValueError(f'Unknown fuse {self.fuse}')
        return entropies

class JointDistribution(dist.distribution.Distribution):
    """
    Joint independent distributions as a single distribution object, for smvae
    
    :param distributions: list of distributions
    :param fuse: 'sum' or 'mean'
    """
    def __init__(self, distributions: list, fuse: str = 'sum'):
        self.distributions = distributions
        self.fuse = fuse
        dbs = distributions[0].batch_shape
        batch_shape = torch.Size([dbs[0], len(distributions), *dbs[1:]])
        super().__init__(batch_shape=batch_shape)

    def extend(self, distributions: list):
        self.distributions.extend(distributions)
        return JointDistribution(self.distributions, self.fuse)

    @property
    def mean(self) -> torch.Tensor:
        means = [d.mean for d in self.distributions]
        means = torch.cat(means, dim=1)
        return means

    @property
    def variance(self) -> torch.Tensor:
        variances = [d.variance for d in self.distributions]
        variances = torch.cat(variances, dim=1)
        return variances
    
    @property
    def loc(self) -> torch.Tensor:
        locs = [d.loc for d in self.distributions]
        locs = torch.cat(locs, dim=1)
        return locs
    
    @property
    def scale(self) -> torch.Tensor:
        scales = [d.scale for d in self.distributions]
        scales = torch.cat(scales, dim=1)
        return scales

    def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
        samples = [d.rsample(sample_shape) for d in self.distributions]
        samples = torch.cat(samples, dim=1)
        return samples

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        log_probs = [d.log_prob(value[:, i]) for i, d in enumerate(self.distributions)]
        log_probs = torch.cat(log_probs, dim=1)
        if self.fuse == 'sum':
            log_probs = torch.sum(log_probs, dim=1)
        elif self.fuse == 'mean':
            log_probs = torch.mean(log_probs, dim=1)
        else:
            raise ValueError(f'Unknown fuse {self.fuse}')
        return log_probs

    def entropy(self) -> torch.Tensor:
        entropies = [d.entropy() for d in self.distributions]
        entropies = torch.cat(entropies, dim=1)
        if self.fuse == 'sum':
            entropies = torch.sum(entropies, dim=0)
        elif self.fuse == 'mean':
            entropies = torch.mean(entropies, dim=0)
        else:
            raise ValueError(f'Unknown fuse {self.fuse}')
        return entropies

# @register_kl(JointDistribution, JointDistribution)
# def _kl_concat_concat(p, q):
#     kl_divs = []
#     for i in range(len(p.distributions)):
#         kl_divs.append(kl_divergence(p.distributions[i], q.distributions[i]))

#     return kl_divs

@register_kl(JointDistribution, JointDistribution)
def _kl_concat_concat(p, q):
    batch_size = p.distributions[0].batch_shape[0]
    latent_size = sum(d.batch_shape[1] for d in p.distributions)
    kl_divs = torch.zeros(batch_size, latent_size, device=p.distributions[0].mean.device)

    offset = 0
    for i in range(len(p.distributions)):
        kl = kl_divergence(p.distributions[i], q.distributions[i])
        size = kl.shape[1]  # Latent size of the current distribution
        kl_divs[:, offset:offset + size] = kl
        offset += size

    return kl_divs