import numpy as np
from typing import Optional
from math import log
from src.core.bandit import BanditAlgorithm
from src.core.detection import detect_change, kl_divergence


def _gaussian_divergence(p: float, q: float, variance: Optional[float] = None) -> float:
    if variance is None:
        raise ValueError("variance must be provided for quadratic divergence.")
    return kl_divergence(p, q, mode="quadratic", variance=variance)

class DAL_LB(BanditAlgorithm):
    def __init__(self, num_actions, horizon, noise_variance, d, delta, r_lambda, S, L, R, *, detector=None):

        super().__init__(num_actions, horizon)

        self.d = d
        self.delta = delta
        self.r_lambda = r_lambda
        self.S = S
        self.L = L
        self.R = R
        self.noise_variance = noise_variance

        self.init_params = {
            'num_actions': num_actions,
            'horizon': horizon,
            'noise_variance': noise_variance,
            'd': d,
            'delta': delta,
            'r_lambda': r_lambda,
            'S': S,
            'L': L,
            'R': R,
            'detector': detector,
        }

        self.c_1 = np.sqrt(self.r_lambda) * self.S
        self.c_2 = 2 * log(1 / self.delta)
        self.inv_r_lambda = 1 / self.r_lambda

        self.z = np.zeros(self.d)
        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = self.inv_r_lambda * np.identity(self.d)
        self.theta_hat = np.zeros(self.d)

        self.is_change = False
        self.tau = 0

        self.ChangePoints = []
        self.SUMS = {i: [] for i in range(self.num_actions)}


        self.k=1
        self.alpha = 0.05 * np.sqrt(self.T ** (-0.8) * self.num_actions * np.log(self.T))
        self.TotalNumber = {i: 0 for i in range(self.num_actions)}
        self.TotalSum = {i: 0 for i in range(self.num_actions)}
        self.forced_exploration = False
        self.explor_freq = int(np.ceil(self.num_actions / self.alpha))
        self.chosen_arm = 0
        self.change_detector = detector or detect_change

    def select_arm(self, arms):

        self.all_arms = np.array(arms)
        self.arms=np.array(arms)
        if self.indep_arms is None:
            self.get_indep_arms()
        self.alpha= 0.001*np.sqrt(self.k* self.N_e * np.log(self.T)/self.T)
        self.explor_freq = int(np.ceil(self.N_e / self.alpha))

        self.theta_hat = self.inv_V @ self.z  
        x_num = len(arms) 

        beta = self.c_1 + self.R * np.sqrt(self.c_2 + self.d * log(1 + (self.t * self.L ** 2) /
                                                                   (self.r_lambda * self.d)))

       
        arms_inv_V = self.arms @ self.inv_V 
        quadratic_forms = np.sum(self.arms * arms_inv_V, axis=1)
        sqrt_terms = np.sqrt(quadratic_forms)
        ucb_s = self.arms @ self.theta_hat + beta * sqrt_terms
        
        mixer = np.random.random(ucb_s.size)  
        ucb_indices = np.lexsort((mixer, ucb_s))  
        chosen_arm = ucb_indices[-1]  
        I=int((self.t-self.tau)%self.explor_freq)
        if I<self.N_e:
           target_vec = self.indep_arms[I]
           chosen_arm = int(np.flatnonzero((arms == target_vec).all(axis=1))[0])
           self.forced_exploration = True
        return chosen_arm

    def update_statistics(self, x, y):
        self.chosen_arm = x
        x_arm = self.arms[x].reshape(-1, 1)  

        self.V += x_arm @ x_arm.T
        self.update_inv_V(x_arm)

        self.z += y * x_arm.flatten()

        self.TotalNumber[self.chosen_arm] += 1
        self.TotalSum[self.chosen_arm] += y
        self.SUMS[self.chosen_arm].append(self.TotalSum[self.chosen_arm])

        
       
        delta = 1 / (self.T) ** 3
        nb = self.TotalNumber[self.chosen_arm]

        if nb > 2:
            changed = self.change_detector(
                nb,
                self.SUMS[self.chosen_arm],
                delta,
                divergence=lambda a, b, var=None: _gaussian_divergence(
                    a, b, variance=self.noise_variance
                ),
                variance=self.noise_variance,
            )
            if changed:
                self.ChangePoints.append(self.t)
                self.is_change = True
                self.k += 1
                            

        if self.is_change:
            self.reset()


    def update_inv_V(self, x_arm):

        inv_V_x = self.inv_V @ x_arm
        denominator = 1.0 + (x_arm.T @ inv_V_x)[0, 0]
        numerator = inv_V_x @ inv_V_x.T
        self.inv_V -= numerator / denominator

    def reset(self):

        self.z = np.zeros(self.d)
        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = self.inv_r_lambda * np.identity(self.d)
        self.theta_hat = np.zeros(self.d)

        self.SUMS = {i: [] for i in range(self.num_actions)}
        self.TotalNumber = {i: 0 for i in range(self.num_actions)}
        self.TotalSum = {i: 0 for i in range(self.num_actions)}
        self.forced_exploration = False
        self.explor_freq = int(np.ceil(self.num_actions / self.alpha))
        self.chosen_arm = 0
        self.is_change = False
        self.tau = self.t


    def re_init(self):

        super().re_init()

        self.z = np.zeros(self.d)
        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = self.inv_r_lambda * np.identity(self.d)
        self.theta_hat = np.zeros(self.d)

        self.ChangePoints = []
        self.SUMS = {i: [] for i in range(self.num_actions)}
        self.TotalNumber = {i: 0 for i in range(self.num_actions)}
        self.TotalSum = {i: 0 for i in range(self.num_actions)}
        self.forced_exploration = False
        self.explor_freq = int(np.ceil(self.num_actions / self.alpha))
        self.chosen_arm = 0
        self.is_change = False
        self.tau = 0
        self.k=1
        self.indep_arms=None

    def __str__(self):
        return 'LB-StaticUCB'
