import numpy as np


class DistributionSampler:

    def __init__(self, sample_func, lb, ub):
        self.sample_func = sample_func
        self.lb = lb
        self.ub = ub
        self.dim = self.sample_func(n=1).shape[1]

        n_samples = int(1e6)
        samples_not_ok = np.ones(n_samples, dtype=np.bool)
        samples = np.zeros((n_samples, self.dim))
        while np.any(samples_not_ok):
            n_new_samples = np.sum(samples_not_ok)
            samples[samples_not_ok] = self.sample_func(n=n_new_samples)
            samples_not_ok = np.logical_or(np.any(samples < lb, axis=-1),
                                           np.any(samples > ub, axis=-1))

        self.mu = np.mean(samples, axis=0)
        diffs = (samples - self.mu[None, :]) / np.sqrt(samples.shape[0])
        self.cov = np.einsum("ni,nj->ij", diffs, diffs)

    def sample(self):
        ok = False
        while not ok:
            sample = self.sample_func(n=1)[0, :]
            ok = np.all(sample >= self.lb) and np.all(sample <= self.ub)
        return sample

    def mean(self):
        return self.mu

    def covariance_matrix(self):
        return self.cov


class GaussianSampler:

    def __init__(self, mean, variance, bounds):
        self.bounds = bounds
        self._mean = mean
        self.covariance = variance if isinstance(variance, np.ndarray) else np.eye(self._mean.shape[0]) * variance

    def sample(self):
        return np.clip(np.random.multivariate_normal(self._mean, self.covariance), self.bounds[0], self.bounds[1])

    def mean(self):
        return self._mean.copy()

    def covariance_matrix(self):
        return self.covariance.copy()


class DiscreteSampler:

    def __init__(self, log_likelihoods):
        self.likelihoods = np.exp(log_likelihoods)

    def sample(self):
        return np.argmax(np.random.uniform(0., 1.) <= np.cumsum(self.likelihoods))


class UniformSampler:

    def __init__(self, lower_bound, upper_bound):
        self.lower_bound = lower_bound
        self.upper_bound = upper_bound

    def sample(self):
        norm_sample = np.random.uniform(low=-1, high=1, size=self.lower_bound.shape)
        return self._scale_context(norm_sample)

    def mean(self):
        return 0.5 * self.lower_bound + 0.5 * self.upper_bound

    def covariance_matrix(self):
        return np.diag((0.5 * (self.upper_bound - self.lower_bound)) ** 2)

    def _scale_context(self, context):
        b = 0.5 * (self.upper_bound + self.lower_bound)
        m = 0.5 * (self.upper_bound - self.lower_bound)
        return m * context + b
