# GP_UCB.py

import numpy as np
from BanditAlgorithm import BanditAlgorithm
from sklearn.gaussian_process.kernels import RBF
import utils_new as util 




class GP_UCB_W(BanditAlgorithm):
    def __init__(self, num_actions, horizon, xi,B,config):
        super().__init__(num_actions, horizon)
        
        self.config = {
            'tol': 0.0001,
            'window': 0,
            'kernel': {
                'name': 'rbf',
                'length_scale': 0.2,
            },
            'lambda': 20,
            'v': 0.01,
            'seed': None,
        }
        self.B=B
        self.config.update(config)

        self.xi=xi
        self.fixed_window=False

        self.horizon = horizon
        self.T=horizon
        self.tol = self.config['tol']
        self.window = self.config['window']
        self.lambda_ = self.config['lambda']
        self.v = self.config['v']
        self.num_actions = num_actions
        
        self.action_history = np.full(horizon + 1, -1, dtype=int)
        self.reward_history = np.zeros(horizon + 1)

        self.set_arms = False 

        self.init_params = {
            'num_actions': num_actions,
            'horizon': horizon,
            'xi':self.xi,
            'B':B,
            'config': self.config
        }
        self.eta=None
        self.B=B
        self.R=np.sqrt(0.01*(2*B))
        self.lambda_=self.R**2


        self.B=B
        self.R=0.1*0.1
        self.lambda_=0.01


        self.information_list=[]
        self.set_information=False
        
    def select_arm(self, arms, PT):

        t = self.t
        
        num_actions = len(arms)
        if not self.set_arms:
            self.set_arms = True
            K = util.kernel_matrix(arms, self.config['kernel'])
            self.Phi = np.linalg.cholesky(K + 1e-8 * np.eye(self.num_actions))
            self.gram=np.zeros((self.num_actions,self.num_actions))
            self.z = np.zeros(num_actions)
            self.action_counts = np.zeros(num_actions)
            self.information_gain = util.InformationGain(self.Phi, self.horizon, self.lambda_)
            self.d=arms[0].size



            self.information_gain.get=lambda t:np.log(t)**(self.d+1)


           

            self.eta=1-(self.information_gain.get(self.T))**(-1/4)*(((PT/self.T)**0.5))

        w_t = np.exp(-t * np.log(self.eta))
      
        self.gram_inv = np.linalg.inv(self.gram+np.eye(num_actions) * self.lambda_*w_t)
        mu = np.zeros(num_actions)
        sigma = np.zeros(num_actions)


        if not self.set_information:
            self.set_information=True
            steps=range(1,self.T+1)
            scale=1
            betafun=lambda t: scale*self.B + scale*self.R*np.sqrt(2*(self.information_gain.get(t)/2+np.log(1 / self.tol)))
            self.information_list=list(map(betafun,steps))


        
        if t > 1:
            mu = self.Phi @ self.gram_inv @ self.Phi.T @ self.z
            for a in range(num_actions):
                u = self.Phi[a, :]
                sigma[a] = self.lambda_ * self.gram_inv.dot(u).dot(u)*w_t
        
        

        ucb_values = mu + self.information_list[t]*np.sqrt(sigma)
        selected_index = np.argmax(ucb_values)
        self.last_selected_index = selected_index

        return selected_index
    
    def update_statistics(self, arm_index, reward):

        t = self.t
        self.action_history[t] = arm_index
        self.reward_history[t] = reward
        
        u = self.Phi[[arm_index], :].T

       
        w_t = np.exp(-t * np.log(self.eta))
      

        self.gram+=u@u.T*w_t
        
        self.z[arm_index] += reward*w_t
        self.action_counts[arm_index] += 1


        

    def reset(self):
        self.arms = []
        self.rewards = []
        self.action_history = np.full(self.horizon + 1, -1, dtype=int)
        self.reward_history = np.zeros(self.horizon + 1)
        self.gram_inv = np.eye(self.num_actions) * self.lambda_
        self.z = np.zeros(self.num_actions)
        self.action_counts = np.zeros(self.num_actions)
        self.set_arms = False
        
      
    
    def re_init(self):

        super().re_init()
        self.action_history = np.full(self.horizon + 1, -1, dtype=int)
        self.reward_history = np.zeros(self.horizon + 1)
        self.gram_inv = np.eye(self.num_actions) * self.lambda_
        self.z = np.zeros(self.num_actions)
        self.action_counts = np.zeros(self.num_actions)
        self.set_arms = False

        self.information_list=[]
        self.set_information=False

