""" Packages import """
import numpy as np
from scipy.optimize import brentq
from .MAB import GenericMAB
from .utils import rollavg_bottlneck, rd_argmax, get_leader
from .tracker import Tracker2


class RadeMAB(GenericMAB):
    """
    Bernoulli Bandit Problem
    """
    def __init__(self, mu):
        """
        Initialization
        :param p: np.array, true probabilities of success for each arm
        """
        # Initialization of arms from GenericMAB
        super().__init__(methods=['Rade']*len(mu), p=mu)

    @staticmethod
    def kl(x, y):
        x = (x+1)/2
        y = (y+1)/2
        if x == y:
            return 0
        elif x > 1 - 1e-6:
            return 0
        elif y == 0 or y == 1:
            return np.inf
        elif x < 1e-6:
            return (1 - x) * np.log((1 - x) / (1 - y))
        return x * np.log(x / y) + (1 - x) * np.log((1 - x) / (1 - y))

    def TS(self, T):
        """
        Beta-Bernoulli TS with rescaled rewards
        """
        def f(x):
            return np.random.beta((x.Sa+x.Na)/2+1, (x.Na-x.Sa)/2+1)
        return self.Index_Policy(T, f)

    def SAMBA(self, T, alpha):
        """
        Modifying the reward to get into the [0,1] range
        """
        tr = Tracker2(self.means, T, store_rewards_arm=True)
        prob = np.ones(self.nb_arms)/self.nb_arms
        for t in range(T):
            arm = np.random.choice(np.arange(self.nb_arms), p=prob)
            reward = (1+self.MAB[arm].sample()[0])/2
            tr.update(t, arm, reward)
            a_star = rd_argmax(prob)
            if a_star == arm:
                prob = prob - alpha * prob**2 * reward / prob[a_star]
            else:
                prob[arm] += alpha * prob[arm] * reward
            p_sub = prob.sum() - prob[a_star]  ## at this step it might hold that prob.sum() >= 1
            prob[a_star] = 1 - p_sub
        return tr

    def UCB1(self, T, rho=1.):
        """
        :param T: Time Horizon
        :param rho: coefficient for the upper bound
        :return:
        """
        def index_func(x):
            mu = (x.Sa / x.Na)/2+1
            return mu + rho * np.sqrt(2 * np.log(x.t + 1) / x.Na)
        return self.Index_Policy(T, index_func)

    def MED(self, T):
        def div(k, tr):
            best_emp = np.max(tr.Sa/tr.Na)
            return self.kl(tr.Sa[k]/tr.Na[k], best_emp)
        return self.Generic_MED(T, divergence=div)


