import numpy as np
from math import log
from numpy.linalg import pinv, norm, solve
from src.utils import sigmoid
from src.core.bandit import BanditAlgorithm

class GLB_SWGLUCB(BanditAlgorithm):
    def __init__(self, num_actions, horizon, d, delta, r_lambda, S, L, R, k_mu, c_mu, w=None):
        super().__init__(num_actions, horizon)
        self.d = d
        self.delta = delta
        self.r_lambda = r_lambda
        self.S = S
        self.L = L
        self.R = R
        self.k_mu = k_mu
        self.c_mu = c_mu
        
        self.w = int(w) if w is not None else None
        
        self.c_1 = np.sqrt(self.r_lambda) * self.c_mu * self.S
        self.c_2 = 2 * log(1 / self.delta)
        
        self.lazy_update_fr = 10
        self.ctr = 0
        
        self.theta_hat = np.zeros(self.d)
        self.theta_tilde = np.zeros(self.d)
        self.inv_V = (1.0 / self.r_lambda) * np.eye(self.d)
        
        self.arms_history = []
        self.rewards_history = []
        self.last_arms = None
        self.params_set = self.w is not None
        self.zeta_t = None

    def _auto_tune(self, P_T):
        val = (np.sqrt(self.d) * self.T) / (P_T + 1e-9)
        self.w = int(max(20, float(val ** (2.0 / 3.0)))) 
        self.params_set = True
        self._calc_zeta()

    def _calc_zeta(self):
        val = 1 + (self.w * self.L**2) / (self.r_lambda * self.d)
        self.zeta_t = self.c_1 + self.R * np.sqrt(self.c_2 + self.d * log(val))

    def select_arm(self, arms, pt=None, **kwargs):
        self.last_arms = arms
        if not self.params_set:
            self._auto_tune(pt if pt is not None else 0.0)
        if self.zeta_t is None:
            self._calc_zeta()

        coeff = (2 * self.k_mu / self.c_mu)
        arms_arr = np.array(arms)
        
        quad = np.sum((arms_arr @ self.inv_V) * arms_arr, axis=1)
        quad = np.maximum(quad, 0.0)
        pred = sigmoid(arms_arr @ self.theta_tilde)
        ucb_s = pred + coeff * self.zeta_t * np.sqrt(quad)
        
        max_val = np.max(ucb_s)
        candidates = np.where(ucb_s >= max_val - 1e-9)[0]
        return int(np.random.choice(candidates))

    def update_statistics(self, arm, reward, **kwargs):
        x = self.last_arms[arm]
        self.arms_history.append(x)
        self.rewards_history.append(reward)
        
        if self.w is not None and len(self.arms_history) > self.w:
            self.arms_history.pop(0)
            self.rewards_history.pop(0)
        
        self.estimator()
        self.t += 1

    def estimator(self):
        self.ctr += 1
        N = len(self.rewards_history)
        
        if N < self.d or (self.ctr % self.lazy_update_fr != 0 and N > 50):
            return

        theta = self.theta_hat.copy()
        X = np.array(self.arms_history)
        Y = np.array(self.rewards_history)
        lambda_eye = self.r_lambda * self.c_mu * np.eye(self.d)

        for _ in range(5):
            scores = X @ theta
            preds = sigmoid(scores)
            grad = self.r_lambda * self.c_mu * theta + X.T @ (preds - Y)
            
            weights = preds * (1 - preds)
            H = (X.T * weights) @ X + lambda_eye
            
            try:
                delta = solve(H, grad)
                theta -= delta
                if norm(delta) < 1e-4: 
                    self.inv_V = pinv(H) 
                    break
            except np.linalg.LinAlgError:
                break
                
        self.theta_hat = theta
        
        n_theta = norm(self.theta_hat)
        if n_theta > self.S:
            self.theta_tilde = self.theta_hat * (self.S / n_theta)
        else:
            self.theta_tilde = self.theta_hat

    def re_init(self):
        super().re_init()
        self.t = 0
        self.ctr = 0
        self.theta_hat.fill(0.0)
        self.theta_tilde.fill(0.0)
        self.inv_V = (1.0 / self.r_lambda) * np.eye(self.d)
        self.arms_history = []
        self.rewards_history = []
        self.last_arms = None
        self.zeta_t = None
        if not self.params_set:
            self.w = None

    def __str__(self):
        w_str = str(self.w) if self.w is not None else "Auto"
        return f'GLB-SWGLUCB(w={w_str})'
