""" Packages import """
import numpy as np
from .MAB import GenericMAB
from .tracker import Tracker2
from .utils import rd_argmax


class TruncGaussianMAB(GenericMAB):
    """
    Gaussian Bandit Problem
    """
    def __init__(self, p):
        """
        Initialization
        :param p: np.array, true values of 1/lambda for each arm
        """
        # Initialization of arms from GenericMAB
        super().__init__(methods=['TG']*len(p), p=p)
        self.best_arm = self.get_best_arm()
        self.Cp = None

    def get_best_arm(self):
        """
        Defines the best arm, with a tie-breaking rule in favor of the distribution with the lowest variance.
        """
        ind = np.nonzero(self.means == np.amax(self.means))[0]
        std = [self.MAB[arm].scale for arm in ind]
        u = np.argmin(std)
        return ind[u]

    def PHE(self, T, a, distrib=None):
        """
        :param T: Time Horizon
        :param a: proportion of perturbed history. a=1 -> same proportion, a=0-> no perturbed history
        :param distrib: distribution of the perturbed history
        :return:
        """
        tr = Tracker2(self.means, T, store_rewards_arm=True)
        for t in range(T):
            if t < self.nb_arms:
                arm = t
            else:
                idx_mean = np.zeros(self.nb_arms)
                for k in range(self.nb_arms):
                    ph = np.random.binomial(n=np.int(a*tr.Na[k]), p=0.5)
                    idx_mean[k] = (tr.Sa[k]+ph)/(tr.Na[k]+np.int(a*tr.Na[k]))
                arm = rd_argmax(idx_mean)
            reward = self.MAB[arm].sample()[0]
            tr.update(t, arm, reward)
        return tr

    def TS(self, T):
        """
        Bernoulli TS with rewards turned into binary outputs
        """
        def f(S, N):
            return np.random.beta(S+1, N-S+1)
        tr = Tracker2(self.means, T)
        bin_Sa = np.zeros(self.nb_arms)
        for t in range(T):
            if t < self.nb_arms:
                arm = t % self.nb_arms
            else:
                arm = rd_argmax(f(bin_Sa, tr.Na))
            reward = self.MAB[arm].sample()[0]
            bin_Sa[arm] += np.random.binomial(n=1, p=reward)
            tr.update(t, arm, reward)
        return tr

    def IMED(self, T):
        """
        Bernoulli IMED
        """
        def kl_ber(x, y):
            if x == y:
                return 0
            elif x > 1 - 1e-6:
                return 0
            elif y == 0 or y == 1:
                return np.inf
            elif x < 1e-6:
                return (1 - x) * np.log((1 - x) / (1 - y))
            return x * np.log(x / y) + (1 - x) * np.log((1 - x) / (1 - y))

        def index_func(bin_Sa, x):
            mu_max = np.max(bin_Sa/x.Na)
            idx = []
            for k in range(self.nb_arms):
                idx.append(x.Na[k]*kl_ber(bin_Sa[k]/x.Na[k], mu_max)+np.log(x.Na[k]))
            return -np.array(idx)
        tr = Tracker2(self.means, T)
        bin_Sa = np.zeros(self.nb_arms)
        for t in range(T):
            if t < self.nb_arms:
                arm = t % self.nb_arms
            else:
                arm = rd_argmax(index_func(bin_Sa, tr))
            reward = self.MAB[arm].sample()[0]
            bin_Sa[arm] += np.random.binomial(n=1, p=reward)
            tr.update(t, arm, reward)
        return tr
