import numpy as np
from abc import abstractmethod, ABCMeta


def numpy_softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    unnormalized = np.exp(x - x_max)
    result = unnormalized / np.sum(unnormalized, axis, keepdims=True)
    return result


class NumpyDistribution(object, metaclass=ABCMeta):
    def __init__(self, seed):
        self.seed = seed
        self.np_rng = np.random.default_rng(seed)

    def change_seed(self, seed):
        self.np_rng = np.random.default_rng(seed)

    @abstractmethod
    def __call__(self, size=None):
        pass


class NormalDistribution(NumpyDistribution):
    def __init__(self,
                 loc, scale,
                 seed=42):
        super().__init__(seed=seed)
        self.loc = loc
        self.scale = scale

    def __call__(self, size=None):
        return self.np_rng.normal(loc=self.loc, scale=self.scale, size=size)

    @property
    def mean(self):
        return self.loc


class LogNormalDistribution(NormalDistribution):
    def __call__(self, size=None):
        return np.exp(super().__call__(size=size))

    @property
    def mean(self):
        return np.exp(self.loc + self.scale ** 2 / 2 )

    @staticmethod
    def get_params(mean, mode):
        assert mean > mode
        log_mod = np.log(mode)
        sigma_square = 2 / 3 * (np.log(mean) - log_mod)
        mean = log_mod + sigma_square
        return mean, np.sqrt(sigma_square)


class Categorical(NumpyDistribution):
    def __init__(self,
                 probs=None,
                 logits=None,
                 seed=42
                 ):
        super().__init__(seed=seed)

        if (logits is None and probs is None) or (probs is not None and logits is not None):
            raise ValueError("only one of probs or logits should be provided")

        if probs is None:
            self.prob = numpy_softmax(logits, axis=-1)
        else:
            self.prob = probs

        self.support = np.arange(len(self.prob))

    def __call__(self, size=None):
        return self.np_rng.choice(self.support, size=size, p=self.prob)


class MixtureDistribution(NumpyDistribution):
    def __init__(self,
                 distributions,
                 probs=None,
                 logits=None,
                 seed=42
                 ):
        super().__init__(seed)
        self.distributions = distributions
        self.categorical = Categorical(probs=probs, logits=logits, seed=seed+1)

    def __call__(self, size=None):
        samples = [d(size=size) for d in self.distributions]
        index = self.categorical(size=size)
        return samples[index]

    @property
    def mean(self):
        probs = self.categorical.prob
        mus = np.asarray([d.mean for d in self.distributions])
        return (mus * probs).sum()

