"""Implementations of Normal distributions."""

import numpy as np
import torch
import typing
from torch import nn

from enflows.distributions.base import Distribution
from enflows.utils import torchutils


class StandardNormal(Distribution):
    """A multivariate Normal with zero mean and unit covariance."""

    def __init__(self, shape):
        super().__init__()
        self._shape = torch.Size(shape)

        self.register_buffer("_log_z",
                             torch.tensor(0.5 * np.prod(shape) * np.log(2 * np.pi),
                                          dtype=torch.float64),
                             persistent=False)

    def _log_prob(self, inputs, context):
        # Note: the context is ignored.
        if inputs.shape[1:] != self._shape:
            raise ValueError(
                "Expected input of shape {}, got {}".format(
                    self._shape, inputs.shape[1:]
                )
            )
        neg_energy = -0.5 * \
            torchutils.sum_except_batch(inputs ** 2, num_batch_dims=1)
        return neg_energy - self._log_z

    def _sample(self, num_samples, context):
        if context is None:
            return torch.randn(num_samples, *self._shape, device=self._log_z.device)
        else:
            # The value of the context is ignored, only its size and device are taken into account.
            context_size = context.shape[0]
            samples = torch.randn(context_size * num_samples, *self._shape,
                                  device=context.device)
            return torchutils.split_leading_dim(samples, [context_size, num_samples])

    def _mean(self, context):
        if context is None:
            return self._log_z.new_zeros(self._shape)
        else:
            # The value of the context is ignored, only its size is taken into account.
            return context.new_zeros(context.shape[0], *self._shape)


class ConditionalDiagonalNormal(Distribution):
    """A diagonal multivariate Normal whose parameters are functions of a context."""

    def __init__(self, shape, context_encoder=None):
        """Constructor.

        Args:
            shape: list, tuple or torch.Size, the shape of the input variables.
            context_encoder: callable or None, encodes the context to the distribution parameters.
                If None, defaults to the identity function.
        """
        super().__init__()
        self._shape = torch.Size(shape)
        if context_encoder is None:
            self._context_encoder = lambda x: x
        else:
            self._context_encoder = context_encoder
        self.register_buffer("_log_z",
                             torch.tensor(0.5 * np.prod(shape) * np.log(2 * np.pi),
                                          dtype=torch.float64),
                             persistent=False)

    def _compute_params(self, context):
        """Compute the means and log stds form the context."""
        if context is None:
            raise ValueError("Context can't be None.")

        params = self._context_encoder(context)
        if params.shape[-1] % 2 != 0:
            raise RuntimeError(
                "The context encoder must return a tensor whose last dimension is even."
            )
        if params.shape[0] != context.shape[0]:
            raise RuntimeError(
                "The batch dimension of the parameters is inconsistent with the input."
            )

        split = params.shape[-1] // 2
        means = params[..., :split].reshape(params.shape[0], *self._shape)
        log_stds = params[..., split:].reshape(params.shape[0], *self._shape)
        return means, log_stds

    def _log_prob(self, inputs, context):
        if inputs.shape[1:] != self._shape:
            raise ValueError(
                "Expected input of shape {}, got {}".format(
                    self._shape, inputs.shape[1:]
                )
            )

        # Compute parameters.
        means, log_stds = self._compute_params(context)
        assert means.shape == inputs.shape and log_stds.shape == inputs.shape

        # Compute log prob.
        norm_inputs = (inputs - means) * torch.exp(-log_stds)
        log_prob = -0.5 * torchutils.sum_except_batch(
            norm_inputs ** 2, num_batch_dims=1
        )
        log_prob -= torchutils.sum_except_batch(log_stds, num_batch_dims=1)
        log_prob -= self._log_z
        return log_prob

    def _sample(self, num_samples, context):
        # Compute parameters.
        means, log_stds = self._compute_params(context)
        stds = torch.exp(log_stds)
        means = torchutils.repeat_rows(means, num_samples)
        stds = torchutils.repeat_rows(stds, num_samples)

        # Generate samples.
        context_size = context.shape[0]
        noise = torch.randn(context_size * num_samples, *
                            self._shape, device=means.device)
        samples = means + stds * noise
        return torchutils.split_leading_dim(samples, [context_size, num_samples])

    def _mean(self, context):
        means, _ = self._compute_params(context)
        return means


