"""
Reference:
Lopez, Romain, et al. "Deep generative modeling for single-cell transcriptomics." Nature methods 15.12 (2018): 1053-1058.
https://github.com/scverse/scvi-tools/blob/1.0.2/scvi/distributions/_negative_binomial.py
"""


import warnings
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch
import torch.nn.functional as F
from torch.distributions import Distribution, Gamma, constraints
from torch.distributions import Poisson as PoissonTorch
from torch.distributions.utils import (
    broadcast_all,
    lazy_property,
    logits_to_probs,
    probs_to_logits,
)


class AbstractDistribution:
    def sample(self):
        raise NotImplementedError()

    def mode(self):
        raise NotImplementedError()


class DiracDistribution(AbstractDistribution):
    def __init__(self, value):
        self.value = value

    def sample(self):
        return self.value

    def mode(self):
        return self.value


class DiagonalGaussianDistribution(object):
    def __init__(self, parameters, deterministic=False):
        self.parameters = parameters
        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
        self.deterministic = deterministic
        self.std = torch.exp(0.5 * self.logvar)
        self.var = torch.exp(self.logvar)
        if self.deterministic:
            self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)

    def sample(self):
        x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
        return x

    def kl(self, other=None):
        if self.deterministic:
            return torch.Tensor([0.])
        else:
            if other is None:
                return 0.5 * torch.sum(torch.pow(self.mean, 2)
                                       + self.var - 1.0 - self.logvar,
                                       dim=[1, 2, 3])
            else:
                return 0.5 * torch.sum(
                    torch.pow(self.mean - other.mean, 2) / other.var
                    + self.var / other.var - 1.0 - self.logvar + other.logvar,
                    dim=[1, 2, 3])

    def nll(self, sample, dims=[1,2,3]):
        if self.deterministic:
            return torch.Tensor([0.])
        logtwopi = np.log(2.0 * np.pi)
        return 0.5 * torch.sum(
            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
            dim=dims)

    def mode(self):
        return self.mean


def normal_kl(mean1, logvar1, mean2, logvar2):
    """
    source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
    Compute the KL divergence between two gaussians.
    Shapes are automatically broadcasted, so batches can be compared to
    scalars, among other use cases.
    """
    tensor = None
    for obj in (mean1, logvar1, mean2, logvar2):
        if isinstance(obj, torch.Tensor):
            tensor = obj
            break
    assert tensor is not None, "at least one argument must be a Tensor"

    # Force variances to be Tensors. Broadcasting helps convert scalars to
    # Tensors, but it does not work for torch.exp().
    logvar1, logvar2 = [
        x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
        for x in (logvar1, logvar2)
    ]

    return 0.5 * (
        -1.0
        + logvar2
        - logvar1
        + torch.exp(logvar1 - logvar2)
        + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
    )

def log_zinb_positive(
    x: torch.Tensor, mu: torch.Tensor, phi: torch.Tensor, pi: torch.Tensor, eps=1e-8
):
    """Log likelihood (scalar) of a minibatch according to a zinb model.

    Parameters
    ----------
    x
        Data
    mu
        mean of the negative binomial (has to be positive support) (shape: minibatch x vars)
    phi
        inverse dispersion parameter (has to be positive support) (shape: minibatch x vars)
    pi
        logit of the dropout parameter (real support) (shape: minibatch x vars)
    eps
        numerical stability constant

    Notes
    -----
    We parametrize the bernoulli using the logits, hence the softplus functions appearing.
    """
    # phi is the dispersion rate. If .ndimension() == 1, it is shared for all cells (regardless of batch or labels)
    if phi.ndimension() == 1:
        phi = phi.view(
            1, phi.size(0)
        )  # In this case, we reshape phi for broadcasting

    # Uses log(sigmoid(x)) = -softplus(-x)
    softplus_pi = F.softplus(-pi)
    log_phi_eps = torch.log(phi + eps)
    log_phi_mu_eps = torch.log(phi + mu + eps)
    pi_phi_log = -pi + phi * (log_phi_eps - log_phi_mu_eps)

    case_zero = F.softplus(pi_phi_log) - softplus_pi
    mul_case_zero = torch.mul((x < eps).type(torch.float32), case_zero)

    case_non_zero = (
        -softplus_pi
        + pi_phi_log
        + x * (torch.log(mu + eps) - log_phi_mu_eps)
        + torch.lgamma(x + phi)
        - torch.lgamma(phi)
        - torch.lgamma(x + 1)
    )
    mul_case_non_zero = torch.mul((x > eps).type(torch.float32), case_non_zero)

    res = mul_case_zero + mul_case_non_zero

    return res


