""" Packages import """
import numpy as np
from scipy.stats import truncnorm as trunc_norm
from scipy.stats import norm #, laplace_asymmetric
from .utils import convert_tg_mean


class AbstractArm(object):
    def __init__(self, mean, variance, random_state):
        """
        :param mean: float, expectation of the arm
        :param variance: float, variance of the arm
        :param random_state: int, seed to make experiments reproducible
        """
        self.mean = mean
        self.variance = variance
        self.local_random = np.random.RandomState(random_state)

    def sample(self, size=1):
        pass


class ArmBernoulli(AbstractArm):
    def __init__(self, p, random_state=0):
        """
        :param p: float, mean parameter
        :param random_state: int, seed to make experiments reproducible
        """
        self.p = p
        super(ArmBernoulli, self).__init__(mean=p,
                                           variance=p * (1. - p),
                                           random_state=random_state
                                           )

    def sample(self, size=1):
        """
        Sampling strategy
        :return: float, a sample from the arm
        """
        return (self.local_random.rand(size) < self.p)*1.

class ArmRademacher(AbstractArm):
    def __init__(self, mu, random_state=0):
        """
        :param p: float, mean parameter
        :param random_state: int, seed to make experiments reproducible
        """
        self.mu = mu
        self.p = (1+mu)/2
        super(ArmRademacher, self).__init__(mean=mu,
                                           variance=1-mu**2,
                                           random_state=random_state
                                           )

    def sample(self, size=1):
        """
        Sampling strategy
        :return: float, a sample from the arm
        """
        is_1 = (self.local_random.rand(size) < self.p) * 1.
        return is_1 - (1-is_1)


class ArmBeta(AbstractArm):
    def __init__(self, a, b, random_state=0):
        """
        :param a: int, alpha coefficient in beta distribution
        :param b: int, beta coefficient in beta distribution
        :param random_state: int, seed to make experiments reproducible
        """
        self.a = a
        self.b = b
        super(ArmBeta, self).__init__(mean=a / (a + b),
                                      variance=(a * b) / ((a + b) ** 2 * (a + b + 1)),
                                      random_state=random_state
                                      )

    def sample(self, size=1):
        """
        Sampling strategy
        :return: float, a sample from the arm
        """
        return self.local_random.beta(self.a, self.b, size)


class ArmGaussian(AbstractArm):
    def __init__(self, mu, eta, random_state=0):
        """
        :param mu: float, mean parameter in Gaussian distribution
        :param eta: float, std parameter in Gaussian distribution
        :param random_state: int, seed to make experiments reproducible
        """
        self.mu = mu
        self.eta = eta
        super(ArmGaussian, self).__init__(mean=mu,
                                          variance=eta**2,
                                          random_state=random_state
                                          )

    def sample(self, size=1):
        """
        Sampling strategy
        :return: float, a sample from the arm
        """
        return self.local_random.normal(self.mu, self.eta, size)


class ArmLogGaussian(AbstractArm):
    def __init__(self, mu, eta, alpha=None, random_state=0):
        """
        :param mu: float, mean parameter in log-Gaussian distribution
        :param eta: float, std parameter in log-Gaussian distribution
        :param alpha: float or None, risk aversion level
        :param random_state: int, seed to make experiments reproducible
        """
        self.mu = mu
        self.eta = eta

        super(ArmLogGaussian, self).__init__(mean=np.exp(mu + 0.5 * eta ** 2),
                                             variance=(np.exp(eta ** 2) - 1) * np.exp(2 * mu + eta ** 2),
                                             alpha=alpha,
                                             random_state=random_state,
                                             )

    def sample(self, size=1):
        """
        Sampling strategy
        :param N: int, sample size
        :return: float, a sample from the arm
        """
        return self.local_random.lognormal(self.mu, self.eta, size)


