import numpy as np
from math import log
from numpy.linalg import pinv, norm, solve
from src.utils import sigmoid, dsigmoid
from src.core.bandit import BanditAlgorithm
from scipy.optimize import minimize, NonlinearConstraint
from scipy.linalg import fractional_matrix_power

class GLB_BVD_GLM_UCB(BanditAlgorithm):
    def __init__(self, num_actions, horizon, d, delta, r_lambda, S, L, R, k_mu, c_mu, gamma=None):
        super().__init__(num_actions, horizon)
        self.d = d
        self.delta = delta
        self.r_lambda = r_lambda
        self.S = S
        self.L = L
        self.R = R
        self.k_mu = k_mu
        self.c_mu = c_mu
        
        self.gamma = float(gamma) if gamma is not None else None

        self.c_1 = np.sqrt(self.r_lambda) * self.c_mu * self.S
        self.c_2 = 2 * log(1 / self.delta)

        self.t = 0
        self.ctr = 0
        self.lazy_update_fr = 10
        
        self.theta_hat = np.zeros(self.d)
        self.theta_tilde = np.zeros(self.d)
        
        self.V = self.r_lambda * np.eye(self.d)
        self.inv_V = (1.0 / self.r_lambda) * np.eye(self.d)
        self.tilde_V = self.r_lambda * np.eye(self.d)
        
        self.arms_history = []
        self.rewards_history = []
        self.last_arms = None
        
        self.gamma_2t = 1.0
        self.weights = None
        self.beta = 0.0
        self.params_set = self.gamma is not None

    def _auto_tune(self, P_T):
        val = P_T / (np.sqrt(self.d) * self.T)
        self.gamma = float(1.0 - (val ** (2.0 / 5.0)))
        self.gamma = max(0.5, min(0.9999, self.gamma))
        self.params_set = True

    def select_arm(self, arms, pt=None, **kwargs):
        self.last_arms = arms
        if not self.params_set:
            self._auto_tune(pt if pt is not None else 0.0)

        term_log = 1 + (1 - self.gamma_2t) * (self.L ** 2) / (self.d * self.r_lambda * (1 - self.gamma ** 2 + 1e-9))
        self.beta = self.c_1 + self.R * np.sqrt(self.c_2 + self.d * np.log(term_log))
        coeff = (2 * self.k_mu / self.c_mu)

        arms_arr = np.array(arms)
        quad = np.sum((arms_arr @ self.inv_V) * arms_arr, axis=1)
        quad = np.maximum(quad, 0.0)
        
        pred = sigmoid(arms_arr @ self.theta_tilde)
        ucb_s = pred + coeff * self.beta * np.sqrt(quad)
        
        max_val = np.max(ucb_s)
        candidates = np.where(ucb_s >= max_val - 1e-9)[0]
        return int(np.random.choice(candidates))

    def update_statistics(self, arm, reward, **kwargs):
        x = self.last_arms[arm]
        self.arms_history.append(x)
        self.rewards_history.append(reward)
        
        if self.gamma is None: self.gamma = 0.99

        aat = np.outer(x, x)
        self.V = self.gamma * self.V + aat + (1 - self.gamma) * self.r_lambda * np.eye(self.d)
        self.inv_V = pinv(self.V)
        
        self.tilde_V = (self.gamma**2) * self.tilde_V + aat + (1 - self.gamma**2) * self.r_lambda * np.eye(self.d)
        
        self.gamma_2t *= self.gamma ** 2
        self.estimator()
        self.t += 1

    def estimator(self):
        self.ctr += 1
        N = len(self.rewards_history)
        
        if self.weights is None:
            self.weights = np.array([1.0])
        else:
            self.weights = np.append(self.weights * self.gamma, 1.0)

        if N < self.d or (self.ctr % self.lazy_update_fr != 0 and N > 50):
            return

        theta = self.theta_hat.copy()
        X = np.array(self.arms_history)
        Y = np.array(self.rewards_history)
        W = self.weights
        lambda_eye = self.r_lambda * self.c_mu * np.eye(self.d)

        for _ in range(5):
            scores = X @ theta
            preds = sigmoid(scores)
            grad = self.r_lambda * self.c_mu * theta + (X.T * W) @ (preds - Y)
            sig_prime = preds * (1 - preds)
            H = (X.T * (W * sig_prime)) @ X + lambda_eye
            
            try:
                delta = solve(H, grad)
                theta -= delta
                if norm(delta) < 1e-4: break
            except np.linalg.LinAlgError:
                break

        self.theta_hat = theta
        if norm(self.theta_hat) < self.S:
            self.theta_tilde = self.theta_hat
        else:
            self.theta_tilde = self.projection(self.theta_hat, self.d)[0:self.d]

    def gt_vector(self, theta, arms_arr):
        if len(arms_arr) == 0: return self.r_lambda * self.c_mu * theta
        preds = sigmoid(arms_arr @ theta)
        return np.dot(self.weights * preds, arms_arr) + self.r_lambda * self.c_mu * theta

    def proj_fun(self, theta_eta, arms_arr, d):
        theta = theta_eta[0:d]
        eta = theta_eta[d:]
        
        Vt = fractional_matrix_power(self.tilde_V, 0.5)
        inv_V_sq = self.inv_V @ self.inv_V
        
        g_theta = self.gt_vector(theta, arms_arr)
        g_hat = self.gt_vector(self.theta_hat, arms_arr)
        
        diff = g_theta + self.beta * np.dot(Vt, eta) - g_hat
        return np.dot(diff, np.dot(inv_V_sq, diff))

    def projection(self, theta_init, d):
        arms_arr = np.array(self.arms_history)
        fun = lambda t: self.proj_fun(t, arms_arr, d)
        
        norm1 = lambda t: np.linalg.norm(t[0:d])
        norm2 = lambda t: np.linalg.norm(t[d:])
        
        constraint1 = NonlinearConstraint(norm1, 0, self.S)
        constraint2 = NonlinearConstraint(norm2, 0, 1.0)
        
        x0 = np.zeros(2*d)
        x0[0:d] = theta_init
        
        res = minimize(fun, x0=x0, method='SLSQP', constraints=[constraint1, constraint2])
        return res.x

    def re_init(self):
        super().re_init()
        self.t = 0
        self.ctr = 0
        self.theta_hat.fill(0.0)
        self.theta_tilde.fill(0.0)
        self.V = self.r_lambda * np.eye(self.d)
        self.inv_V = (1.0 / self.r_lambda) * np.eye(self.d)
        self.tilde_V = self.r_lambda * np.eye(self.d)
        
        self.arms_history = []
        self.rewards_history = []
        self.gamma_2t = 1.0
        self.weights = None
        self.last_arms = None
        if not self.params_set:
            self.gamma = None

    def __str__(self):
        g_str = f"{self.gamma:.4f}" if self.gamma is not None else "Auto"
        return f'GLB-BVD-GLM-UCB(g={g_str})'