import numpy as np
from math import log
from numpy.linalg import norm, solve
from src.utils import sigmoid
from scipy.optimize import minimize, NonlinearConstraint
from src.core.bandit import BanditAlgorithm

class GLB_RestartUCB(BanditAlgorithm):
    def __init__(self, num_actions, horizon, d, delta, r_lambda, S, L, R, k_mu, c_mu, H=None):
        super().__init__(num_actions, horizon)
        self.d = d
        self.delta = delta
        self.r_lambda = r_lambda
        self.S = S
        self.L = L
        self.R = R
        self.k_mu = k_mu
        self.c_mu = c_mu
        
        self.H = int(H) if H is not None else None
        
        self.c_1 = np.sqrt(self.r_lambda) * c_mu * self.S
        self.c_2 = 2 * log(1 / self.delta)

        self.internal_t = 1
        self.j = 1
        self.tau = 0
        
        self.lazy_update_fr = 10 
        self.ctr = 0
        
        self.theta_hat = np.zeros(self.d)
        self.theta_tilde = np.zeros(self.d)
        
        self.inv_V = (1.0 / self.r_lambda) * np.eye(self.d)
        
        self.arms_history = []
        self.rewards_history = []
        self.last_arms = None
        
        self.t_epoch = None
        self.params_set = self.H is not None

    def _auto_tune(self, P_T):
        h_val = (self.d**0.25) * (self.T**0.5) * ((1 + P_T)**(-0.5))
        self.H = int(max(1, float(h_val)))
        self.t_epoch = self.tau + self.H - 1
        self.params_set = True

    def select_arm(self, arms, pt=None, **kwargs):
        self.last_arms = arms
        if not self.params_set:
            self._auto_tune(pt if pt is not None else 0.0)
        if self.t_epoch is None:
            self.t_epoch = self.tau + self.H - 1

        eff_t = self.internal_t - self.tau
        val = 1 + (eff_t * self.L**2) / (self.r_lambda * self.d)
        beta_t = self.c_1 + self.R * np.sqrt(self.c_2 + self.d * log(val))
        coeff = (2 * self.k_mu / self.c_mu)

        arms_arr = np.array(arms)
        quad_form = np.sum((arms_arr @ self.inv_V) * arms_arr, axis=1)
        quad_form = np.maximum(quad_form, 0.0) 
        
        pred = sigmoid(arms_arr @ self.theta_tilde)
        ucb_s = pred + coeff * beta_t * np.sqrt(quad_form)
        
        max_val = np.max(ucb_s)
        candidates = np.where(ucb_s >= max_val - 1e-9)[0]
        return int(np.random.choice(candidates))

    def update_statistics(self, arm, reward, **kwargs):
        x = self.last_arms[arm]
        self.arms_history.append(x)
        self.rewards_history.append(reward)
        
        v_x = self.inv_V @ x
        denom = 1.0 + x @ v_x
        self.inv_V -= np.outer(v_x, v_x) / denom
        
        self.estimator()
        self.internal_t += 1
        
        if self.t_epoch is not None and self.internal_t > self.t_epoch:
            self._restart()

    def _restart(self):
        if self.H is None: return 
        self.j += 1
        self.tau = (self.j - 1) * self.H
        self.internal_t = self.tau + 1
        self.t_epoch = self.tau + self.H - 1
        
        self.ctr = 0
        self.theta_hat.fill(0.0)
        self.theta_tilde.fill(0.0)
        self.inv_V = (1.0 / self.r_lambda) * np.eye(self.d)
        self.arms_history = []
        self.rewards_history = []

    def estimator(self):
        self.ctr += 1
        N = len(self.rewards_history)
        
        if N < self.d or (self.ctr % self.lazy_update_fr != 0 and N > 50):
            return

        theta = self.theta_hat.copy()
        X = np.array(self.arms_history)
        Y = np.array(self.rewards_history)
        lambda_eye = self.r_lambda * self.c_mu * np.eye(self.d)

        for _ in range(5):
            scores = X @ theta
            preds = sigmoid(scores)
            grad = self.r_lambda * self.c_mu * theta + X.T @ (preds - Y)
            
            weights = preds * (1 - preds)
            H = (X.T * weights) @ X + lambda_eye
            
            try:
                delta = solve(H, grad)
                theta -= delta
                if norm(delta) < 1e-4: break
            except np.linalg.LinAlgError:
                break
                
        self.theta_hat = theta
        
        n_theta = norm(self.theta_hat)
        if n_theta > self.S:
            self.theta_tilde = self.theta_hat * (self.S / n_theta)
        else:
            self.theta_tilde = self.theta_hat

    def re_init(self):
        super().re_init()
        self.internal_t = 1
        self.j = 1
        self.tau = 0
        self.t_epoch = None
        self.ctr = 0
        self.theta_hat.fill(0.0)
        self.theta_tilde.fill(0.0)
        self.inv_V = (1.0 / self.r_lambda) * np.eye(self.d)
        self.arms_history = []
        self.rewards_history = []
        self.last_arms = None
        
        if self.H is not None:
            pass 
        else:
            self.params_set = False

    def __str__(self):
        return f'GLB-RestartUCB(H={self.H})'