class ArmMultinomial(AbstractArm):
    def __init__(self, X, P, random_state=0):
        """
        :param X: np.array, support of the distribution
        :param P: np.array, associated probabilities
        :param random_state: int, seed to make experiments reproducible
        """
        assert np.min(P) >= 0.0, 'p should be nonnegative.'
        assert np.isclose(np.sum(P), 1.0), 'p should should sum to 1.'

        self.X = np.array(X)
        self.P = np.array(P)

        mean = np.dot(self.X, self.P)
        super(ArmMultinomial, self).__init__(mean=mean,
                                             variance=np.dot(self.X ** 2, self.P) - mean ** 2,
                                             random_state=random_state
                                             )

    def sample(self, size=1):
        """
        Sampling strategy for an arm with a finite support and the associated probability distribution
        :return: float, a sample from the arm
        """
        i = self.local_random.choice(len(self.P), size=size, p=self.P)
        reward = self.X[i]
        return reward


class ArmEmpirical(AbstractArm):
    def __init__(self, X, random_state=0):
        """
        Same a ArmMultinomial but with uniform probability.
        Allows for faster sampling using randint rather than choice
        (~x4 speed up).
        :param X: np.array, support of the distribution
        :param random_state: int, seed to make experiments reproducible
        """
        self.X = np.array(X)
        self.n = len(self.X)

        mean = np.mean(self.X)
        super(ArmEmpirical, self).__init__(mean=mean,
                                           variance=np.var(self.X),
                                           random_state=random_state
                                           )

    def sample(self, size=1):
        """
        Sampling strategy for an arm with a finite support and the associated probability distribution
        :return: float, a sample from the arm
        """
        i = self.local_random.randint(len(self.X), size=size)
        reward = self.X[i]
        return reward


class ArmExponential(AbstractArm):
    def __init__(self, p, random_state=0):
        """
        :param p: float, parameter in exponential distribution
        :param random_state: int, seed to make experiments reproducible
        """
        self.p = p
        super(ArmExponential, self).__init__(mean=p,
                                             variance=p**2,
                                             random_state=random_state
                                             )

    def sample(self, size=1):
        """
        Sampling strategy
        :return: float, a sample from the arm
        """
        return self.local_random.exponential(self.p, size)


class ArmDirac():
    def __init__(self, c, random_state):
        """
        :param c: float, location of the mass
        :param random_state: int, seed to make experiments reproducible
        """
        self.mean = c
        self.variance = 0
        self.local_random = np.random.RandomState(random_state)

    def sample(self, size=1):
        return self.mean * np.ones(size)


class ArmTG(AbstractArm):
    def __init__(self, mu, scale, random_state=0):
        """
        :param mu: float, mean of the untruncated Gaussian
        :param scale: float, std of the untruncated Gaussian
        :param random_state: int, seed to make experiments reproducible
        """
        self.mu = mu
        self.scale = scale
        self.dist = trunc_norm(-mu/scale, b=(1-mu)/scale, loc=mu, scale=scale)
        self.dist.random_state = random_state
        super(ArmTG, self).__init__(mean=convert_tg_mean(mu, scale),
                                    variance=scale**2,
                                    random_state=random_state
                                    )

    def sample(self, size=1):
        """
        Sampling strategy
        :return: float, a sample from the arm
        """
        x = self.local_random.normal(self.mu, self.scale, size)
        return x * (x > 0) * (x < 1) + (x > 1)


class ArmPoisson(AbstractArm):
    def __init__(self, p, random_state=0):
        """
        :param p: float, Poisson parameter
        :param random_state: int, seed to make experiments reproducible
        """
        self.p = p
        super(ArmPoisson, self).__init__(mean=p,
                                         variance=None,
                                         random_state=random_state
                                         )

    def sample(self, size):
        """
        Sampling strategy
        :return: float, a sample from the arm
        """
        return self.local_random.poisson(self.p, size=size)


class ArmNegativeExponential(AbstractArm):
    def __init__(self, p, random_state=0):
        """
        :param p: float, parameter in exponential distribution
        :param random_state: int, seed to make experiments reproducible
        """
        self.p = p
        super(ArmNegativeExponential, self).__init__(mean=-p,
                                                     variance=p**2,
                                                     random_state=random_state
                                                     )

    def sample(self, size=1):
        """
        Sampling strategy
        :return: float, a sample from the arm
        """
        return -self.local_random.exponential(self.p, size)


