
import torch.nn as nn
import torch.distributions as td
import torch
from utils.constants import Cte
import torch.nn.functional as F

def get_likelihood(name, domain_size):
    lik_cls = None
    if name == Cte.BETA: # Beta
        lik_cls =  BetaLikelihood
    elif name == Cte.BERNOULLI: # Bernoulli
        lik_cls =  BernoulliLikelihood
    elif name == Cte.CATEGORICAL: # Categorical
        lik_cls =  CategoricalLikelihood
    elif name == Cte.CONTINOUS_BERN: # Continuous Bernoulli
        lik_cls =  ContinousBernoulliLikelihood
    elif name == Cte.DELTA: # Delta
        lik_cls =  DeltaLikelihood
    elif name == Cte.GAUSSIAN:
        lik_cls =  NormalLikelihood
    elif name == 'normal_fix':
        lik_cls  = NormalLikelihoodFixStd
    else:
        raise NotImplementedError()

    return lik_cls(domain_size)

class BaseLikelihood(nn.Module):
    def __init__(self, domain_size):
        super().__init__()
        self._domain_size = domain_size

    @property
    def domain_size(self):
        return self._domain_size

    @property
    def params_size(self):
        raise NotImplementedError()


class BetaLikelihood(BaseLikelihood):
    def __init__(self, domain_size):
        super().__init__(domain_size)

    @property
    def params_size(self):
        return self._domain_size*2

    def forward(self, logits, return_mean=False):
        logits = F.softplus(logits)
        latent_dim = logits.size(1) // 2
        c0, c1 = torch.split(logits, split_size_or_sections=latent_dim, dim=1)
        p = td.Beta(c0, c1)
        if return_mean:
            return p.mean, p
        else:
            return p

class BernoulliLikelihood(BaseLikelihood):
    def __init__(self, domain_size):
        super().__init__(domain_size)

    @property
    def params_size(self):
        return self._domain_size

    def forward(self, logits, return_mean=False):
        p = td.Bernoulli(logits=logits)
        if return_mean:
            return p.mean, p
        else:
            return p







class CategoricalLikelihood(BaseLikelihood):
    def __init__(self, domain_size):
        super().__init__(domain_size)

    @property
    def params_size(self):
        return self._domain_size

    def forward(self, logits, return_mean=False):
        p = td.Categorical(logits=logits)
        if return_mean:
            return F.softmax(logits), p
        else:
            return p

class ContinousBernoulliLikelihood(BaseLikelihood):
    def __init__(self, domain_size):
        super().__init__(domain_size)

    @property
    def params_size(self):
        return self._domain_size

    def forward(self, logits, return_mean=False):
        p = td.ContinuousBernoulli(logits=logits)
        if return_mean:
            return F.sigmoid(logits), p
        else:
            return p


class DeltaLikelihood(BaseLikelihood):
    def __init__(self, domain_size):
        super().__init__(domain_size)

        self.lambda_ = 1.0

    @property
    def params_size(self):
        return self._domain_size

    def set_lambda(self, value):
        self.lambda_ = value

    def forward(self, logits, return_mean=False):
       # logits = torch.sigmoid(logits)
        p = Delta(logits, lambda_=self.lambda_)
        if return_mean:
            return p.mean, p
        else:
            return p


class NormalLikelihood(BaseLikelihood):
    def __init__(self, domain_size):
        super().__init__(domain_size)
        self.clip_std = 0.0
        self.fix_std = None


    def set_fix_std(self, value):
        self.fix_std = value

    @property
    def params_size(self):
        return self._domain_size*2

    def forward(self, logits, return_mean=False):
        latent_dim = logits.size(1)//2
        mu, log_var = torch.split(logits, split_size_or_sections=latent_dim, dim=1)
        std = torch.exp(log_var / 2) + 0.0001
        if self.clip_std > 0:
            std = torch.clip(std, max=self.clip_std)
        if isinstance(self.fix_std, float):
            std = torch.ones_like(std, requires_grad=False)*self.fix_std
        # std = 0.001*torch.sigmoid(log_var)

        p = td.Normal(mu, std)
        if return_mean:
            return mu, p
        else:
            return p


class NormalLikelihoodFixStd(BaseLikelihood):
    def __init__(self, domain_size):
        super().__init__(domain_size)


        self.fix_std = 0.01



    @property
    def params_size(self):
        return self._domain_size

    def forward(self, logits, return_mean=False):

        std = torch.ones_like(logits, requires_grad=False)*self.fix_std
        # std = 0.001*torch.sigmoid(log_var)

        p = td.Normal(logits, std)
        if return_mean:
            return logits, p
        else:
            return p



# %%

class Delta(td.Distribution):
    def __init__(self, center=None, lambda_=1.0, validate_args=None):
        if center is None:
            raise ValueError("`center` must be specified.")
        self.center = center
        self.lambda_=lambda_
        self._param = self.center
        batch_shape = self._param.size()
        super(Delta, self).__init__(batch_shape, validate_args=validate_args)

    @property
    def mean(self):
        return self.center

    def sample(self, sample_shape=torch.Size()):
        return self.center

    def rsample(self, sample_shape=torch.Size()):
        raise NotImplementedError()

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)


        return  - (1/self.lambda_)*(value - self.center)**2
