import numpy as np

from comab.algo.baselines import CoMABAlgo
from comab.algo.comab_estimation_tools import neighborhood
from comab.environment import r_n, e_i


def local_greedy_allocation(R, p, N, n_t):
    K, _ = R.shape
    lb, ub = neighborhood(n_t, p, N)
    n = lb
    continue_adding_units = (np.sum(n) < N)
    while continue_adding_units:
        tentative_rewards = np.array(
            [r_n(R, n + e_i(k, K), p) for k in range(K)])  # compare the effect of adding one unit to each arm
        tentative_rewards[n >= ub] = 0
        k = np.argmax(tentative_rewards)  # choose the best arm to add to

        continue_adding_units *= not (np.max(tentative_rewards) < r_n(R, n,p))  # adding a unit does not degrade the reward (beware of NaN)
        if continue_adding_units:
            n += e_i(k, K)

        continue_adding_units *= (np.sum(n) < N)  # there are still some available units
    return n


class LocalGreedy(CoMABAlgo):
    def __init__(self, K, N, p, reward_estimator):
        super().__init__(K, N, p)
        self._reward_estimator = reward_estimator
        self.n[0] = int(2 * self._N / 3)
        self.t_k = np.zeros(self._N + 1, dtype=int)

    def update(self, arms_with_observation, observed_gains, observed_costs, t):
        if arms_with_observation[0] or self.n[0] == 0:
            self._reward_estimator.update_estimator(self.n, arms_with_observation, observed_gains, observed_costs, t)
            self.t_k[self.n[0]] += 1
            if self.t_k[self.n[0]]* 2*np.log(self._N) > t: # theoretically it is alpha=1/log(N) but it does not work so I chose alpha=1/2*log(N)
                _r_hat = self._reward_estimator.r_hat(k=self.n)
                self.n = local_greedy_allocation(_r_hat, self._p, self._N, self.n)

