""" Packages import """
import numpy as np
from scipy.optimize import brentq
from .MAB import GenericMAB
from .utils import rollavg_bottlneck, rd_argmax, get_leader
from .tracker import Tracker2


class BetaBernoulliMAB(GenericMAB):
    """
    Bernoulli Bandit Problem
    """
    def __init__(self, p):
        """
        Initialization
        :param p: np.array, true probabilities of success for each arm
        """
        # Initialization of arms from GenericMAB
        super().__init__(methods=['B']*len(p), p=p)
        # Complexity
        self.Cp = sum([(self.mu_max-x)/self.kl(x, self.mu_max) for x in self.means if x != self.mu_max])

    @staticmethod
    def kl(x, y):
        """
        Implementation of the Kullback-Leibler divergence for two Bernoulli distributions (B(x),B(y))
        :param x: float
        :param y: float
        :return: float, KL(B(x), B(y))
        """
        if x == y:
            return 0
        elif x > 1-1e-6:
            return 0
        elif y == 0 or y == 1:
            return np.inf
        elif x < 1e-6:
            return (1-x) * np.log((1-x)/(1-y))
        return x * np.log(x/y) + (1-x) * np.log((1-x)/(1-y))

    def TS(self, T):
        def f(x):
            return np.random.beta(x.Sa+1, x.Na-x.Sa+1)
        return self.Index_Policy(T, f)

    def TS_star(self, T):
        tr = Tracker2(self.means, T, store_rewards_arm=True)
        for t in range(T):
            if t < self.nb_arms:
                arm = t % self.nb_arms
            else:
                ref_arm = rd_argmax(tr.Sa/tr.Na)
                ref_mean = tr.Sa[ref_arm]/tr.Na[ref_arm]
                candidates = [ref_arm]
                means = np.random.beta(tr.Sa + 1, tr.Na - tr.Sa + 1)
                for k in range(self.nb_arms):
                    if k != ref_arm:
                        if means[k] >= ref_mean:
                            candidates.append(k)
                arm = np.random.choice(candidates)
            reward = self.MAB[arm].sample()[0]
            tr.update(t, arm, reward)
        return tr

    def MED(self, T):
        def div(k, tr):
            best_emp = np.max(tr.Sa/tr.Na)
            return self.kl(tr.Sa[k]/tr.Na[k], best_emp)
        return self.Generic_MED(T, divergence=div)

    def kl_ucb(self, T, f):
        def index_func(x):
            res = []
            for k in range(self.nb_arms):
                if x.Sa[k]/x.Na[k] < 1e-6:
                    res.append(1)
                elif x.Sa[k]/x.Na[k] > 1-1e-6:
                    res.append(1)
                else:
                    def kl_shift(y):
                        return self.kl(x.Sa[k]/x.Na[k], y) - f(x.t)/x.Na[k]
                    res.append(brentq(kl_shift, x.Sa[k]/x.Na[k]-1e-7, 1 - 1e-10))
            return np.array(res)
        return self.Index_Policy(T, index_func)