import torch
from robustopt_torch.funcutils import allequal
from functools import partial

# Really each distribution sampler should return an instance of a discrete distribution

def is_discrete(distribution):
    if isinstance(distribution, DiscreteDist):
        return True
    elif isinstance(distribution, MixtureDist):
        return all(map(is_discrete, distribution.distributions))
    return False

def get_particles_and_weights(distribution):
    if isinstance(distribution, DiscreteDist):
        return distribution.vals.detach().clone(), distribution.weights.detach().clone()
    elif isinstance(distribution, MixtureDist):
        part_and_wts = [get_particles_and_weights(dist) for dist in
                        distribution.distributions]
        parts, wts = zip(*part_and_wts)
        wts_lens = [len(weights) for weights in wts]
        corrected_wts = torch.repeat_interleave(distribution.weights,
                                                torch.as_tensor(wts_lens)) * torch.cat(wts)
        return torch.cat(parts), corrected_wts
    else:
        raise ValueError("Attempted to get particles for a non-discrete distribution")

def get_samps(distribution, batch_size = None):
    if batch_size == None:
        if not is_discrete(distribution):
            raise ValueError("Batch size is invalid!")
        return get_particles_and_weights(distribution)
    return distribution.sample(batch_size), None

class DiscreteDist:
    def __init__(self, vals, weights = None, generator = None,
                 sampling_kernel = None, log_sampling_density = None):
        self.vals = vals.detach().clone()
        self.generator = generator
        self.sampling_kernel = sampling_kernel
        self.log_sampling_density = log_sampling_density
        if weights is not None:
            self.weights = weights.detach().clone()
            if (self.weights < 0.0).any():
                raise ValueError("Weights cannot be negative!")
            if torch.abs(self.weights.sum() - 1.0) > 1e-14:
                raise ValueError("Weights must sum to 1!")
            if self.weights.numel() != self.vals.shape[0]:
                raise ValueError("Must have compatible dimensions with values")
        else:
            self.weights = torch.ones(self.vals.shape[0])

        self.weights = self.weights / self.weights.sum()

    def moment(self, k):
        return torch.matmul(self.vals ** k, weights).sum()

    def sample(self, k = 1, out = None):
        if isinstance(k, tuple):
            k, *rest = k
            if len(rest) != 0:
                raise ValueError("Number of samples must be a single scalar number")

        if k == float("inf"):
            s_indexes = torch.randperm(len(self.vals), generator = self.generator)
        elif allequal(self.weights):
            s_indexes = torch.randint(self.weights.numel(), (k,), generator =
                                      self.generator)
        else:
            s_indexes = torch.multinomial(self.weights, k, replacement=True,
                                          generator=self.generator)

        samples = torch.index_select(self.vals, 0, s_indexes, out = out)
        return samples if self.sampling_kernel is None else self.sampling_kernel(samples, generator = self.generator)

    def log_density(self, x):
        if self.log_sampling_density is None:
            raise ValueError("Discrete distribution has no sampling density.")
        return torch.logsumexp(self.log_sampling_density(x, self.vals) +
                               self.weights.log())

    def __getitem__(self, key):
        return self.vals[key] if self.sampling_kernel is None else self.sampling_kernel(self.vals[key], generator = self.generator)

    def __setitem__(self, key, value):
        self.vals[key] = value

class MixtureDist:
    def __init__(self, *distributions, weights = None, generator = None):
        if not allequal(dist.sample((1,)).squeeze().shape for dist in distributions):
            raise ValueError("The ambient spaces of the distributions are not "
                             "the same.")

        self.distributions = distributions
        self.generator = generator
        if weights is not None:
            self.weights = weights.detach().clone()
            if (self.weights < 0.0).any():
                raise ValueError("Weights cannot be negative!")
            if self.weights.numel() != len(distributions):
                raise ValueError("The number of distributions must match the " \
                                 "number of weights!")
        else:
            self.weights = torch.ones(len(distributions))

        self.weights = self.weights / self.weights.sum()

    def moment(self, k):
        return sum(weight * dist.moment(k) for weight, dist in zip(self.weights,
                                                                   self.distributions))

    def sample(self, k = 1, out = None):
        if isinstance(k, tuple):
            k, *rest = k
            if len(rest) != 0:
                raise ValueError("Number of samples must be a single scalar number")

        if not self.distributions:
            raise ValueError("Mixture contains no distributions to sample from!")

        if k == float("inf"):
            if not is_discrete(self):
                raise ValueError("Requested a random permutation of all particles, " \
                                 "but the distribution is not discrete!")
            parts, _ = get_particles_and_weights(self)
            return parts[torch.randperm(len(parts))]

        if allequal(self.weights):
            s_indexes = torch.randint(self.weights.numel(), (k,), generator =
                                      self.generator)
        else:
            s_indexes = torch.multinomial(self.weights, k, replacement=True,
                                          generator=self.generator)

        samp = self.distributions[0].sample((1,))
        if out is None: out = torch.empty((k, *samp.squeeze().shape), dtype =
                                          samp.dtype)
        for i, dist in enumerate(self.distributions):
            out[s_indexes == i] = dist.sample(((s_indexes == i).sum(),))
        return out

class KernelParam:
    def __init__(self, param):
        self.param = params

class KDEDist:
    def __init__(self, vals, kernel, weights = None,
                 variable_args = {}, generator = None):
        self.vals = vals.detach().clone()
        self.kernel = kernel
        self.variable_args = variable_args
        self.generator = generator
        if weights is not None:
            self.weights = weights.detach().clone()
            if (self.weights < 0.0).any():
                raise ValueError("Weights cannot be negative!")
            if torch.abs(self.weights.sum() - 1.0) > 1e-14:
                raise ValueError("Weights must sum to 1!")
            if self.weights.numel() != self.vals.shape[0]:
                raise ValueError("Must have compatible dimensions with values")
        else:
            self.weights = torch.ones(self.vals.shape[0])

        self.weights = self.weights / self.weights.sum()

    def moment(self, k):
        raise NotImplementedError("KDE moments must specialize on the kernel. " \
                                  "It is currently not implemented.")

    def sample(self, k = 1, out = None):
        if isinstance(k, tuple):
            k, *rest = k
            if len(rest) != 0:
                raise ValueError("Number of samples must be a single scalar number")

        if k == float("inf"):
            return self.vals[torch.randperm(len(self.vals))].detach().clone()

        if allequal(self.weights):
            s_indexes = torch.randint(self.weights.numel(), (k,), generator =
                                      self.generator)
        else:
            s_indexes = torch.multinomial(self.weights, k, replacement=True,
                                          generator=self.generator)
        return torch.index_select(self.vals, 0, s_indexes, out = out)


class MultivariateNormalDist:
    def __init__(self, mean, covariance, generator = None):
        self.mean = mean.detach().clone()
        self.covariance = covariance.detach().clone()
        self._tril = torch.cholesky(self.covariance)
        self.generator = generator

    def sample(self, k = 1, out = None):
        if isinstance(k, tuple):
            k, *rest = k
            if len(rest) != 0:
                raise ValueError("Number of samples must be a single scalar number")

        formatted_mean = torch.tile(self.mean.squeeze().unsqueeze(0),
                                    (k, 1)).unsqueeze(-1)
        return torch.matmul(self._tril, torch.normal(formatted_mean,
                                                     torch.ones_like(formatted_mean),
                                                     generator =
                                                     self.generator), out =
                                                     out).squeeze_()
