import numpy as np
import scipy.linalg as spla
from src.core.bandit import BanditAlgorithm
from src import utils as util
from numba import njit


@njit(fastmath=True, cache=True)
def fast_weighted_outer_update(gram, u, w_t):
    N = gram.shape[0]
    for i in range(N):
        val_i = u[i] * w_t
        for j in range(N):
            gram[i, j] += val_i * u[j]

@njit(fastmath=True, cache=True)
def fast_add_diagonal(M, gram, diag_val):
    N = gram.shape[0]
    for i in range(N):
        for j in range(N):
            M[i, j] = gram[i, j]
        M[i, i] += diag_val

class GP_UCB_W(BanditAlgorithm):
    def __init__(self, num_actions, horizon, xi, B, config):
        super().__init__(num_actions, horizon)
        
        self.config = {
            'tol': 0.0001, 'window': 0, 'kernel': {'name': 'rbf', 'length_scale': 0.2},
            'lambda': 20, 'v': 0.01, 'seed': None,
        }
        self.config.update(config)

        self.xi = xi
        self.fixed_window = False
        self.horizon = horizon
        self.T = horizon
        self.tol = self.config['tol']
        self.window = self.config['window']
        self.lambda_ = self.config['lambda']
        self.v = self.config['v']
        self.num_actions = num_actions
        
        self.action_history = np.full(horizon + 1, -1, dtype=int)
        self.reward_history = np.zeros(horizon + 1)
        self.set_arms = False 

        self.init_params = {
            'num_actions': num_actions, 'horizon': horizon,
            'xi': self.xi, 'B': B, 'config': self.config
        }
        self.B = B
        self.R = 0.01 
        self.lambda_ = 0.01
        self.power=0.8
        self.eta = None

        self.information_list = []
        self.set_information = False
        
        self.Phi = None
        self._cached_arms_hash = None
        
        self.gram = None
        self.z = None
        self.action_counts = None
        
        self.M_buffer = None 
        self.V_buffer = None 
        self.Phi_T = None    

    def select_arm(self, arms, PT):
        t = self.t
        num_actions = len(arms)
        
        arms_hash = hash(arms.data.tobytes()) if hasattr(arms, 'data') else id(arms)

        if not self.set_arms or self._cached_arms_hash != arms_hash:
            self.set_arms = True
            
            K = util.kernel_matrix(arms, self.config['kernel'])
            self.Phi = np.array(np.linalg.cholesky(K + 1e-8 * np.eye(self.num_actions)), order='F')
            self.Phi_T = np.array(self.Phi.T, order='F') 
            self._cached_arms_hash = arms_hash
            
            if self.gram is None:
                self.gram = np.zeros((num_actions, num_actions), order='F')
                self.z = np.zeros(num_actions)
                self.action_counts = np.zeros(num_actions)
                
                self.M_buffer = np.zeros((num_actions, num_actions), order='F')
                self.V_buffer = np.zeros((num_actions, num_actions), order='F')
            else:
                self.reset()
            
            self.d = arms[0].size
            self.information_gain = util.InformationGain(self.Phi, self.horizon, self.lambda_)
            self.information_gain.get = lambda t: np.log(t)**(self.d*self.power)
            
            ig_T = self.information_gain.get(self.T)
            pt_val = PT if PT > 0 else 1e-9
            self.eta = 1 - (ig_T)**(-1/4) * (((pt_val/self.T)**0.5))

        if not self.set_information:
            self.set_information = True
            steps = np.arange(1, self.T + 1)
            ig_vals = np.array([self.information_gain.get(x) for x in steps])
            self.information_list = self.B + self.R * np.sqrt(
                2 * (ig_vals + np.log(1 / self.tol)))/np.sqrt(self.lambda_)

        log_eta = np.log(self.eta)
        w_t = np.exp(-t * log_eta)

        diag_val = self.lambda_ * np.sqrt(w_t)
        fast_add_diagonal(self.M_buffer, self.gram, diag_val)
        
        try:
            spla.cholesky(self.M_buffer, lower=True, overwrite_a=True, check_finite=False)
            
            alpha = spla.cho_solve((self.M_buffer, True), self.z, check_finite=False)
            mu = self.Phi @ alpha

            np.copyto(self.V_buffer, self.Phi_T)

            spla.solve_triangular(self.M_buffer, self.V_buffer, lower=True, 
                                  overwrite_b=True, check_finite=False)
            
            sigma_diag = np.sum(self.V_buffer**2, axis=0)
            sigma = (self.lambda_ * w_t) * sigma_diag
            
        except np.linalg.LinAlgError:
            mu = np.zeros(num_actions)
            sigma = np.zeros(num_actions)

        idx = int(t) if t < len(self.information_list) else -1
        beta_t = self.information_list[idx]
        
        ucb_values = mu + beta_t * np.sqrt(sigma)
        selected_index = np.argmax(ucb_values)
        self.last_selected_index = selected_index

        return selected_index
    
    def update_statistics(self, arm_index, reward):
        t = self.t
        self.action_history[t] = arm_index
        self.reward_history[t] = reward
        
        w_t = np.exp(-t * np.log(self.eta))
        
        u = self.Phi[arm_index, :] 
        
        fast_weighted_outer_update(self.gram, u, w_t)
        
        self.z[arm_index] += reward * w_t
        self.action_counts[arm_index] += 1

    def reset(self):
        self.arms = []
        self.rewards = []
        
        if self.gram is not None:
            self.gram.fill(0.0)
            self.z.fill(0.0)
            self.action_counts.fill(0.0)
        
        self.set_arms = False

    def re_init(self):
        super().re_init()
        self.action_history = np.full(self.horizon + 1, -1, dtype=int)
        self.reward_history = np.zeros(self.horizon + 1)
        
        self.gram = None
        self.z = None
        self.action_counts = None
        self.M_buffer = None
        self.V_buffer = None
        self.Phi_T = None
        
        self.set_arms = False
        self.Phi = None
        self._cached_arms_hash = None
        self.information_list = []
        self.set_information = False

    def __str__(self):
        return "GP-UCB-W"