import numpy as np
from src import utils as util
from src.core.bandit import BanditAlgorithm


class Scheduler:
    def __init__(self, block_index, E, j_star=0, prob_scale=1, block_test=False):
        self.block_index = block_index
        self.E = E
        self.block_test = block_test
        schedule = {}
        for j in range(block_index):
            if j < j_star:
                continue
            p = np.sqrt(2 ** (j - block_index)) * prob_scale
            schedule[j] = np.random.rand(2 ** (block_index - j)) < p
        self.schedule = schedule

    def get_index(self, t):
        if self.block_index == 0:
            return 0
        assert t >= 0 and t < self.E * (2 ** self.block_index)
        for j in range(self.block_index):
            if j not in self.schedule:
                continue
            current_index = int(t / self.E / (2 ** j))
            if self.schedule[j][current_index]:
                return j
        return self.block_index

    def get_intervals_ending_at(self, t):
        ret = []
        for j in range(self.block_index):
            if j not in self.schedule:
                continue
            if (t + 1) % (self.E * (2 ** j)) == 0:
                index = (int((t + 1) / self.E / (2 ** j))) - 1
                if self.schedule[j][index]:
                    ret.append((j, index))
        if self.block_test and t + 1 == self.E * (2 ** self.block_index):
            ret.append((self.block_index, 0))
        return ret


