import torch.distributions as D
import numpy as np
import torch
import typing
from torch import nn

from flowcon.distributions.base import Distribution
from flowcon.utils import torchutils


class MOG(Distribution):
    """A mixture of Gaussians with given mean and covariance."""

    def __init__(self, means, stds, low: float = None, high: float = None):
        super().__init__()
        assert means.shape == stds.shape
        self._shape = torch.Size([means.shape[1]])
        self.n_components = means.shape[0]
        self.means = means
        self.stds = stds
        self._low = low if low is not None else None
        self._high = high if high is not None else None

        equal_components = torch.ones(
            self.n_components,
        ).to(means.device)
        mix = D.Categorical(equal_components)
        comp = D.Independent(D.Normal(self.means, self.stds), 1)
        self.gmm = D.MixtureSameFamily(mix, comp)

    def _log_prob(self, inputs, context):
        # Note: the context is ignored.
        if self._low is not None:
            assert torch.all(inputs.ge(self._low)), f"Some inputs are smaller than {self._low}"
        if self._high is not None:
            assert torch.all(inputs.le(self._high)), f"Some inputs are greater than {self._high}"

        if inputs.shape[1:] != self._shape:
            raise ValueError(f"Expected input of shape {self._shape}, got {inputs.shape[1:]}")

        return self.gmm.log_prob(inputs)

    def _sample(self, num_samples, context):
        if context is None:
            tot_samples = num_samples
        else:  # The value of the context is ignored, only its size and device are taken into account.
            context_size = context.shape[0]
            tot_samples = context_size * num_samples

        samples = self.gmm.sample(torch.Size([tot_samples]))
        if self._low is not None or self._high is not None:
            indeces = torch.all(samples.ge(self._low) * samples.le(self._high), dim=-1)
            samples = samples[indeces]
            while len(samples) < tot_samples:
                new_samples = self.gmm.sample(torch.Size([tot_samples // 5]))
                indeces = torch.all(new_samples.ge(self._low) * new_samples.le(self._high), dim=-1)
                new_samples = new_samples[indeces]
                samples = torch.cat((samples, new_samples), dim=0)

        samples = samples[:tot_samples]
        assert len(samples) == tot_samples

        if context is None:
            return samples
        else:
            return torchutils.split_leading_dim(samples, [context_size, num_samples])
