
import numpy as np
from BanditAlgorithm import BanditAlgorithm
import utils_new as util


from numba import njit

@njit
def kl_numba(p, q):
    res = 0.0
    if p != q:
        eps = np.finfo(np.float64).eps
        p = max(p, eps)
        p = min(p, 1 - eps)
        res = p * np.log(p / q) + (1 - p) * np.log((1 - p) / (1 - q))
    return res

@njit
def klSG(p, q,var):
    return ((p-q) ** 2)/(2*var)

@njit
def beta_numba(n, delta):
    return np.log((4 * n ** 1.5) / delta)

@njit
def change_detection(nb, sums,kl_numba, delta,noise_variance):
    check = 0
    s = 1
    while s < nb and check == 0:
        draw1 = s
        draw2 = nb - s
        mu1 = sums[s - 1] / draw1
        mu2 = (sums[nb - 1] - sums[s - 1]) / draw2
        mu = sums[nb - 1] / nb
        kl_val = draw1 * kl_numba(mu1, mu,noise_variance) + draw2 * kl_numba(mu2, mu,noise_variance)
        if kl_val > beta_numba(nb, delta):
            check += 1
        s += 1
    return check

class DAL_KB(BanditAlgorithm):
    def __init__(self, num_actions, horizon, noise_variance, B, config):
        super().__init__(num_actions, horizon)
        
       
        self.config = {
            'tol': 0.0001,
            'window': 0,
            'kernel': {
                'name': 'rbf',
                'length_scale': 0.2,
            },
            'lambda': 20,
            'v': 0.01,
            'seed': None,
        }

        self.B=B
        
        self.config.update(config)

        self.horizon=horizon

        self.noise_variance=noise_variance

        
        self.tol = self.config['tol']
        self.window = self.config['window']
        self.lambda_ = self.config['lambda']
        self.v = self.config['v']
        self.num_actions=num_actions


        self.Nr=20
        
       
        self.set_arms=False 
        self.current_arm_indices=None

        self.next_r=False

        self.init_params={
            'num_actions':num_actions,
            'horizon':horizon,
            'noise_variance':noise_variance,
            'B':B,
            'config':self.config
        }
        self.is_change=False
        self.tau=0


        self.ChangePoints = []
        self.SUMS ={i:[] for i in range(self.num_actions)}

        self.alpha=0.05*np.sqrt(self.T**(-0.8)*self.num_actions*np.log(self.T))
        self.TotalNumber = {i:0 for i in range(self.num_actions)}
        self.TotalSum = {i:0 for i in range(self.num_actions)}
        self.forced_exploration = False
        self.explor_freq = int(np.ceil(self.num_actions / self.alpha))
        self.chosen_arm=0
        self.d=None
        self.B=B
        self.delta=self.tol/(6*(np.log2(self.T)))
        self.R=np.sqrt(0.01*(2*B))
        self.lambda_=self.R**2
        self.k=1


        self.B=B
        self.R=0.1*0.1
        self.lambda_=0.01

        self.set_info=False

    def select_arm(self, arms,K=None):

        t = self.t-self.tau  

        self.alpha= 0.001*np.sqrt(self.k* self.num_actions * np.log(self.T)/self.T)
        self.explor_freq = int(np.ceil(self.num_actions / self.alpha))

        if not self.set_info:
            self.set_info=True
            K = util.kernel_matrix(arms, self.config['kernel'])
            self.Phi = np.linalg.cholesky(K + 1e-8 * np.eye(self.num_actions))
            self.information_gain = util.InformationGain(self.Phi, self.horizon, self.lambda_)
            self.d=arms[0].size
           


        if self.current_arm_indices is None:

            self.current_arm_indices=(np.arange(0,self.num_actions)).flatten()
            self.all_arms=arms
            self.d=arms[0].size
            
            
        
        if t>=self.Nr:self.next_r=True


        selected_index=np.random.choice(self.current_arm_indices)
        self.last_selected_index = selected_index

        I=int((self.t-self.tau)%self.explor_freq)
        if I<self.num_actions:
           selected_index = I

        

        return selected_index
    
    def update_statistics(self, arm_index, reward):


        t = self.t-self.tau
      
        u = self.Phi[[arm_index], :].T
        A = self.gram_inv
            
        self.gram_inv -= A @ u @ u.T @ A / (1 + u.T @ A @ u)
        self.z[arm_index] += reward

        if(self.next_r):
            self.next_r=False
            self.Nr=2*self.Nr
            
            

            mu = np.zeros(self.num_actions)
            sigma = np.zeros(self.num_actions)

            mu = self.Phi @ self.gram_inv @ self.Phi.T @ self.z
            for a in range(self.num_actions):
                u = self.Phi[a, :]
                sigma[a] = self.lambda_ * self.gram_inv.dot(u).dot(u)
            
            alpha=self.B+self.R*np.sqrt((2/self.lambda_)*(np.log((np.log(self.T))/self.delta)))
            c=2*self.B/self.T+self.R*np.sqrt((2/(self.T*self.lambda_)*np.log(4*self.T/self.delta)))

            self.information_gain.get=lambda t:np.log(t)**(self.d+1)
            alpha*=np.sqrt(self.information_gain.get(self.Nr)/self.Nr)*self.lambda_


            
            ucb=mu+alpha*sigma+c
            lcb=mu-alpha*sigma-c
           
            

            selection=np.argwhere(ucb[self.current_arm_indices]>= np.max(lcb[self.current_arm_indices]))
            self.current_arm_indices=self.current_arm_indices[selection]
            self.current_arm_indices=(self.current_arm_indices).flatten()
           
        

        self.chosen_arm=arm_index

        
        self.TotalNumber[self.chosen_arm] += 1
        self.TotalSum[self.chosen_arm] += reward
        self.SUMS[self.chosen_arm].append(self.TotalSum[self.chosen_arm])
        



        delta = 1 / (self.T**(1/6))


        if self.TotalNumber[self.chosen_arm] > 0:
            nb = self.TotalNumber[self.chosen_arm]
            sums = np.array(self.SUMS[self.chosen_arm], dtype=np.float64)
            check = change_detection(nb, sums,klSG, delta,self.noise_variance)
            if check > 0:
                self.ChangePoints.append(self.t)
                self.is_change = True
                self.k+=1

            

    
        if(self.is_change):self.reset()
    

    def reset(self):
        self.is_change=False
        self.arms = []
        self.rewards = []
        self.tau=self.t
        
       
        self.gram_inv = np.eye(self.num_actions) * self.lambda_
        self.z = np.zeros(self.num_actions)
        self.set_arms=False
        self.Nr=20
        self.set_arms=False 
        self.current_arm_indices=None

        self.next_r=False


        self.SUMS ={i:[] for i in range(self.num_actions)}
        self.TotalNumber = {i:0 for i in range(self.num_actions)}
        self.TotalSum = {i:0 for i in range(self.num_actions)}
        self.forced_exploration = False
        self.explor_freq = int(np.ceil(self.num_actions / self.alpha))
        self.chosen_arm=0


    


       
    
    def re_init(self):

        super().re_init()
       
        self.gram_inv = np.eye(self.num_actions) * self.lambda_
        self.z = np.zeros(self.num_actions)
        self.set_arms=False
        self.Nr=20
        self.set_arms=False 
        self.current_arm_indices=None


        self.is_change=False

        self.next_r=False
        self.tau=0
        self.k=1


        self.SUMS ={i:[] for i in range(self.num_actions)}
        self.TotalNumber = {i:0 for i in range(self.num_actions)}
        self.TotalSum = {i:0 for i in range(self.num_actions)}
        self.forced_exploration = False
        self.explor_freq = int(np.ceil(self.num_actions / self.alpha))
        self.chosen_arm=0
        self.set_info=False

