import numpy as np
from math import log
from numpy.linalg import pinv, norm, solve
from src.utils import sigmoid, dsigmoid
from scipy.optimize import minimize, NonlinearConstraint
from src.core.bandit import BanditAlgorithm

class SCB_dGLUCB(BanditAlgorithm):
    def __init__(self, num_actions, horizon, d, delta, r_lambda, S, L, R, k_mu, c_mu, gamma=None):
        super().__init__(num_actions, horizon)
        self.d = d
        self.delta = delta
        self.r_lambda = r_lambda
        self.S = S
        self.L = L
        self.R = R
        self.k_mu = k_mu
        self.c_mu = c_mu
        
        self.gamma = float(gamma) if gamma is not None else None
        
        self.c_1 = np.sqrt(self.r_lambda * self.c_mu) * (self.S + 0.5)
        self.c_2 = (2 / np.sqrt(self.r_lambda * self.c_mu)) * log(1 / self.delta)
        
        self.t = 0
        self.ctr = 0
        self.lazy_update_fr = 10
        
        self.theta_hat = np.zeros(self.d)
        self.theta_tilde = np.zeros(self.d)
        
        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = (1.0 / self.r_lambda) * np.identity(self.d)
        
        self.Hessian = self.r_lambda * self.c_mu * np.eye(self.d)
        
        self.arms_history = []
        self.rewards_history = []
        self.weights = None
        self.last_arms = None
        self.params_set = self.gamma is not None

    def _auto_tune(self, P_T):
        val = np.sqrt(P_T / (self.d * self.T))
        self.gamma = float(1.0 - max(1.0/self.T, val))
        self.gamma = max(0.5, min(0.9999, self.gamma))
        self.params_set = True

    def select_arm(self, arms, pt=None, **kwargs):
        self.last_arms = arms
        if not self.params_set:
            self._auto_tune(pt if pt is not None else 0.0)

        x_num = len(arms)
        ucb_s = np.zeros(x_num)
        
        eff_t = min(self.t + 1, 1.0 / (1.0 - self.gamma))
        
        log_term = np.log(4 * (1 + ((self.k_mu * eff_t) / (self.r_lambda * self.c_mu * self.d))))
        beta_t = self.c_1 + self.c_2 + (self.d / np.sqrt(self.r_lambda * self.c_mu)) * log_term
        
        bonus_coeff = (np.sqrt(4 + 8 * self.S) * self.k_mu / np.sqrt(self.c_mu))

        arms_arr = np.array(arms)
        
        quad = np.sum((arms_arr @ self.inv_V) * arms_arr, axis=1)
        quad = np.maximum(quad, 0.0)
        
        pred = sigmoid(arms_arr @ self.theta_tilde)
        ucb_s = pred + bonus_coeff * beta_t * np.sqrt(quad)
        
        max_val = np.max(ucb_s)
        candidates = np.where(ucb_s >= max_val - 1e-9)[0]
        return int(np.random.choice(candidates))

    def update_statistics(self, arm, reward, **kwargs):
        x = self.last_arms[arm]
        self.arms_history.append(x)
        self.rewards_history.append(reward)
        
        if self.gamma is None: self.gamma = 0.99

        aat = np.outer(x, x)
        self.V = self.gamma * self.V + aat + (1 - self.gamma) * self.r_lambda * np.identity(self.d)
        self.inv_V = pinv(self.V)
        
        self.estimator()
        self.t += 1

    def estimator(self):
        self.ctr += 1
        N = len(self.rewards_history)
        
        if self.weights is None:
            self.weights = np.array([1.0])
        else:
            self.weights = np.append(self.weights * self.gamma, 1.0)

        if N < self.d or (self.ctr % self.lazy_update_fr != 0 and N > 50):
            return

        theta = self.theta_hat.copy()
        X = np.array(self.arms_history)
        Y = np.array(self.rewards_history)
        W = self.weights
        lambda_eye = self.r_lambda * self.c_mu * np.eye(self.d)

        for _ in range(5):
            scores = X @ theta
            preds = sigmoid(scores)
            
            grad = self.r_lambda * self.c_mu * theta + (X.T * W) @ (preds - Y)
            
            sig_prime = preds * (1 - preds)
            H = (X.T * (W * sig_prime)) @ X + lambda_eye
            
            try:
                delta_step = solve(H, grad)
                theta -= delta_step
                if norm(delta_step) < 1e-4: break
            except np.linalg.LinAlgError:
                break
        
        self.theta_hat = theta
        self.Hessian = H
        
        if norm(self.theta_hat) < self.S:
            self.theta_tilde = self.theta_hat
        else:
            self.theta_tilde = self.projection(self.theta_hat)

    def gt(self, theta, arms):
        coeffs = sigmoid(arms @ theta)[:, None]
        gt_val = np.sum(self.weights[:, None] * arms * coeffs, axis=0) + self.r_lambda * self.c_mu * theta
        return gt_val

    def hessian_at(self, theta, arms):
        coeffs = dsigmoid(arms @ theta)[:, None]
        ht = np.dot(arms.T, self.weights[:, None] * coeffs * arms) + self.r_lambda * self.c_mu * np.eye(self.d)
        return ht

    def proj_fun(self, theta_target, arms):
        inv_H = pinv(self.Hessian)
        
        g_theta = self.gt(theta_target, arms)
        g_hat = self.gt(self.theta_hat, arms)
        
        diff_gt = g_theta - g_hat
        return np.dot(diff_gt, np.dot(inv_H, diff_gt))

    def projection(self, theta_init):
        arms = np.array(self.arms_history)
        fun = lambda t: self.proj_fun(t, arms)
        
        norm_cons = lambda t: np.linalg.norm(t)
        constraint = NonlinearConstraint(norm_cons, 0, self.S)
        
        opt = minimize(fun, x0=theta_init, method='SLSQP', constraints=constraint)
        return opt.x

    def re_init(self):
        super().re_init()
        self.t = 0
        self.ctr = 0
        self.theta_hat.fill(0.0)
        self.theta_tilde.fill(0.0)
        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = (1.0 / self.r_lambda) * np.identity(self.d)
        self.Hessian = self.r_lambda * self.c_mu * np.eye(self.d)
        
        self.arms_history = []
        self.rewards_history = []
        self.weights = None
        self.last_arms = None
        if self.gamma is None: 
            self.params_set = False

    def __str__(self):
        g_str = f"{self.gamma:.4f}" if self.gamma is not None else "Auto"
        return f'SCB-dGLUCB(g={g_str})'