import numpy as np
import scipy

from comab.algo.baselines import CoMABAlgo


def I(p, q):
    if p == 0 or p == 1:
        return 0
    return p * np.log(p / q) + (1 - p) * np.log((1 - p) / (1 - q))


def F(p, s, n, c, tol):
    if s == 0 or p == 1 or p==0:
        return 1
    a = np.random.uniform(0, 1)

    def f(x):
        value = (n > 0) * np.log(n) + c * (n > 1) * np.log(np.log(n)) - s * I(p, x)
        fprime = -np.nan_to_num(s * (1 - p) / (1 - x) - s * p / x)
        fprime2 = -np.nan_to_num(s * (1 - p) / (1 - x) ** 2 + s * p / x ** 2)
        return value, fprime, fprime2

    res = scipy.optimize.root_scalar(
        f=f,
        bracket=[p + 1e-10, 1 - 1e-10],
        x0=a * (p + 1e-10) + (1 - 1e-10 - a),
        x1=p + (1 - p) / 2,
        fprime=True,
        fprime2=True,
        xtol=tol
    )
    return res.root


class OSUB(CoMABAlgo):
    def __init__(self, K, N, p, c, **kwargs):
        super().__init__(K, N, p)
        assert K == 1
        self.sum_reward = np.zeros(self._N + 1)
        self.t_k = np.zeros(self._N + 1, dtype=int)
        self.l_k = np.zeros(self._N + 1, dtype=int)
        self.c = c
        self.leader = int(self._N / 2)
        self.n[0] = self.leader

    @property
    def r_n(self):
        return np.clip(np.nan_to_num(self.sum_reward / self.t_k), 0, 1)

    def ub(self, leader, t):
        return np.clip(np.array(
            [
                F(self.r_n[_k], self.t_k[_k], self.l_k[leader], self.c, 1e-12)
                for _k in np.arange(self._N + 1)
            ]),
            0,
            1)\
               * (np.arange(self._N+1)>0)

    def masked_ub(self, leader, t):
        mask = np.zeros(self._N + 1, dtype=bool)
        mask[max(1, leader - 1)] = True
        mask[leader] = True
        mask[min(leader + 1, self._N)] = True
        return self.ub(leader, t) * mask

    def update(self, arms_with_observation, observed_gains, observed_costs, t):
        # update
        self.sum_reward[self.n[0]] += observed_gains[0] - observed_costs[0]
        self.t_k[self.n[0]] += 1
        self.l_k[self.leader] += 1

        # planning
        self.leader = np.argmax(self.r_n)
        if (self.l_k[self.leader]-1) % 3 == 0:
            self.n[0] = self.leader
        else:
            self.n[0] = np.argmax(self.masked_ub(self.leader, t))

