from copy import deepcopy
from functools import lru_cache

import numpy as np

from comab.algo.baselines import CoMABAlgo
from comab.algo.comab_estimation_tools import grid, v_lr_all, B_mask
from comab.environment import r_n
from comab.reward_estimator.fixed_grid_estimator import WithFixedGridEstimationK


def C_t_fn(UCB, S, N, l_hat, L_t):
    ucb_above_lcb = (UCB >= L_t)
    C = []
    for k in S:
        l = min(k, l_hat)
        u = max(k, l_hat)
        if np.all(ucb_above_lcb[l:(u + 1)]):
            C.append(k)
    return C

def beta_minus(n, p, D, k):
    return 4 * np.sqrt(3) * (
            n / (2 * (n + p) - (k + p)) + n / (2 * D)
            +
            np.power(2, ((n + p - 1) / (k + p)) - 2) * (n / (k + p))
    )


def beta_plus(n, p, D, k):
    return 6 * np.sqrt(3) * (
            (n / (2 * (n + p) - (k + p)) + n / (2 * D))
            +
            n / (3 * (k + p))
    )




class GreedyGrid(CoMABAlgo):
    def __init__(self, K, N, p, T, delta, D, R, alpha, **kwargs):
        super().__init__(K, N, p)
        self.alpha = alpha
        self.delta = delta
        self.reward_estimator = WithFixedGridEstimationK(K, N, p, D, R)
        self.T = T
        self._t = 0
        self._S = grid(self._N, p)
        self._m_nt = np.zeros((self._N + 1, self.T), dtype=int)
        self.l_n = np.zeros(self._N + 1, dtype=int)
        self.k_n = np.zeros(self._N + 1, dtype=int)
        self.v_l, self.v_r = v_lr_all(self._S, self._N)
        self.B_n = B_mask(self.v_l, self._N, self._S)
        self.n[0] = self.v_l[self._N // 2]
        self.D = D

    def update_m(self, arms_with_observation, observed_gains, observed_costs, t):
        self._m_nt[:, t] = self._m_nt[:, t - 1]
        self._m_nt[self.n[0], t] += 1
        return self._m_nt

    def update_l_n(self):
        mask = deepcopy(self.B_n)
        mask[np.arange(self._N + 1), self.v_l] = True
        mask[np.arange(self._N + 1), self.v_r] = True
        self.l_n = np.argmax(mask * self.m_tilde_nt, axis=1)
        # self.l_n+self._p[0] <= 3*(np.arange(self._N+1)+self._p[0]-1)/2
        return self.l_n

    def update_k_n(self):
        # mask[n,k]
        mask = np.eye(self._N + 1, dtype=bool)
        mask[np.arange(self._N + 1), self.v_l] = True
        mask[np.arange(self._N + 1), self.v_r] = True
        mask[np.arange(self._N + 1), self.l_n] = True
        mask[0,0] = False
        self.k_n = np.argmax(mask * self.m_nt, axis=1)
        # self.k_n+self._p[0] <= 3*(np.arange(self._N+1)+self._p[0]-1)/2
        return self.k_n

    @property
    def r_n(self):
        return self._r_n(self._t)

    @lru_cache(maxsize=2)
    def _r_n(self, t):
        assert t == self._t
        _r_hat = self.reward_estimator.r_hat(self.k_n)
        return np.array([r_n(_r_hat, np.array([n]), self._p) for n in range(self._N + 1)])

    @property
    def m_nt(self):
        return self._m_nt[:, self._t]

    @property
    def m_tilde_nt(self):
        return self._m_nt[:, self._t] - self._m_nt[:, int(self._t / 2)]

    @property
    def d_hat(self):
        result = np.ones(self._N+1)
        for n in range(self._N+1):
            if self.D[n] >= 1:
                mask = []
                for u in range(self.D[n]):
                    mask.append(self.reward_estimator.F_hat_slide(u/self.D[n], self.k_n[n] + self._p[0]) >= 8 * np.log(
                        2 * (self.D[n] + 1) / self.delta[self._t]) / self.m_nt[n])
                mask = np.array(mask)
                result[n] = np.argmax(mask)
        return result

    @property
    def xi_plus(self):
        n = np.arange(self._N + 1)
        p = self._p[0]
        return np.power(8, (n + p - 1) / (self.k_n + p)) * (self.d_hat + 1) / (self.D)

    @property
    def xi_minus(self):
        n = np.arange(self._N + 1)
        p = self._p[0]
        return np.power(4, (n + p - 1) / (self.k_n + p)) * (self.d_hat + 1) / (self.D)

    @property
    def LCB(self):
        return np.maximum(self._LCB, self._LCB_on_grid)

    @property
    def UCB(self):
        return np.minimum(self._UCB, self._UCB_on_grid)

    @property
    def _LCB_on_grid(self):
        on_grid = np.zeros(self._N+1, dtype=bool)
        on_grid[self._S] = True
        lcb = np.zeros(self._N+1)
        lcb[np.logical_not(on_grid)] = -np.inf
        lcb += self.r_n
        lcb -= np.sqrt(2*np.log(self._t)/self.m_nt)
        return lcb

    @property
    def _UCB_on_grid(self):
        on_grid = np.zeros(self._N+1, dtype=bool)
        on_grid[self._S] = True
        ucb = np.zeros(self._N+1)
        ucb[np.logical_not(on_grid)] = np.inf
        ucb += self.r_n
        ucb += np.sqrt(2*np.log(self._t)/self.m_nt)
        return ucb

    @property
    def _LCB(self):
        p = self._p[0]
        n = np.arange(self._N + 1)
        D = deepcopy(self.D)
        D[0] = 1
        # D = np.ceil(n * np.sqrt(self.m_nt[self.k_n])) - 1
        A = np.log(2 * D / self.delta[self._t]) / self.m_nt[self.k_n]
        beta = beta_plus(n, p, D-1, self.k_n)
        xi = self.xi_plus
        result = self.r_n \
               - 1 / np.sqrt(self.m_nt[self.k_n]) \
               - beta * np.sqrt(A) \
               - n * xi * np.power(A, (n + p - 1) / (self.k_n + p))
        result[0] = 0
        result = np.nan_to_num(result, nan=-np.inf)
        return result

    @property
    def _UCB(self):
        p = self._p[0]
        n = np.arange(self._N + 1)
        D = deepcopy(self.D)
        D[0] = 1
        # D = np.ceil(n * np.sqrt(self.m_nt[self.k_n])) - 1
        A = np.log(2 * D / self.delta[self._t]) / self.m_nt[self.k_n]
        beta = beta_minus(n, p, D-1, self.k_n)
        xi = self.xi_minus
        result = self.r_n \
               + 1 / np.sqrt(self.m_nt[self.k_n]) \
               + beta * np.sqrt(A) \
               + n * xi * np.power(A, (n + p - 1) / (self.k_n + p))
        result[0] = 0
        result = np.nan_to_num(result, nan=np.inf)
        return result


    def update(self, arms_with_observation, observed_gains, observed_costs, t):
        if arms_with_observation[0] or self.n[0] == 0:
            self._t += 1

            self.reward_estimator.update_estimator(self.n, arms_with_observation, observed_gains, observed_costs, self._t)
            self.update_m(arms_with_observation, observed_gains, observed_costs, self._t)
            self.update_l_n()
            self.update_k_n()
            self.D = np.ceil(np.arange(self._N+1) * np.sqrt(self.m_nt[self.k_n])).astype(int)

            l_hat = np.argmax(self.LCB)
            nm1 = self.n[0]
            if not self.B_n[l_hat, nm1] or (self.m_nt[nm1] >= self.alpha * self._t):
                L_t = np.max(self.LCB)
                C_t = C_t_fn(self.UCB, self._S, self._N, l_hat, L_t)
                if C_t:
                    self.n[0] = np.random.choice(C_t)  # change to round robin
                else:
                    r_hat_next_to_l_hat = self.B_n[l_hat] * self.r_n[l_hat]
                    self.n[0] = np.random.choice(np.flatnonzero(r_hat_next_to_l_hat == np.max(r_hat_next_to_l_hat)))