import numpy as np

from comab.reward_estimator.reward_estimator import RewardEstimator, F_hat_combi


class WithFixedGridEstimation(RewardEstimator):
    def __init__(self, K, N, p, D, R, **kwargs):
        super().__init__(K, N, p, **kwargs)
        self._D = D  # number of values in riemann integral approximation
        self._R = R  # max reward (upper-bound of support of F)
        self._counts = np.zeros((self._K, self._N + self._P + 1, self._D))
        self._t = np.zeros((self._K, self._N + self._P + 1), dtype=int)  # number of samples (mk in overleaf)

    def update_estimator(self, n, arms_with_observation, observed_gains, observed_costs, t):
        self._counts[np.arange(self._K), n + self._p, :] += arms_with_observation[..., np.newaxis] * (
                    observed_gains[..., np.newaxis] <= (np.arange(self._D) + 1) * self._R / self._D)
        self._t[np.arange(self._K), n + self._p] += arms_with_observation

    def r_hat(self, **kwargs):
        reward = np.zeros((self._K, self._N + self._P))
        for u in range(self._D):
            _F_hat_combi = F_hat_combi(self._counts[..., u] / self._t, self._t)
            reward += np.maximum(_F_hat_combi[:, :-1] - _F_hat_combi[:, 1:], 0)
        return reward / self._D


class WithFixedGridEstimationK(WithFixedGridEstimation):
    def __init__(self, K, N, p, D, R, **kwargs):
        super().__init__(K, N, p, D, R, **kwargs)

    def r_hat(self, k, **kwargs):
        assert self._K == 1
        reward = np.zeros(self._N + self._P)
        p = self._p[0]
        _F_hat = np.nan_to_num(self._counts[0, k+p, :] / self._t[0, k+p, np.newaxis])
        _F_hat_powers = np.power(_F_hat, ((np.arange(self._N + 1, dtype=int) + p) / (k + p))[:, np.newaxis])
        reward[p:] = np.sum(np.maximum(_F_hat_powers[:-1, :] - _F_hat_powers[1:, :], 0), axis=-1)
        return (reward / self._D)[np.newaxis, :]

    def F_hat_slide(self, x, i):
        u = int(x * self._D)
        _F_hat = self._counts[..., i, u] / self._t[..., i]
        return _F_hat
