import numpy as np
from src.core.bandit import BanditAlgorithm
from src import utils as util
from numba import njit

@njit(fastmath=True, cache=True)
def fast_compute_mu(Phi, gram_inv, feature_accumulator):
    theta = gram_inv @ feature_accumulator
    return Phi @ theta

@njit(fastmath=True, cache=True)
def fast_update_all(gram_inv, feature_accumulator, current_sigma, Phi, u, reward, lambda_):
    v = gram_inv @ u
    denom = 1.0 + np.dot(u, v)
    
    N = gram_inv.shape[0]
    for i in range(N):
        factor = v[i] / denom
        for j in range(N):
            gram_inv[i, j] -= factor * v[j]

    for i in range(N):
        feature_accumulator[i] += u[i] * reward

    w = Phi @ v
    for i in range(N):
        update_val = (w[i] * w[i]) / denom
        current_sigma[i] -= lambda_ * update_val
        if current_sigma[i] < 0.0:
            current_sigma[i] = 0.0


class GP_UCB_R(BanditAlgorithm):
    def __init__(self, num_actions, horizon, xi, B, config):
        super().__init__(num_actions, horizon)
        
        self.config = {
            'tol': 0.0001, 'window': 0, 'kernel': {'name': 'rbf', 'length_scale': 0.2},
            'lambda': 20, 'v': 0.01, 'seed': None,
        }
        self.config.update(config)

        self.horizon = horizon
        self.tol = self.config['tol']
        self.window = self.config['window']
        self.lambda_ = self.config['lambda']
        self.v = self.config['v']
        self.num_actions = num_actions
        
        self.action_history = np.full(horizon + 1, -1, dtype=int)
        self.reward_history = np.zeros(horizon + 1)
        self.set_arms = False 
        self.xi = xi

        self.init_params = {
            'num_actions': num_actions, 'horizon': horizon,
            'xi': self.xi, 'B': B, 'config': self.config
        }
       
        self.is_change = False
        self.tau = 0
        self.B = B
        self.R = 0.01
        self.lambda_ = 0.01
        self.information_list = []
        self.set_information = False
        
        self.Phi = None 
        self._cached_arms_hash = None
        self.feature_accumulator = None
        self.current_sigma = None 
        self.cached_initial_sigma = None

    def select_arm(self, arms, PT):
        t = self.t - self.tau
        num_actions = len(arms)
        
        arms_hash = hash(arms.data.tobytes()) if hasattr(arms, 'data') else id(arms)
        
        if not self.set_arms or self._cached_arms_hash != arms_hash:
            self.set_arms = True
            
            K = util.kernel_matrix(arms, self.config['kernel'])
            K_mod = K + 1e-9 * np.eye(self.num_actions)
            self.Phi = np.linalg.cholesky(K_mod)
            self._cached_arms_hash = arms_hash
            
            self.cached_initial_sigma = self.lambda_ * np.sum(self.Phi**2, axis=1)
            
            if self.current_sigma is None:
                self.current_sigma = self.cached_initial_sigma.copy()
                self.gram_inv = np.eye(num_actions) * self.lambda_
                self.feature_accumulator = np.zeros(num_actions)
                self.z = np.zeros(num_actions)
                self.action_counts = np.zeros(num_actions)
            else:
                self.reset() 

            self.information_gain = util.InformationGain(self.Phi, self.horizon, self.lambda_)
            self.d = arms[0].size
            self.information_gain.get = lambda t: np.log(t)**(self.d+2)
            ig_T = self.information_gain.get(self.T)
            if PT!=0:
                self.H = int(np.ceil((ig_T)**(1/4) * ((self.T/PT)**0.5)))
            else:
                self.H = int(np.ceil((ig_T)**(1/4) * (self.T**0.5)))
            
            steps = np.arange(1, self.T + 1)
            ig_vals = np.array([self.information_gain.get(x) for x in steps])
            self.information_list = self.B + self.R * np.sqrt(
                4 * (ig_vals  + 1 + np.log(1 / self.tol)))/np.sqrt(self.lambda_)
            self.set_information = True

        if t > 1:
            mu = fast_compute_mu(self.Phi, self.gram_inv, self.feature_accumulator)
        else:
            mu = np.zeros(num_actions)

        sigma = self.current_sigma

        idx = int(t) if t < len(self.information_list) else -1
        ucb_values = mu + self.information_list[idx] * np.sqrt(sigma)

        selected_index = np.argmax(ucb_values)
        self.last_selected_index = selected_index

        if (self.t) % self.H == 1:
            self.is_change = True

        return selected_index
    
    def update_statistics(self, arm_index, reward):
        t = self.t - self.tau
        self.action_history[t] = arm_index
        self.reward_history[t] = reward
        
        u = self.Phi[arm_index, :] 
        
        fast_update_all(
            self.gram_inv, 
            self.feature_accumulator, 
            self.current_sigma,
            self.Phi, 
            u, 
            float(reward), 
            self.lambda_
        )
        
        self.z[arm_index] += reward
        self.action_counts[arm_index] += 1
       
        if self.is_change:
            self.reset()

    def reset(self):
        self.is_change = False
        self.tau = self.t
        
        if self.gram_inv is not None:
            self.gram_inv.fill(0.0)
            np.fill_diagonal(self.gram_inv, self.lambda_)
            
            self.z.fill(0.0)
            self.action_counts.fill(0.0)
            self.feature_accumulator.fill(0.0)
        
        if self.cached_initial_sigma is not None:
            np.copyto(self.current_sigma, self.cached_initial_sigma)
        

    def re_init(self):
        super().re_init()
        self.action_history = np.full(self.horizon + 1, -1, dtype=int)
        self.reward_history = np.zeros(self.horizon + 1)
        self.gram_inv = np.eye(self.num_actions) * self.lambda_
        self.z = np.zeros(self.num_actions)
        self.action_counts = np.zeros(self.num_actions)
        self.tau = 0
        
        self.set_arms = False
        self.Phi = None
        self._cached_arms_hash = None
        self.feature_accumulator = None
        self.current_sigma = None
        self.cached_initial_sigma = None 
        self.information_list = []
        self.set_information = False

    def __str__(self):
        return "GP-UCB-R"