""" Packages import """

from .MAB import GenericMAB
import numpy as np
from scipy.optimize import minimize_scalar, root_scalar


class BetaMAB(GenericMAB):
    """
    Beta Bandit Problem
    """
    def __init__(self, p):
        """
        Initialization
        :param p: np array, true values of parameters for each arm
        """
        # Initialization of arms from GenericMAB
        super().__init__(methods=['Beta']*len(p), p=p)
        # Parameters used for stop learning policy
        self.best_arm = self.get_best_arm()

    def get_best_arm(self):
        """
        Returns which arm is best
        """
        ind = np.nonzero(self.means == np.amax(self.means))[0]
        return np.argmax(ind)

    def MED(self, T, custom_optim=True):
        """
        Implementation of MED with the non-parametric KL-inf for bounded distributions
        """
        def div(k, tr):
            mu_star = np.max(tr.Sa/tr.Na)
            X = np.array(tr.rewards_arm[k])
            # Faster optimization: many times, the maximum of the concave
            # dual objective is attained on the boundary 0 or 1/(B-mu).
            # ~x2 speedup on some bandit instances.
            # If problem, fall back to standard minimize_scalar.
            fallback = False
            if custom_optim:
                def f(l):
                    return np.mean(np.log(1 - (X - mu_star) * l))

                def jac(l):
                    return -np.mean((X - mu_star) / (1 - (X - mu_star) * l))

                l_plus = 1e12 if mu_star == 1 else 1 / (1 - mu_star)

                if jac(0) * jac(l_plus) >= 0:
                    kinf = np.maximum(f(0), f(l_plus))
                else:
                    ret = root_scalar(
                        jac, method='brentq', bracket=[0, l_plus]
                    )
                    if ret.converged:
                        kinf = np.max([f(ret.root), f(0), f(l_plus)])
                    else:
                        fallback = True
            if not custom_optim or fallback:
                # minimize -E[log(1-(X-mu^*)*lambda)]
                def f(l):
                    return -np.mean(np.log(1 - (X - mu_star) * l))

                ret = minimize_scalar(
                    f, method='bounded', bounds=(0, 1 / (1 - mu_star))
                )
                if ret.success:
                    kinf = -ret.fun
                else:
                    # if error, just make this arm not eligible this turn
                    kinf = np.inf
            return kinf
        return self.Generic_MED(T, divergence=div, store_rewards_arm=True)