class DiagonalNormal(Distribution):
    """A diagonal multivariate Normal with trainable parameters."""

    def __init__(self, shape):
        """Constructor.

        Args:
            shape: list, tuple or torch.Size, the shape of the input variables.
            context_encoder: callable or None, encodes the context to the distribution parameters.
                If None, defaults to the identity function.
        """
        super().__init__()
        self._shape = torch.Size(shape)
        self.mean_ = nn.Parameter(torch.zeros(shape).reshape(1, -1))
        self.log_std_ = nn.Parameter(torch.zeros(shape).reshape(1, -1))
        self.register_buffer("_log_z",
                             torch.tensor(0.5 * np.prod(shape) * np.log(2 * np.pi),
                                          dtype=torch.float64),
                             persistent=False)

    def _log_prob(self, inputs, context):
        if inputs.shape[1:] != self._shape:
            raise ValueError(
                "Expected input of shape {}, got {}".format(
                    self._shape, inputs.shape[1:]
                )
            )

        # Compute parameters.
        means = self.mean_
        log_stds = self.log_std_

        # Compute log prob.
        norm_inputs = (inputs - means) * torch.exp(-log_stds)
        log_prob = -0.5 * torchutils.sum_except_batch(
            norm_inputs ** 2, num_batch_dims=1
        )
        log_prob -= torchutils.sum_except_batch(log_stds, num_batch_dims=1)
        log_prob -= self._log_z
        return log_prob

    def _sample(self, num_samples, context):
        raise NotImplementedError()

    def _mean(self, context):
        return self.mean

import torch.distributions as D

class MOG(Distribution):
    """A mixture of Gaussians with given mean and covariance."""

    def __init__(self, means, stds, low:float=None, high:float=None):
        super().__init__()
        assert means.shape == stds.shape
        self._shape = torch.Size([means.shape[1]])
        self.n_components = means.shape[0]
        self.means = means
        self.stds = stds
        if low is not None: self._low = low
        if high is not None: self._high = high

        equal_components = torch.ones(self.n_components, ).to(means.device)
        mix = D.Categorical(equal_components)
        comp = D.Independent(D.Normal(self.means, self.stds), 1)
        self.gmm = D.MixtureSameFamily(mix, comp)

    def _log_prob(self, inputs, context):
        # Note: the context is ignored.
        if self._low is not None: assert torch.all(inputs.ge(self._low)), f"Some inputs are smaller than {self._low}"
        if self._high is not None: assert torch.all(inputs.le(self._high)), f"Some inputs are greater than {self._high}"

        if inputs.shape[1:] != self._shape:
            raise ValueError(f"Expected input of shape {self._shape}, got {inputs.shape[1:]}")

        return self.gmm.log_prob(inputs)

    def _sample(self, num_samples, context):
        if context is None:
            tot_samples = num_samples
        else: # The value of the context is ignored, only its size and device are taken into account.
            context_size = context.shape[0]
            tot_samples = context_size * num_samples

        samples = self.gmm.sample(torch.Size([tot_samples]))
        if self._low is not None or self.high is not None:
            indeces = torch.all(samples.ge(self._low) * samples.le(self._high), dim=-1)
            samples = samples[indeces]
            while len(samples) < tot_samples:
                new_samples = self.gmm.sample(torch.Size([tot_samples // 5]))
                indeces = torch.all(new_samples.ge(self._low) * new_samples.le(self._high), dim=-1)
                new_samples = new_samples[indeces]
                samples = torch.cat((samples, new_samples), dim=0)

        samples = samples[:tot_samples]
        assert len(samples) == tot_samples

        if context is None:
            return samples
        else:
            return torchutils.split_leading_dim(samples, [context_size, num_samples])


    def _mean(self, context):
        NotImplementedError
        # if context is None:
        #     return self._log_z.new_zeros(self._shape)
        # else:
        #     # The value of the context is ignored, only its size is taken into account.
        #     return context.new_zeros(context.shape[0], *self._shape)
