import numpy as np
from math import log
from numpy.linalg import pinv, norm, solve
from src.utils import sigmoid, dsigmoid
from scipy.optimize import minimize, NonlinearConstraint
from src.core.bandit import BanditAlgorithm

class SCB_RestartUCB(BanditAlgorithm):
    def __init__(self, num_actions, horizon, d, delta, r_lambda, S, L, R, k_mu, c_mu, H=None):
        super().__init__(num_actions, horizon)
        self.d = d
        self.delta = delta
        self.r_lambda = r_lambda
        self.S = S
        self.L = L
        self.R = R
        self.k_mu = k_mu
        self.c_mu = c_mu
        self.H = int(H) if H is not None else None
        self.T = horizon

        self.c_1 = np.sqrt(self.r_lambda * self.c_mu) * (self.S + 0.5)
        self.c_2 = (2 / np.sqrt(self.r_lambda * self.c_mu)) * log(1 / self.delta)

        self.internal_t = 1
        self.j = 1
        self.tau = 0
        self.t_epoch = None
        
        self.ctr = 0
        self.lazy_update_fr = 5
        
        self.theta_hat = np.zeros(self.d)
        self.theta_tilde = np.zeros(self.d)
        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = (1.0 / self.r_lambda) * np.identity(self.d)
        self.Hessian = self.r_lambda * self.c_mu * np.eye(self.d)
        
        self.arms_history = []
        self.rewards_history = []
        self.last_arms = None
        self.params_set = self.H is not None

    def _auto_tune(self, P_T):
        h_val = (self.d**0.25) * (self.T**0.5) * ((1 + P_T)**(-0.5))
        self.H = int(max(1, float(h_val)))
        self.t_epoch = self.tau + self.H - 1
        self.params_set = True

    def select_arm(self, arms, pt=None, **kwargs):
        self.last_arms = arms
        if not self.params_set:
            self._auto_tune(pt if pt is not None else 0.0)
        
        if self.t_epoch is None:
            self.t_epoch = self.tau + self.H - 1

        x_num = len(arms)
        ucb_s = np.zeros(x_num)
        
        val = 1 + ((self.k_mu * self.internal_t) / (self.r_lambda * self.c_mu * self.d))
        beta_t = self.c_1 + self.c_2 + (self.d / np.sqrt(self.r_lambda * self.c_mu)) * np.log(4 * val)
        
        bonus_multiplier = (np.sqrt(4 + 8 * self.S) * self.k_mu / np.sqrt(self.c_mu))

        arms_arr = np.array(arms)
        quad = np.sum((arms_arr @ self.inv_V) * arms_arr, axis=1)
        quad = np.maximum(quad, 0.0)
        
        pred = sigmoid(arms_arr @ self.theta_tilde)
        ucb_s = pred + bonus_multiplier * beta_t * np.sqrt(quad)
        
        max_val = np.max(ucb_s)
        candidates = np.where(ucb_s >= max_val - 1e-9)[0]
        return int(np.random.choice(candidates))

    def update_statistics(self, arm, reward, **kwargs):
        x = self.last_arms[arm]
        self.arms_history.append(x)
        self.rewards_history.append(reward)
        
        v_x = self.inv_V @ x
        denom = 1.0 + x @ v_x
        self.inv_V -= np.outer(v_x, v_x) / denom
        

        self.estimator()
        self.internal_t += 1
        
        if self.t_epoch is not None and self.internal_t > self.t_epoch:
            self._restart()

    def _restart(self):
        if self.H is None: return
        self.j += 1
        self.tau = (self.j - 1) * self.H
        self.internal_t = self.tau + 1
        self.t_epoch = self.tau + self.H - 1
        
        self.ctr = 0
        self.theta_hat.fill(0.0)
        self.theta_tilde.fill(0.0)
        
        self.V = self.r_lambda * np.eye(self.d)
        self.inv_V = (1.0 / self.r_lambda) * np.eye(self.d)
        self.Hessian = self.r_lambda * self.c_mu * np.eye(self.d)
        
        self.arms_history = []
        self.rewards_history = []

    def estimator(self):
        if self.ctr % self.lazy_update_fr == 0 or len(self.rewards_history) < 200:
            theta_hat = self.theta_hat
            hessian = self.Hessian
            
            X = np.array(self.arms_history)
            Y = np.array(self.rewards_history)
            lambda_eye = self.r_lambda * self.c_mu * np.eye(self.d)

            if len(X) > 0:
                for _ in range(5):
                    scores = X @ theta_hat
                    coeffs = sigmoid(scores)
                    y_diff = coeffs - Y
                    
                    grad = self.r_lambda * self.c_mu * theta_hat + X.T @ y_diff
                    
                    sig_prime = coeffs * (1 - coeffs)
                    hessian = (X.T * sig_prime) @ X + lambda_eye
                    
                    try:
                        delta = solve(hessian, grad)
                        theta_hat -= delta
                        if norm(delta) < 1e-4: break
                    except np.linalg.LinAlgError:
                        break
            
            self.theta_hat = theta_hat
            self.Hessian = hessian
            
        self.ctr += 1
        
        if norm(self.theta_hat) < self.S:
            self.theta_tilde = self.theta_hat
        else:
            self.theta_tilde = self.projection(self.theta_hat)

    def gt(self, theta, arms_arr):
        if len(arms_arr) == 0:
            return self.r_lambda * self.c_mu * theta
        coeffs = sigmoid(arms_arr @ theta)
        gt = np.dot(arms_arr.T, coeffs) + self.r_lambda * self.c_mu * theta
        return gt

    def hessian_at(self, theta, arms_arr):
        if len(arms_arr) == 0:
            return self.r_lambda * self.c_mu * np.eye(self.d)
        coeffs = dsigmoid(arms_arr @ theta)
        ht = (arms_arr.T * coeffs) @ arms_arr + self.r_lambda * self.c_mu * np.eye(self.d)
        return ht

    def proj_fun(self, theta_target, arms_arr):
        inv_H = pinv(self.Hessian)
        
        g_target = self.gt(theta_target, arms_arr)
        g_hat = self.gt(self.theta_hat, arms_arr)
        
        diff_gt = g_target - g_hat
        return diff_gt @ (inv_H @ diff_gt)

    def projection(self, theta_init):
        arms_arr = np.array(self.arms_history)
        fun = lambda t: self.proj_fun(t, arms_arr)
        
        norm_cons = lambda t: norm(t)
        constraint = NonlinearConstraint(norm_cons, 0, self.S)
        
        opt = minimize(fun, x0=theta_init, method='SLSQP', constraints=constraint)
        return opt.x

    def re_init(self):
        super().re_init()
        self.internal_t = 1
        self.j = 1
        self.tau = 0
        self.ctr = 0
        self.theta_hat.fill(0.0)
        self.theta_tilde.fill(0.0)
        self.V = self.r_lambda * np.eye(self.d)
        self.inv_V = (1.0 / self.r_lambda) * np.eye(self.d)
        self.Hessian = self.r_lambda * self.c_mu * np.eye(self.d)
        
        self.arms_history = []
        self.rewards_history = []
        self.last_arms = None
        
        if self.H is not None:
            self.t_epoch = self.tau + self.H - 1
        else:
            self.t_epoch = None
            self.params_set = False

    def __str__(self):
        return f'SCB-RestartUCB(H={self.H})'