import numpy as np
from math import log
from numpy.linalg import pinv, norm, solve
from src.utils import sigmoid
from src.core.bandit import BanditAlgorithm

class GLB_dGLUCB(BanditAlgorithm):
    def __init__(self, num_actions, horizon, d, delta, r_lambda, S, L, R, k_mu, c_mu, gamma=None):
        super().__init__(num_actions, horizon)
        self.d = d
        self.delta = delta
        self.r_lambda = r_lambda
        self.S = S
        self.L = L
        self.R = R
        self.k_mu = k_mu
        self.c_mu = c_mu
        
        self.gamma = float(gamma) if gamma is not None else None
        
        self.c_1 = np.sqrt(self.r_lambda) * self.c_mu * self.S
        self.c_2 = 2 * log(1 / self.delta)
        
        self.t = 0
        self.lazy_update_fr = 10
        self.ctr = 0
        
        self.theta_hat = np.zeros(self.d)
        self.theta_tilde = np.zeros(self.d)
        
        self.V = self.r_lambda * np.eye(self.d)
        self.inv_V = (1.0 / self.r_lambda) * np.eye(self.d)
        
        self.arms_history = []
        self.rewards_history = []
        self.last_arms = None
        self.params_set = self.gamma is not None
        self.weights = None

    def _auto_tune(self, P_T):
        val = P_T / (np.sqrt(self.d) * self.T)
        self.gamma = float(1.0 - (val ** (2.0 / 5.0)))
        self.gamma = max(0.5, min(0.9999, self.gamma))
        self.params_set = True

    def select_arm(self, arms, pt=None, **kwargs):
        self.last_arms = arms
        if not self.params_set:
            self._auto_tune(pt if pt is not None else 0.0)

        log_term = np.log(1 + self.L**2 / (self.d * self.r_lambda * (1 - self.gamma**2 + 1e-9)))
        beta_t = self.c_1 + self.R * np.sqrt(self.c_2 + self.d * log_term)
        coeff = (2 * self.k_mu / self.c_mu)

        arms_arr = np.array(arms)
        quad = np.sum((arms_arr @ self.inv_V) * arms_arr, axis=1)
        quad = np.maximum(quad, 0.0)
        
        pred = sigmoid(arms_arr @ self.theta_tilde)
        ucb_s = pred + coeff * beta_t * np.sqrt(quad)
        
        max_val = np.max(ucb_s)
        candidates = np.where(ucb_s >= max_val - 1e-9)[0]
        return int(np.random.choice(candidates))

    def update_statistics(self, arm, reward, **kwargs):
        x = self.last_arms[arm]
        self.arms_history.append(x)
        self.rewards_history.append(reward)
        
        if self.gamma is None: self.gamma = 0.99

        aat = np.outer(x, x)
        self.V = self.gamma * self.V + aat + (1 - self.gamma) * self.r_lambda * np.eye(self.d)
        self.inv_V = pinv(self.V)

        self.estimator()
        self.t += 1

    def estimator(self):
        self.ctr += 1
        N = len(self.rewards_history)
        
        if self.weights is None:
            self.weights = np.array([1.0])
        else:
            self.weights = np.append(self.weights * self.gamma, 1.0)

        if N < self.d or (self.ctr % self.lazy_update_fr != 0 and N > 50):
            return

        theta = self.theta_hat.copy()
        X = np.array(self.arms_history)
        Y = np.array(self.rewards_history)
        W = self.weights
        lambda_eye = self.r_lambda * self.c_mu * np.eye(self.d)

        for _ in range(5):
            scores = X @ theta
            preds = sigmoid(scores)
            
            grad = self.r_lambda * self.c_mu * theta + (X.T * W) @ (preds - Y)
            
            sig_prime = preds * (1 - preds)
            H = (X.T * (W * sig_prime)) @ X + lambda_eye
            
            try:
                delta = solve(H, grad)
                theta -= delta
                if norm(delta) < 1e-4: break
            except np.linalg.LinAlgError:
                break
        
        self.theta_hat = theta
        
        n_theta = norm(self.theta_hat)
        if n_theta > self.S:
            self.theta_tilde = self.theta_hat * (self.S / n_theta)
        else:
            self.theta_tilde = self.theta_hat

    def re_init(self):
        super().re_init()
        self.t = 0
        self.ctr = 0
        self.theta_hat.fill(0.0)
        self.theta_tilde.fill(0.0)
        self.V = self.r_lambda * np.eye(self.d)
        self.inv_V = (1.0 / self.r_lambda) * np.eye(self.d)
        self.arms_history = []
        self.rewards_history = []
        self.weights = None
        self.last_arms = None
        if not self.params_set:
             self.gamma = None

    def __str__(self):
        g_str = f"{self.gamma:.4f}" if self.gamma is not None else "Auto"
        return f'GLB-dGLUCB(g={g_str})'