""" Packages import """
import numpy as np
from .MAB import GenericMAB
from .utils import rollavg_bottlneck, rd_argmax, get_leader
from .tracker import Tracker2


class GaussianMAB(GenericMAB):
    """
    Gaussian Bandit Problem
    """
    def __init__(self, p):
        """
        Initialization
        :param p: np.array, true values of 1/lambda for each arm
        """
        # Initialization of arms from GenericMAB
        super().__init__(methods=['G']*len(p), p=p)
        # Parameters used for stop learning policy
        self.best_arm = self.get_best_arm()
        # Careful: Cp is the bound only with same variance for each arm
        self.Cp = sum([(self.mu_max - arm.mu) / self.kl2(arm.mu, self.mu_max, arm.eta, self.MAB[self.best_arm].eta)
                       for arm in self.MAB if arm.mu != self.mu_max])

    def get_best_arm(self):
        """
        Defines the best arm, with a tie-breaking rule in favor of the distribution with the lowest variance.
        """
        ind = np.nonzero(self.means == np.amax(self.means))[0]
        std = [self.MAB[arm].eta for arm in ind]
        u = np.argmin(std)
        return ind[u]

    @staticmethod
    def kl(mu1, mu2):
        """
        Implementation of the Kullback-Leibler divergence for two Gaussian distributions of same variance
        :param x: float
        :param y: float
        :return: float, KL(B(x), B(y))
        """
        return (mu2-mu1)**2/2

    @staticmethod
    def kl2(mu1, mu2, sigma1, sigma2):
        """
        Implementation of the Kullback-Leibler divergence for two Bernoulli distributions (B(x),B(y))
        :param x: float
        :param y: float
        :return: float, KL(B(x), B(y))
        """
        return np.log(sigma2/sigma1) + 0.5 * (sigma1**2/sigma2**2 + (mu2-mu1)**2/sigma2**2 - 1)

    def TS(self, T):
        """
        Thompson Sampling with known variance, and an inproper uniform prior
         on the mean
        :param T: Time Horizon
        :return:
        """
        eta = np.array([arm.eta for arm in self.MAB])

        def f(x):
            return np.random.normal(x.Sa/x.Na, eta/np.sqrt(x.Na))
        return self.Index_Policy(T, f)

    def TS_star(self, T):
        """
        Implementation of TS^\star with Gaussian rewards
        """
        tr = Tracker2(self.means, T, store_rewards_arm=True)
        eta = np.array([arm.eta for arm in self.MAB])
        for t in range(T):
            if t < self.nb_arms:
                arm = t % self.nb_arms
            else:
                ref_arm = rd_argmax(tr.Sa/tr.Na)
                ref_mean = tr.Sa[ref_arm]/tr.Na[ref_arm]
                candidates = [ref_arm]
                means = np.random.normal(tr.Sa/tr.Na, eta/np.sqrt(tr.Na))
                for k in range(self.nb_arms):
                    if means[k] >= ref_mean and k != ref_arm:
                        candidates.append(k)
                arm = np.random.choice(candidates)
            reward = self.MAB[arm].sample()[0]
            tr.update(t, arm, reward)
        return tr

    def MED(self, T):
        """
        Implementation of MED for Gaussian variables with known (but potentially different) variances
        """
        eta = np.array([arm.eta for arm in self.MAB])
        def div(k, tr):
            best_arm = rd_argmax(tr.Sa/tr.Na)
            best_emp = tr.Sa[best_arm]/tr.Na[best_arm]
            return 0.5 * (tr.Sa[k]/tr.Na[k] - best_emp)**2/eta[k]**2
        return self.Generic_MED(T, divergence=div)

    def kl_ucb(self, T, f):
        """
        :param T: Time Horizon
        :param rho: coefficient for the upper bound
        :return:
        """
        def index_func(x):
            return x.Sa / x.Na + np.sqrt(f(x.t)*2 / x.Na)
        return self.Index_Policy(T, index_func)











