import numpy as np
from src.core.bandit import BanditAlgorithm
from src import utils as util
from numba import njit


@njit(fastmath=True, cache=True)
def fast_sw_compute_mu(Phi, gram_inv, feature_accumulator):
    theta = gram_inv @ feature_accumulator
    return Phi @ theta

@njit(fastmath=True, cache=True)
def fast_sw_update_all(gram_inv, feature_accumulator, current_sigma, Phi, 
                       u_add, r_add, u_remove, r_remove, do_remove, lambda_):
    
    v = gram_inv @ u_add
    denom_add = 1.0 + np.dot(u_add, v)
    
    N = gram_inv.shape[0]
    for i in range(N):
        factor = v[i] / denom_add
        for j in range(N):
            gram_inv[i, j] -= factor * v[j]

    w_add = Phi @ v
    for i in range(N):
        update_val = (w_add[i] * w_add[i]) / denom_add
        current_sigma[i] -= lambda_ * update_val
        if current_sigma[i] < 0: current_sigma[i] = 0.0

    for i in range(N):
        feature_accumulator[i] += u_add[i] * r_add
        
    if do_remove:
        v_rem = gram_inv @ u_remove
        
        denom_rem = 1.0 + np.dot(u_remove, v_rem)
        
        for i in range(N):
            factor = v_rem[i] / denom_rem
            for j in range(N):
                gram_inv[i, j] += factor * v_rem[j]
        
        w_rem = Phi @ v_rem
        for i in range(N):
            update_val = (w_rem[i] * w_rem[i]) / denom_rem
            current_sigma[i] += lambda_ * update_val
        
        for i in range(N):
            feature_accumulator[i] -= u_remove[i] * r_remove

class GP_UCB_SW(BanditAlgorithm):
    def __init__(self, num_actions, horizon, xi, B, config):
        super().__init__(num_actions, horizon)
        
        self.config = {
            'tol': 0.0001, 'window': 0, 'kernel': {'name': 'rbf', 'length_scale': 0.2},
            'lambda': 20, 'v': 0.01, 'seed': None,
        }
        self.config.update(config)

        self.xi = xi
        self.fixed_window = False
        self.horizon = horizon
        self.T = horizon
        self.tol = self.config['tol']
        self.window = self.config['window']
        self.v = self.config['v']
        self.num_actions = num_actions
        
        self.action_history = np.full(horizon + 1, -1, dtype=int)
        self.reward_history = np.zeros(horizon + 1)
        self.set_arms = False 

        self.init_params = {
            'num_actions': num_actions, 'horizon': horizon,
            'xi': self.xi, 'B': B, 'config': self.config
        }
        self.B = B
        self.R = 0.01 
        self.lambda_ = 0.01

        self.information_list = []
        self.set_information = False

        self.Phi = None
        self._cached_arms_hash = None
        self.feature_accumulator = None
        self.current_sigma = None

    def select_arm(self, arms, PT):
        t = self.t
        num_actions = len(arms)
        
        arms_hash = hash(arms.data.tobytes()) if hasattr(arms, 'data') else id(arms)

        if not self.set_arms:
            self.set_arms = True
            if self.Phi is None or self._cached_arms_hash != arms_hash:
                K = util.kernel_matrix(arms, self.config['kernel'])
                K_mod = K + 1e-8 * np.eye(self.num_actions)
                self.Phi = np.linalg.cholesky(K_mod)
                self._cached_arms_hash = arms_hash
                
                self.current_sigma = self.lambda_ * np.diag(K_mod).copy()
            
            self.gram_inv = np.eye(num_actions) * self.lambda_
            self.z = np.zeros(num_actions)
            self.action_counts = np.zeros(num_actions)
            self.feature_accumulator = np.zeros(num_actions)
            
            if self.current_sigma is None:
                self.current_sigma = self.lambda_ * np.sum(self.Phi**2, axis=1)

            self.information_gain = util.InformationGain(self.Phi, self.horizon, self.lambda_)
            self.d = arms[0].size
            self.information_gain.get = lambda t: np.log(t)**(self.d+2)
            
            if PT == 0:
                self.window = 0
            else:
                ig_T = self.information_gain.get(self.T)
                self.window = int(np.ceil((ig_T)**(1/4) * ((self.T / PT)**0.5)))
       
        if not self.set_information:
            self.set_information = True
            steps = np.arange(1, self.T + 1)
            ig_vals = np.array([self.information_gain.get(x) for x in steps])
            self.information_list = self.B + self.R * np.sqrt(
                4* (ig_vals+ np.log(1 / self.tol)))/np.sqrt(self.lambda_)

        if t > 1:
            mu = fast_sw_compute_mu(self.Phi, self.gram_inv, self.feature_accumulator)
        else:
            mu = np.zeros(num_actions)

        sigma = self.current_sigma

        idx = int(t) if t < len(self.information_list) else -1
        beta_t = self.information_list[idx]
        
        ucb_values = mu + beta_t * np.sqrt(sigma)
        selected_index = np.argmax(ucb_values)
        self.last_selected_index = selected_index

        return selected_index
    
    def update_statistics(self, arm_index, reward):
        t = self.t
        self.action_history[t] = arm_index
        self.reward_history[t] = reward
        
        u_add = self.Phi[arm_index, :].copy()
        r_add = float(reward)
        
        do_remove = False
        u_remove = np.zeros_like(u_add)
        r_remove = 0.0

        if self.window and t > self.window:
            do_remove = True
            past_arm = self.action_history[t - self.window]
            u_remove = self.Phi[past_arm, :].copy()
            r_remove = float(self.reward_history[t - self.window])
            
            self.action_counts[past_arm] -= 1
            self.z[past_arm] -= r_remove

        fast_sw_update_all(
            self.gram_inv, 
            self.feature_accumulator, 
            self.current_sigma,
            self.Phi, 
            u_add, r_add, 
            u_remove, r_remove, 
            do_remove,
            self.lambda_
        )
        
        self.z[arm_index] += reward
        self.action_counts[arm_index] += 1

    def reset(self):
        self.arms = []
        self.rewards = []
        self.action_history = np.full(self.horizon + 1, -1, dtype=int)
        self.reward_history = np.zeros(self.horizon + 1)
        self.gram_inv = np.eye(self.num_actions) * self.lambda_
        self.z = np.zeros(self.num_actions)
        self.action_counts = np.zeros(self.num_actions)
        self.feature_accumulator = np.zeros(self.num_actions)
        
        if self.Phi is not None:
             self.current_sigma = self.lambda_ * np.sum(self.Phi**2, axis=1)
        
        self.set_arms = False

    def re_init(self):
        super().re_init()
        self.action_history = np.full(self.horizon + 1, -1, dtype=int)
        self.reward_history = np.zeros(self.horizon + 1)
        self.gram_inv = np.eye(self.num_actions) * self.lambda_
        self.z = np.zeros(self.num_actions)
        self.action_counts = np.zeros(self.num_actions)
        self.set_arms = False
        self.Phi = None
        self._cached_arms_hash = None
        self.feature_accumulator = None
        self.current_sigma = None
        self.information_list = []
        self.set_information = False

    def __str__(self):
        return "GP-UCB-SW"