class OPKB(BanditAlgorithm):
    name = 'ADA-OPKB (Lin)'
    

    def __init__(self, num_actions, horizon, B):
        super().__init__(num_actions, horizon)
        self.T=horizon
        self.num_actions=num_actions

        self.d=5
        self.config = {
        'kernel': {
            'name': 'linear'        
            },
        'tol': 0.0001,
        'gamma': 1,
        'beta_scale': 5,
        'mu_scale': 1 / 10,
        'e_scale': 0.1,
        'mixture_probability': 0.5,
        'epsilon_scale': 2.5,
        'j_star': 4,
        'g_scale': 0.5,
        'change_detection': False,
        'seed': None}
        C = self.config

        
        self.detection_count = 0
        self.set_arms = False 

        self.init_params = {
            'num_actions': num_actions,
            'horizon': horizon,
            'B':B
        }
    def _initialize(self, t):
        C = self.config

        self.tau = t
        self.epoch_start_time = t
        self.m = 0
        self.beta = self._compute_beta(0)
        self.mu = self._compute_mu(0)
        self.gram_inv = {}
        self.deltas = {}
        self.reward_estimates = {}
        self.history = []
        self.p = self.pi
        self.ps = [self.pi]
        self.gram_inv[0] = util.s_inv(util.S(self.p, self.Phi, C['gamma'] / self.T))
        self.block_end_time = {}

        self.scheduler = Scheduler(0, self.E, j_star=int(C['j_star']))

    def _compute_beta(self, m):
        C = self.config
        t = self.E * (2 ** m)
        gamma_t = self.gamma_T
        mu = 0.5 * np.sqrt(1 / (2 ** m))
        epsilon = (40 + 16 * np.sqrt(self.alpha)) * mu
        return 2 * gamma_t / epsilon * C['beta_scale']

    def _compute_mu(self, m):
        C = self.config
        return 0.5 * np.sqrt(1 / (2 ** m)) * C['mu_scale']

    def _compute_epsilon(self, m):
        C = self.config
        mu = 0.5 * np.sqrt(1 / (2 ** m))
        return (40 + 16 * np.sqrt(self.alpha)) * mu

    def _compute_reward_estimate(self, history, gram_inv=None):
        if gram_inv is None:
            gram_inv = self.gram_inv
        zs = {}
        for policy_index, a, r in history:
            if policy_index not in zs:
                zs[policy_index] = np.zeros(self.num_actions)
            zs[policy_index] += self.Phi[a, :] * r
        l = np.zeros(self.num_actions)
        for policy_index, z in zs.items():
            l += gram_inv[policy_index].dot(z)
        theta_hat = l / len(history)
        return self.Phi @ theta_hat

    def _compute_delta(self, history, gram_inv=None):
        R_hat = self._compute_reward_estimate(history, gram_inv)
        return R_hat, (np.max(R_hat) - R_hat).ravel()

    def _end_of_block_update(self, t, a, r):
        C = self.config

        R_hat, delta = self._compute_delta(self.history)
        self.reward_estimates[self.m] = R_hat
        self.deltas[self.m] = delta

        self.beta = self._compute_beta(self.m + 1)
        self.mu = self._compute_mu(self.m + 1)
        x = util.OP(self.Phi, delta, self.beta, self.gamma)
        G = delta <= 2 * self.alpha * self.gamma_T / self.beta * C['beta_scale'] * C['g_scale']

        if sum(G) == self.num_actions:
            pG = self.pi
        elif sum(G) == 1:
            pG = G.astype(float)
        else:
            Phi = self.Phi[G, :][:, G]
            p = kernel_optimal_design(Phi, C['gamma'] / self.T)
            pG = np.zeros(self.num_actions)
            pG[G] = p

        mixture = C['mixture_probability']
        self.p = (1 - self.mu) * (x * (1 - mixture) + pG * mixture) + self.mu * self.pi
        self.ps.append(self.p)

        self.tau = t + 1
        self.m += 1
        self.gram_inv[self.m] = util.s_inv(util.S(self.p, self.Phi, self.gamma))
        self.scheduler = Scheduler(self.m, self.E, j_star=int(C['j_star']))

    def action(self, t):
        if self.config['change_detection']:
            policy_index = self.scheduler.get_index(t - self.tau)
        else:
            policy_index = self.m
        p = self.ps[policy_index]
        return np.random.choice(p.size, p=p)
    

    def select_arm(self, arms, change_points):
        t = self.t
        if not self.set_arms:
            self.set_arms = True
            C=self.config

            K = util.kernel_matrix(arms, C['kernel'])
            self.Phi = np.linalg.cholesky(K+ 1e-9 * np.eye(self.num_actions))
            self.pi = kernel_optimal_design(self.Phi, C['gamma'] / self.T)

            self.gamma = C['gamma'] / self.T

            self.information_gain = util.InformationGain(self.Phi, self.T, C['gamma'])
            self.gamma_T = self.information_gain.get_exact(self.T)
            self.E = int(np.ceil(
            4 * self.gamma_T * np.log(8 * self.num_actions * self.T * np.log2(self.T) / C['tol']) * C['e_scale']
             ))
            self.alpha = C['gamma'] / (4 * np.log(8 * self.T * np.log2(self.T) * self.num_actions / C['tol']))

            self.d=arms[0].size

            self._initialize(0)
            
           

       
        if self.config['change_detection']:
            policy_index = self.scheduler.get_index(t - self.tau)
        else:
            policy_index = self.m
        p = self.ps[policy_index]
        return np.random.choice(p.size, p=p)
    
    def update_statistics(self, arm_index, reward):
        t = self.t


        C = self.config

        block_t = t - self.tau
        if C['change_detection']:
            policy_index = self.scheduler.get_index(block_t)
        else:
            policy_index = self.m
        self.history.append((policy_index, arm_index, reward))

        if C['change_detection'] and self._end_of_replay_change_detected(self.m, block_t):
            self.detection_count += 1
            if not self.detection_count >= 100:
                self._initialize(t + 1)
                return

        if t - self.tau + 1 >= (2 ** self.m) * self.E:
            self._end_of_block_update(t, arm_index, reward)
        

    def _end_of_replay_change_detected(self, m, block_t):
        C = self.config
        assert C['change_detection']
        for j, replay_index in self.scheduler.get_intervals_ending_at(block_t):
            start_index = self.tau + self.E * (2 ** j) * replay_index - self.epoch_start_time
            end_index = start_index + self.E * (2 ** j)
            replay_history = self.history[start_index:end_index]
            assert len(replay_history) == self.E * (2 ** j)
            assert all([policy_index <= j for policy_index, _, _ in replay_history])
            R_hat, delta = self._compute_delta(replay_history)
            for k in range(int(C['j_star']), m):
                r = min(k, j)
                epsilon = (self._compute_epsilon(k) + self._compute_epsilon(j)) / 2
                diff = max(
                    np.max(self.deltas[k] - 4 * delta),
                    np.max(delta - 4 * self.deltas[k]),
                ) / (4 * epsilon * C['epsilon_scale'])
                if diff > 1:
                    return True
        return False

    def _update(self, t, a, r):
        C = self.config

        block_t = t - self.tau
        if C['change_detection']:
            policy_index = self.scheduler.get_index(block_t)
        else:
            policy_index = self.m
        self.history.append((policy_index, a, r))

        if C['change_detection'] and self._end_of_replay_change_detected(self.m, block_t):
            self.detection_count += 1
            if not self.detection_count >= 100:
                self._initialize(t + 1)
                return

        if t - self.tau + 1 >= (2 ** self.m) * self.E:
            self._end_of_block_update(t, a, r)
    def re_init(self):

        self.detection_count = 0
        self.set_arms = False 



def kernel_optimal_design(Phi, gamma):
    k, _ = Phi.shape
    return util.OP(Phi, np.zeros(k), 1, gamma)