def log_nb_positive(
    x: torch.Tensor,
    mu: torch.Tensor,
    phi: torch.Tensor,
    eps: float = 1e-8,
    log_fn: callable = torch.log,
    lgamma_fn: callable = torch.lgamma,
):
    """Log likelihood (scalar) of a minibatch according to a nb model.

    Parameters
    ----------
    x
        data
    mu
        mean of the negative binomial (has to be positive support) (shape: minibatch x vars)
    phi
        inverse dispersion parameter (has to be positive support) (shape: minibatch x vars)
    eps
        numerical stability constant
    log_fn
        log function
    lgamma_fn
        log gamma function
    """
    log = log_fn
    lgamma = lgamma_fn
    log_phi_mu_eps = log(phi + mu + eps)
    res = (
        phi * (log(phi + eps) - log_phi_mu_eps)
        + x * (log(mu + eps) - log_phi_mu_eps)
        + lgamma(x + phi)
        - lgamma(phi)
        - lgamma(x + 1)
    )

    return res

def _convert_mean_disp_to_counts_logits(mu, phi, eps=1e-6):
    r"""NB parameterizations conversion.

    Parameters
    ----------
    mu
        mean of the NB distribution.
    phi
        inverse overdispersion.
    eps
        constant used for numerical log stability. (Default value = 1e-6)

    Returns
    -------
    type
        the number of failures until the experiment is stopped
        and the success probability.
    """
    if not (mu is None) == (phi is None):
        raise ValueError(
            "If using the mu/phi NB parameterization, both parameters must be specified"
        )
    logits = (mu + eps).log() - (phi + eps).log()
    total_count = phi
    return total_count, logits


def _convert_counts_logits_to_mean_disp(total_count, logits):
    """NB parameterizations conversion.

    Parameters
    ----------
    total_count
        Number of failures until the experiment is stopped.
    logits
        success logits.

    Returns
    -------
    type
        the mean and inverse overdispersion of the NB distribution.

    """
    phi = total_count
    mu = logits.exp() * phi
    return mu, phi


def _gamma(phi, mu):
    concentration = phi
    rate = phi / mu
    # Important remark: Gamma is parametrized by the rate = 1/scale!
    gamma_d = Gamma(concentration=concentration, rate=rate)
    return gamma_d


class Poisson(PoissonTorch):
    """Poisson distribution.

    Parameters
    ----------
    rate
        rate of the Poisson distribution.
    validate_args
        whether to validate input.
    scale
        Normalized mean expression of the distribution.
        This optional parameter is not used in any computations, but allows to store
        normalization expression levels.

    """

    def __init__(
        self,
        rate: torch.Tensor,
        validate_args: Optional[bool] = None,
        scale: Optional[torch.Tensor] = None,
    ):
        super().__init__(rate=rate, validate_args=validate_args)
        self.scale = scale
    
    @torch.inference_mode()
    def sample_lbd(
        self,
    ) -> torch.Tensor:
        return self.rate

    @torch.inference_mode()
    def sample_from_lbd(
        self,
        lbd: torch.Tensor,
    ) -> torch.Tensor:
        """Sample from lambda."""
        return self.sample(lbd.shape)