class ArmPareto(AbstractArm):
    def __init__(self, alpha, random_state=0):
        """
        :param alpha: float, exponent in Pareto distribution
        :param random_state: int, seed to make experiments reproducible
        """
        self.alpha = alpha
        super(ArmPareto, self).__init__(mean=alpha / (alpha - 1),
                                        variance=None,
                                        random_state=random_state
                                        )

    def sample(self, size=1):
        """
        Sampling strategy
        :return: float, a sample from the arm
        """
        return self.local_random.pareto(self.alpha, size)


# class ArmGaussianMixture(AbstractArm):
#     def __init__(self, p, means, sigmas, random_state=0):
#         """
#         :param p: array, probability of each Gaussian in the mixture
#         :param means: array, means for each Gaussian in the mixture
#         :param sigmas: array, stds for each Gaussian in the mixture
#         :param random_state: int, seed to make experiments reproducible
#         """
#         assert np.min(p) >= 0.0, 'p should be nonnegative.'
#         assert np.isclose(np.sum(p), 1.0), 'p should should sum to 1.'
#         self.p = np.array(p)
#         self.means = np.array(means)
#         self.sigmas = np.array(sigmas)
#         self.mean = np.inner(self.p, self.means)
#         var = np.inner(self.p, self.sigmas**2) + np.inner(self.p, (self.means-self.mean)**2)
#         print(p, means, sigmas)
#         super(ArmGaussianMixture, self).__init__(mean=mean,
#                                                  variance=var,
#                                                  random_state=random_state
#                                                  )
#
#     def sample(self, size=1):
#         """
#         Sampling strategy
#         :return: float, a sample from the arm
#         """
#         i = self.local_random.choice(np.arange(self.p.shape[0]), p=self.p, size=size)
#         return self.local_random.normal(loc=self.means[i], scale=self.sigmas[i], size=size)

class ArmGaussianMixture:
    def __init__(self, p, means, sigmas, random_state=0):
        """
        :param p: array, probability of each Gaussian in the mixture
        :param means: array, means for each Gaussian in the mixture
        :param sigmas: array, stds for each Gaussian in the mixture
        :param random_state: int, seed to make experiments reproducible
        """
        assert np.min(p) >= 0.0, 'p should be nonnegative.'
        assert np.isclose(np.sum(p), 1.0), 'p should should sum to 1.'
        self.p = np.array(p)
        self.means = np.array(means)
        self.sigmas = np.array(sigmas)
        self.mean = np.inner(self.p, self.means)
        self.variance = np.inner(self.p, self.sigmas**2) + np.inner(self.p, (self.means-self.mean)**2)
        self.local_random = np.random.RandomState(random_state)

    def sample(self, size=1):
        """
        Sampling strategy
        :return: float, a sample from the arm
        """
        i = self.local_random.choice(np.arange(self.p.shape[0]), p=self.p, size=size)
        return self.local_random.normal(loc=self.means[i], scale=self.sigmas[i], size=size)

class ArmUniform(AbstractArm):
    def __init__(self, low, high, random_state=0):
        """
        :param low: float, lower bound of support
        :param high: float, upper bound of support
        :param random_state: int, seed to make experiments reproducible
        """
        self.low = low
        self.high = high
        super(ArmUniform, self).__init__(mean=0.5 * (low + high),
                                         variance=1 / 12 * (high - low) ** 2,
                                         random_state=random_state
                                         )

    def sample(self, size=1):
        """
        Sampling strategy
        :param N: int, sample size
        :return: float, a sample from the arm
        """
        return self.local_random.uniform(low=self.low, high=self.high, size=size)


