from BanditAlgorithm import BanditAlgorithm
from utils import sigmoid, dsigmoid
import numpy as np
from math import log
from numpy.linalg import pinv
from scipy.optimize import minimize, NonlinearConstraint
from numba import njit

@njit
def kl_numba(p, q, var):
    eps = 1e-12
    p = min(max(p, eps), 1 - eps)
    q = min(max(q, eps), 1 - eps)
    return p * np.log(p / q) + (1 - p) * np.log((1 - p) / (1 - q))

@njit
def klSG(p, q, var):
    return ((p - q) ** 2) / (2 * var)

@njit
def beta_numba(n, delta):
    return np.log((n ** 1.5) / delta)

@njit
def change_detection(nb, sums, kl_numba, delta, noise_variance):
    check = 0
    s = 1
    while s < nb and check == 0:
        draw1 = s
        draw2 = nb - s
        mu1 = sums[s - 1] / draw1
        mu2 = (sums[nb - 1] - sums[s - 1]) / draw2
        mu = sums[nb - 1] / nb
        kl_val = draw1 * kl_numba(mu1, mu, noise_variance) + draw2 * kl_numba(mu2, mu, noise_variance)
        if kl_val > beta_numba(nb, delta):
            check += 1
        s += 1
    return check

class DAL_GLB(BanditAlgorithm):
    def __init__(self, num_actions, horizon, noise_variance, d, delta, r_lambda, S, L, R, k_mu, c_mu):
        super().__init__(num_actions, horizon)

        self.d = d
        self.delta = delta
        self.r_lambda = r_lambda
        self.S = S
        self.L = L
        self.R = R
        self.k_mu = k_mu
        self.c_mu = c_mu

        self.init_params={
        'num_actions':num_actions,
        'horizon':horizon,
        'noise_variance':noise_variance,
        'd':d,
        'delta':delta,
        'r_lambda':r_lambda,
        'S':S,
        'L':L,
        'R':R,
        'k_mu':k_mu,
        'c_mu':c_mu
        }

        self.c_1 = np.sqrt(self.r_lambda) * c_mu * self.S
        self.c_2 = 2 * log(1 / self.delta)
        self.const_k_mu_over_c_mu = (2 * self.k_mu) / self.c_mu
        self.inv_r_lambda = 1 / self.r_lambda

        self.t = 0 
        self.ctr = 0
        self.lazy_update_fr = 5
        self.theta_hat = np.zeros(self.d)  
        self.theta_tilde = np.zeros(self.d)
        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = self.inv_r_lambda * np.identity(self.d)
        self.H = self.r_lambda * self.c_mu * np.eye(self.d)  
        self.arms = []  
        self.rewards = []  
        self.noise_variance = noise_variance

        self.is_change = False
        self.tau = 0
        self.ChangePoints = []
        self.SUMS = {i: [] for i in range(self.num_actions)}
        self.k=1
        self.alpha = 0.05 * np.sqrt(self.T ** (-0.8) * self.num_actions * np.log(self.T))
        self.TotalNumber = {i: 0 for i in range(self.num_actions)}
        self.TotalSum = {i: 0 for i in range(self.num_actions)}
        self.forced_exploration = False
        self.explor_freq = int(np.ceil(self.num_actions / self.alpha))
        self.chosen_arm = 0

    def select_arm(self, arms, changepoints):
        self.my_changepoints = changepoints
        self.all_arms = np.array(arms)

        if self.indep_arms is None:
            self.get_indep_arms()
        self.alpha= 0.001*np.sqrt(self.k* self.N_e * np.log(self.T)/self.T)
        self.explor_freq = int(np.ceil(self.N_e / self.alpha))
        x_num = len(arms)
        beta_t = self.c_1 + self.R* np.sqrt(self.c_2 + self.d * np.log(1 + self.t / (self.r_lambda * self.d)))

        dot_products = np.dot(self.all_arms, self.theta_tilde)
        sigmoid_values = sigmoid(dot_products)
        arms_inv_V = self.all_arms @ self.inv_V  
        quadratic_forms = np.sum(self.all_arms * arms_inv_V, axis=1)
        sqrt_terms = np.sqrt(quadratic_forms*self.noise_variance)
        ucb_s = sigmoid_values + self.const_k_mu_over_c_mu * beta_t * sqrt_terms

        mixer = np.random.random(ucb_s.size)  
        ucb_indices = np.lexsort((mixer, ucb_s))  
        chosen_arm = ucb_indices[-1]  
        

        I=int((self.t-self.tau)%self.explor_freq)
        if I<self.N_e:
           target_vec = self.indep_arms[I]
           chosen_arm = int(np.flatnonzero((arms == target_vec).all(axis=1))[0])
           self.forced_exploration = True
        return chosen_arm

    def update_statistics(self, x, y):
        self.chosen_arm = x
        x = self.all_arms[x] 
        self.arms.append(x)
        self.rewards.append(y)

        x = x.reshape(-1, 1) 
        self.V += x @ x.T
        self.update_inv_V(x)

        self.estimator()

        self.TotalNumber[self.chosen_arm] += 1
        self.TotalSum[self.chosen_arm] += y
        self.SUMS[self.chosen_arm].append(self.TotalSum[self.chosen_arm])

        
        delta = 1 /np.sqrt(self.T)

        nb = self.TotalNumber[self.chosen_arm]
        check=0
        if nb > 0:
            sums = np.array(self.SUMS[self.chosen_arm], dtype=np.float64)
            check = change_detection(nb, sums, klSG, delta, self.noise_variance)

            if check > 0:
                self.ChangePoints.append(self.t)
                self.is_change = True
                self.k+=1
               

        if self.is_change:
            self.reset()

    def update_inv_V(self, x):
        inv_V_x = self.inv_V @ x
        denominator = 1.0 + (x.T @ inv_V_x)[0, 0]
        numerator = inv_V_x @ inv_V_x.T
        self.inv_V -= numerator / denominator

    def estimator(self):
        if self.ctr % self.lazy_update_fr == 0 or len(self.rewards) < 200:
            arms = np.array(self.arms)
            rewards = np.array(self.rewards)

            theta_hat = self.theta_hat
            hessian = self.H
            for _ in range(5):
                coeffs = sigmoid(arms @ theta_hat)[:, None]
                y = coeffs - rewards[:, None]
                grad = self.r_lambda * self.c_mu * theta_hat + np.sum(y * arms, axis=0)
                hessian = arms.T @ (coeffs * (1 - coeffs) * arms) + self.r_lambda * self.c_mu * np.eye(self.d)
                theta_hat -= np.linalg.solve(hessian, grad)
            self.theta_hat = theta_hat
            self.H = hessian
        self.ctr += 1

        if np.linalg.norm(self.theta_hat) < self.S:
            self.theta_tilde = self.theta_hat
        else:
            self.theta_tilde = self.projection_new()

    def projection_new(self):
        norm_theta = np.linalg.norm(self.theta_hat)
        if norm_theta <= self.S:
            return self.theta_hat
        else:
            return (self.theta_hat / norm_theta) * self.S

    def reset(self):
        self.ctr = 0
        self.lazy_update_fr = 5
        self.theta_hat = np.zeros(self.d)
        self.theta_tilde = np.zeros(self.d)
        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = self.inv_r_lambda * np.identity(self.d)
        self.H = self.r_lambda * self.c_mu * np.eye(self.d)
        self.arms = []
        self.rewards = []
        self.SUMS = {i: [] for i in range(self.num_actions)}
        self.TotalNumber = {i: 0 for i in range(self.num_actions)}
        self.TotalSum = {i: 0 for i in range(self.num_actions)}
        self.forced_exploration = False
        self.explor_freq = int(np.ceil(self.num_actions / self.alpha))
        self.chosen_arm = 0
        self.is_change = False
        self.tau = self.t

    def re_init(self):
        super().re_init()
        self.ctr = 0
        self.lazy_update_fr = 5
        self.theta_hat = np.zeros(self.d)
        self.theta_tilde = np.zeros(self.d)
        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = self.inv_r_lambda * np.identity(self.d)
        self.H = self.r_lambda * self.c_mu * np.eye(self.d)
        self.arms = []
        self.rewards = []
        self.SUMS = {i: [] for i in range(self.num_actions)}
        self.TotalNumber = {i: 0 for i in range(self.num_actions)}
        self.TotalSum = {i: 0 for i in range(self.num_actions)}
        self.forced_exploration = False
        self.explor_freq = int(np.ceil(self.num_actions / self.alpha))
        self.chosen_arm = 0
        self.is_change = False
        self.tau = 0
        self.k=1
        self.indep_arms=None

    def __str__(self):
        return 'LB-StaticUCB'