class NegativeBinomial(Distribution):
    r"""Negative binomial distribution.
    In the (`mu`, `phi`) parameterization, samples from the negative binomial are generated as follows:

    1. :math:`w \sim \textrm{Gamma}(\underbrace{\phi}_{\text{shape}}, \underbrace{\phi/\mu}_{\text{rate}})`
    2. :math:`x \sim \textrm{Poisson}(w)`

    Parameters
    ----------
    mu
        Mean of the distribution.
    phi
        Inverse dispersion.
    scale
        Normalized mean expression of the distribution.
    validate_args
        Raise ValueError if arguments do not match constraints
    """

    arg_constraints = {
        "mu": constraints.greater_than_eq(0),
        "phi": constraints.greater_than_eq(0),
        "scale": constraints.greater_than_eq(0),
    }
    support = constraints.nonnegative_integer

    def __init__(
        self,
        mu: Optional[torch.Tensor] = None,
        phi: Optional[torch.Tensor] = None,
        scale: Optional[torch.Tensor] = None,
        validate_args: bool = False,
    ):
        self._eps = 1e-8
        mu, phi = broadcast_all(mu, phi)
        self.mu = mu
        self.phi = phi
        self.scale = scale
        super().__init__(validate_args=validate_args)

    @property
    def mean(self):
        return self.mu

    @property
    def variance(self):
        return self.mean + (self.mean**2) / self.phi

    @torch.inference_mode()
    def sample(
        self,
        sample_shape: Optional[Union[torch.Size, Tuple]] = None,
    ) -> torch.Tensor:
        """Sample from the distribution."""
        sample_shape = sample_shape or torch.Size()
        gamma_d = self._gamma()
        p_means = gamma_d.sample(sample_shape)
        l_train = torch.clamp(p_means, max=1e8)
        counts = PoissonTorch(
            l_train
        ).sample()  # Shape : (n_samples, n_cells_batch, n_vars)
        return counts
    
    @torch.inference_mode()
    def sample_lbd(
        self,
        sample_shape: Optional[Union[torch.Size, Tuple]] = None,
    ) -> torch.Tensor:
        """Sample from the distribution."""
        sample_shape = sample_shape or torch.Size()
        gamma_d = self._gamma()
        lbd = gamma_d.sample(sample_shape)
        lbd = torch.clamp(lbd, max=1e8)
        return lbd

    @torch.inference_mode()
    def sample_from_lbd(
        self,
        lbd: torch.Tensor,
    ) -> torch.Tensor:
        """Sample from lambda."""
        counts = PoissonTorch(
            lbd
        ).sample()  # Shape : (n_samples, n_cells_batch, n_vars)
        return counts

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        if self._validate_args:
            try:
                self._validate_sample(value)
            except ValueError:
                warnings.warn(
                    "The value argument must be within the support of the distribution",
                )

        return log_nb_positive(value, mu=self.mu, phi=self.phi, eps=self._eps)

    def _gamma(self):
        return _gamma(self.phi, self.mu)


class ZeroInflatedNegativeBinomial(NegativeBinomial):
    r"""Zero-inflated negative binomial distribution.
    In the (`mu`, `phi`) parameterization, samples from the negative binomial are generated as follows:

    1. :math:`w \sim \textrm{Gamma}(\underbrace{\phi}_{\text{shape}}, \underbrace{\phi/\mu}_{\text{rate}})`
    2. :math:`x \sim \textrm{Poisson}(w)`

    Parameters
    ----------
    mu
        Mean of the distribution.
    phi
        Inverse dispersion.
    pi_logits
        Logits scale of zero inflation probability.
    scale
        Normalized mean expression of the distribution.
    validate_args
        Raise ValueError if arguments do not match constraints
    """

    arg_constraints = {
        "mu": constraints.greater_than_eq(0),
        "phi": constraints.greater_than_eq(0),
        "pi_logits": constraints.real,
        "scale": constraints.greater_than_eq(0),
    }
    support = constraints.nonnegative_integer

    def __init__(
        self,
        mu: Optional[torch.Tensor] = None,
        phi: Optional[torch.Tensor] = None,
        pi_logits: Optional[torch.Tensor] = None,
        scale: Optional[torch.Tensor] = None,
        validate_args: bool = False,
    ):
        super().__init__(
            mu=mu,
            phi=phi,
            scale=scale,
            validate_args=validate_args,
        )
        self.pi_logits, self.mu, self.phi = broadcast_all(
            pi_logits, self.mu, self.phi
        )

    @property
    def mean(self):
        pi = self.pi_probs
        return (1 - pi) * self.mu

    @property
    def variance(self):
        raise NotImplementedError

    @lazy_property
    def pi_logits(self) -> torch.Tensor:
        return probs_to_logits(self.pi_probs, is_binary=True)

    @lazy_property
    def pi_probs(self) -> torch.Tensor:
        return logits_to_probs(self.pi_logits, is_binary=True)

    @torch.inference_mode()
    def sample(
        self,
        sample_shape: Optional[Union[torch.Size, Tuple]] = None,
    ) -> torch.Tensor:
        """Sample from the distribution."""
        sample_shape = sample_shape or torch.Size()
        samp = super().sample(sample_shape=sample_shape)
        is_zero = torch.rand_like(samp) <= self.pi_probs
        samp_ = torch.where(is_zero, torch.zeros_like(samp), samp)
        return samp_

    @torch.inference_mode()
    def sample_from_lbd(
        self,
        lbd: torch.Tensor,
    ) -> torch.Tensor:
        """Sample from lambda."""
        samp = super().sample_from_lbd(lbd)
        is_zero = torch.rand_like(samp) <= self.pi_probs
        samp_ = torch.where(is_zero, torch.zeros_like(samp), samp)
        return samp_

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        """Log probability."""
        try:
            self._validate_sample(value)
        except ValueError:
            warnings.warn(
                "The value argument must be within the support of the distribution",
            )
        return log_zinb_positive(value, self.mu, self.phi, self.pi_logits, eps=1e-08)