from typing import Optional

import torch
import torch.distributions as tdist
import torch.nn as nn

from vae.layers.misc import BaseModule


def get_distribution(dist_type, **kwargs):
    if dist_type == 'normal':
        dist = Normal(**kwargs)
    elif dist_type == 'categorical':
        dist = Categorical(**kwargs)
    elif dist_type == 'bernoulli':
        dist = Bernoulli(**kwargs)
    elif dist_type == 'sigmoid':
        dist = Sigmoid(**kwargs)
    elif dist_type == 'laplace':
        dist = Laplace(**kwargs)
    else:
        raise ValueError('Illegal distribution type.')

    return dist


class Distribution(BaseModule):
    """
    Defines generic distribution, e.g., one of Gaussian or Categorical type.
    """

    def __init__(self,
                 dist_type,
                 scale_regularizer=None):
        """
        :param dist_type: type of output distribution
        """
        super().__init__()
        self.dist_type = dist_type
        self.scale_regularizer = scale_regularizer

    def forward(self, *args, **kwargs):
        distribution = self._infer_distribution(*args, **kwargs)
        distribution['type'] = self.dist_type
        return distribution

    def _infer_distribution(self, *args, **kwargs):
        raise NotImplementedError('Define in subclass.')


class LocScaleDistribution(Distribution):
    """
    Assumes a distribution that has a mean and a scale parameter, such as a
    normal or a laplace distribution.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Define in subclass
        self.dist = None

    def _infer_distribution(self, inp, temp=1., **kwargs):
        mu, scale = self._process_input(inp, temp)

        if self.scale_regularizer:
            scale = self.scale_regularizer(scale)

        dist = self.dist(mu, scale)
        samples = self._sample(dist, **kwargs)
        distribution = {'dist': dist,
                        'samples': samples}

        return distribution

    def _process_input(self, inp, temp):
        if isinstance(inp, list):
            # parameters have already been inferred
            mu, scale = inp
        else:
            # input are logits
            mu, scale = self._separate_logits(inp)
        if temp != 1.:
            scale *= temp
        return mu, scale

    @staticmethod
    def _separate_logits(logits):
        length = len(logits.size())
        if length == 2 or length == 3:
            # coming from dense layer
            mu, scale = logits.chunk(2, dim=-1)
        elif length == 4 or length == 5:
            # coming from convolutional layer
            mu, scale = logits.chunk(2, dim=-3)
        else:
            raise ValueError('Illegal logit length.')

        return mu, scale

    @staticmethod
    def _sample(dist,
                k: Optional[int] = None,
                use_mode: bool = False) -> torch.Tensor:
        """ Sample from parameterized distribution.

        :param dist: torch distribution object
        :param k: number of importance samples
        :param use_mode: whether to sample from mode
        :return: samples
        """
        if use_mode:
            msg = 'No point to use importance samples if all samples are ' \
                  'identical'
            assert k is None, msg

        if k is None:
            samples = dist.mean if use_mode else dist.rsample()
        else:
            samples = dist.rsample(torch.Size([k]))

        return samples


class Normal(LocScaleDistribution):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, dist_type='normal', **kwargs)
        self.dist = tdist.Normal


class Laplace(LocScaleDistribution):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, dist_type='laplace', **kwargs)
        self.dist = tdist.Laplace


class Categorical(Distribution):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, dist_type='categorical', **kwargs)

    def _infer_distribution(self, logits):
        # this layer is typically used as reconstruction layer, where
        # samples are typically not used
        distribution = {'dist': tdist.Categorical(logits=logits),
                        'samples': None}

        return distribution


class Bernoulli(Distribution):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, dist_type='bernoulli', **kwargs)

    def _infer_distribution(self, logits):
        # this layer is typically used as reconstruction layer, where
        # samples are typically not used
        distribution = {'dist': tdist.Bernoulli(logits=logits),
                        'samples': None}

        return distribution


class Sigmoid(Distribution):
    """
    Stochastic layer produces sigmoid-activated values. Those can then,
    for example, be used to calculate binary cross-entropy.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, dist_type='sigmoid', **kwargs)
        self.sigmoid = nn.Sigmoid()

    def _infer_distribution(self, logits):
        samples = self.sigmoid(logits)
        distribution = {'dist': None,
                        'samples': samples}
        return distribution