class ArmGaussianMisspecifiedParetoTail(AbstractArm):
    def __init__(self, mu, eta, breakpoint, alpha, misspecified, random_state=0):
        """
        :param mu: float, mean parameter in Gaussian distribution
        :param eta: float, std parameter in Gaussian distribution
        :param breakpoint: float, point in the support after which
            the tail changes to Pareto
        :param alpha: float, parameter for the Pareto tail
        :param misspecified: bool, whether or not to contaminate the tail
        :param random_state: int, seed to make experiments reproducible
        """
        self.mu = mu
        self.eta = eta
        self.breakpoint = breakpoint
        self.alpha = alpha
        self.misspecified = misspecified
        if not self.misspecified:
            super(ArmGaussianMisspecifiedParetoTail, self).__init__(mean=self.mu,
                                                                variance=eta**2,
                                                                random_state=random_state
                                                                )
        else:
            z = (self.breakpoint - self.mu) / self.eta
            phi = 1 / (np.sqrt(2 * np.pi) * self.eta) * np.exp(-z ** 2 / 2)
            self.Phi = norm.cdf(z)
            self.c = self.breakpoint / self.alpha * phi
            self.mean = self.mu * self.Phi - self.eta * phi + self.c * self.alpha / (self.alpha - 1) * self.breakpoint
            super(ArmGaussianMisspecifiedParetoTail, self).__init__(mean=self.mean,
                                                                variance=eta**2,  # variance is not eta**2, TODO later
                                                                random_state=random_state

                                                                )
    def sample(self, size=1):
        """
        Sampling strategy
        :return: float, a sample from the arm
        """
        if not self.misspecified:
            return self.local_random.normal(self.mu, self.eta, size)
        else:
            # Randomly mix proportionally to the respective mass of
            # the Gaussian and the Pareto.
            ind = (self.local_random.rand(size) < self.Phi / (self.Phi + self.c)) * 1
            return (
                self.local_random.normal(self.mu, self.eta, size) * ind
                + self.local_random.pareto(self.alpha, size) * (1 - ind)
                )


class ArmGaussianMisspecifiedExpTail(AbstractArm):
    def __init__(self, mu, eta, breakpoint, beta, misspecified, random_state=0):
        """
        :param mu: float, mean parameter in Gaussian distribution
        :param eta: float, std parameter in Gaussian distribution
        :param breakpoint: float, point in the support after which
            the tail changes to Exponential
        :param beta: float, rate for the exponential tail 1/beta*exp(-(x-m)/beta)
        :param m: float, location for the exponential tail 1/beta*exp(-(x-m)/beta)
        :param misspecified: bool, whether or not to contaminate the tail
        :param random_state: int, seed to make experiments reproducible
        """
        self.mu = mu
        self.eta = eta
        self.breakpoint = breakpoint
        self.beta = beta
        self.misspecified = misspecified
        if not self.misspecified:
            super(ArmGaussianMisspecifiedExpTail, self).__init__(mean=self.mu,
                                                             variance=eta**2,
                                                             random_state=random_state
                                                             )
        else:
            z = (self.breakpoint - self.mu) / self.eta
            phi = 1 / (np.sqrt(2 * np.pi) * self.eta) * np.exp(-z ** 2 / 2)
            self.exp_ = np.exp(-self.breakpoint / self.beta)
            self.Phi = norm.cdf(z)
            self.c = self.beta / self.eta * phi / self.exp_
            self.mean = self.mu * self.Phi - self.eta * phi + self.c * (self.breakpoint + self.beta) * self.exp_
            super(ArmGaussianMisspecifiedExpTail, self).__init__(mean=self.mean,
                                                             variance=eta**2,  # variance is not eta**2, TODO later
                                                             random_state=random_state
                                                             )

    def sample(self, size=1):
        """
        Sampling strategy
        :return: float, a sample from the arm
        """
        if not self.misspecified:
            return self.local_random.normal(self.mu, self.eta, size)
        else:
            # Randomly mix proportionally to the respective mass of
            # the Gaussian and the Exponential.
            ind = (self.local_random.rand(size) < self.Phi / (self.Phi + self.c * self.exp_)) * 1
            return (
                self.local_random.normal(self.mu, self.eta, size) * ind
                + self.local_random.exponential(self.beta, size) * (1 - ind)
                )