import numpy as np
from numpy.linalg import inv, cholesky, solve
from scipy.linalg import cho_solve

def sigmoid(x):
    # 오버플로우/언더플로우 방지를 위해 clip 사용
    x = np.clip(x, -500, 500)
    return 1 / (1 + np.exp(-x))

class LogBanditAlg:
    def __init__(self, env, n, params):
        self.env = env
        self.rng = env.rng
        self.K = env.K
        self.d = env.d
        self.n = n
        # Hyperparameters
        self.S = params.get('S', env.norm_theta)
        self.delta = params.get('delta', 0.01)
        # Data history
        self.A = []  # list of past arms
        self.R = []  # list of past rewards
        # Current arm set and theta
        self.X = np.zeros((self.K, self.d))
        self.theta = np.zeros(self.d)

    def get_arms(self, X):
        """현재 라운드의 팔 정보를 업데이트합니다."""
        self.X = X

    def update(self, t, arm_idx, reward):
        """선택한 팔과 보상을 기록합니다. (자식 클래스에서 오버라이드 가능)"""
        x = self.X[arm_idx]
        self.A.append(x)
        self.R.append(reward)

    def get_arm(self, t):
        """팔을 선택하는 로직. 자식 클래스에서 반드시 구현해야 합니다."""
        raise NotImplementedError
    
import math
import numpy as np
import numpy.linalg as LA
from scipy.optimize import minimize

# LogBanditAlg는 위 섹션에서 정의했다고 가정합니다.
# from your_file import LogBanditAlg, sigmoid

def dsigmoid(x):
    """Sigmoid 함수의 도함수"""
    s = sigmoid(x)
    return s * (1 - s)

class RSGLinCB(LogBanditAlg):
    @staticmethod
    def print():
        return "RS-GLinCB"

    def __init__(self, env, n, params):
        # 1. 부모 클래스 초기화 (기존 코드 구조와 동일)
        super().__init__(env, n, params)
        
        # 2. RS-GLinCB 고유 파라미터 로드
        self.lazy_update_fr = params.get('lazy_update_fr', 1)
        self.tol = params.get('tol', 1e-7)

        # 3. 알고리즘 고유 파라미터 및 상태 변수 초기화
        self.param_norm_ub = self.S
        self.l2reg = self.d * np.log(self.n / self.delta)
        self.gamma = 25 * self.param_norm_ub * np.sqrt(self.d * np.log(self.n / self.delta))
        self.kappa = 3 + np.exp(self.param_norm_ub)
        self.warmup_threshold = 1 / (self.gamma**2 * self.kappa)
        
        # 상태 변수
        self.ctr = 0
        self.switch1 = False # Switching Criterion I 발동 여부

        # 데이터 저장 공간 (트리거/논-트리거 분리)
        self.triggered_arms = np.zeros((0, self.d))
        self.triggered_rewards = np.array([])
        self.nontriggered_arms = np.zeros((0, self.d))
        self.nontriggered_rewards = np.array([])

        # 행렬 및 추정치
        self.V = self.l2reg * np.eye(self.d)
        self.V_inv = (1 / self.l2reg) * np.eye(self.d)
        self.H = self.l2reg * np.eye(self.d)
        self.H_inv = (1 / self.l2reg) * np.eye(self.d)
        self.H_tau = self.l2reg * np.eye(self.d)
        self.H_tau_inv = (1 / self.l2reg) * np.eye(self.d)

        # theta 추정치
        self.theta_hat_o = self.rng.normal(0, 1, self.d)
        self.theta_hat_tau = self.rng.normal(0, 1, self.d)

    def _neg_log_likelihood(self, theta, arms, rewards):
        if arms.shape[0] == 0: return 0
        z = arms @ theta
        # 로그의 인자가 0이 되는 것을 방지하기 위해 z의 범위를 제한 (sigmoid와 유사)
        z = np.clip(z, -100, 100)
        return -np.sum(rewards * z - np.log(1 + np.exp(z)))

    def _neg_log_likelihood_J(self, theta, arms, rewards):
        if arms.shape[0] == 0: return np.zeros_like(theta)
        mu = sigmoid(arms @ theta)
        return -np.sum(arms * (rewards - mu)[:, np.newaxis], axis=0)

    def get_arm(self, t):
        self.ctr += 1
        
        # --- Switching Criterion I (Warmup Phase) ---
        warmup_scores = np.einsum('ij,ji->i', self.X @ self.V_inv, self.X.T)
        
        if np.max(warmup_scores) >= self.warmup_threshold:
            self.switch1 = True
            return np.argmax(warmup_scores)
        
        # --- Standard Phase ---
        # Switching Criterion II (Snapshot Update)
        if LA.det(self.H) > 2 * LA.det(self.H_tau):
            self.H_tau, self.H_tau_inv = np.copy(self.H), np.copy(self.H_inv)
            
            if self.nontriggered_arms.shape[0] > 0:
                cons = {'type': 'ineq',
                        'fun': lambda th: self.param_norm_ub**2 - th @ th,
                        'jac': lambda th: -2 * th}
                opt = minimize(self._neg_log_likelihood, x0=self.theta_hat_tau,
                               args=(self.nontriggered_arms, self.nontriggered_rewards),
                               method='SLSQP', jac=self._neg_log_likelihood_J,
                               constraints=cons, tol=self.tol)
                self.theta_hat_tau = opt.x

        # 1. Arm Elimination
        ucb_o = self.X @ self.theta_hat_o + self.gamma * np.sqrt(self.kappa * warmup_scores)
        lcb_o = self.X @ self.theta_hat_o - self.gamma * np.sqrt(self.kappa * warmup_scores)
        
        active_mask = ucb_o >= np.min(lcb_o)
        active_arms = self.X[active_mask]
        original_indices = np.where(active_mask)[0]

        if active_arms.shape[0] == 0: # 모든 팔이 제거된 경우
            active_arms, original_indices = self.X, np.arange(self.K)

        # 2. UCB-based Arm Selection from the active set
        ucb_bonus_H = np.einsum('ij,ji->i', active_arms @ self.H_tau_inv, active_arms.T)
        ucb_scores = (active_arms @ self.theta_hat_tau + 
                      150 * np.sqrt(ucb_bonus_H) * np.sqrt(self.d * np.log(self.n / self.delta)))
        
        chosen_arm_idx = original_indices[np.argmax(ucb_scores)]
        
        # 선택된 팔로 H 행렬 업데이트
        chosen_arm_vector = self.X[chosen_arm_idx]
        dmu_val = dsigmoid(chosen_arm_vector @ self.theta_hat_o) / math.e
        self.H += dmu_val * np.outer(chosen_arm_vector, chosen_arm_vector)
        
        # Sherman-Morrison 공식으로 H_inv 업데이트
        tmp = np.sqrt(dmu_val) * chosen_arm_vector
        self.H_inv -= np.outer(self.H_inv @ tmp, self.H_inv @ tmp) / (1 + tmp.T @ self.H_inv @ tmp)
        
        return chosen_arm_idx

    def update(self, t, arm_idx, reward):
        if self.ctr % self.lazy_update_fr != 0:
            return

        arm_vector = self.X[arm_idx]
        
        if self.switch1:
            # Triggered data update
            self.triggered_arms = np.vstack([self.triggered_arms, arm_vector])
            self.triggered_rewards = np.append(self.triggered_rewards, reward)

            # V, V_inv 업데이트
            self.V += np.outer(arm_vector, arm_vector)
            self.V_inv -= np.outer(self.V_inv @ arm_vector, self.V_inv @ arm_vector) / (1 + arm_vector.T @ self.V_inv @ arm_vector)
            
            # theta_hat_o 업데이트
            obj = lambda th: self._neg_log_likelihood(th, self.triggered_arms, self.triggered_rewards) + (self.l2reg / 2) * (th @ th)
            obj_J = lambda th: self._neg_log_likelihood_J(th, self.triggered_arms, self.triggered_rewards) + self.l2reg * th
            opt = minimize(obj, x0=self.theta_hat_o, method='L-BFGS-B', jac=obj_J, tol=self.tol)
            self.theta_hat_o = opt.x
            
            self.switch1 = False # 스위치 리셋
        else:
            # Non-triggered data update
            self.nontriggered_arms = np.vstack([self.nontriggered_arms, arm_vector])
            self.nontriggered_rewards = np.append(self.nontriggered_rewards, reward)

        # 이 알고리즘은 부모의 A, R 리스트를 사용하지 않으므로 super().update()를 호출하지 않습니다.


import numpy as np
from logbexp.utils.utils import sigmoid, dsigmoid, weighted_norm
from logbexp.utils.optimization import fit_online_logistic_estimate, fit_online_logistic_estimate_bar
# LogBanditAlg는 이전과 동일한 파일에 정의되어 있다고 가정합니다.
# from your_file import LogBanditAlg

class EcoLog(LogBanditAlg):
    @staticmethod
    def print():
        return "EcoLog"

    def __init__(self, env, n, params):
        # 1. 부모 클래스 초기화
        super().__init__(env, n, params)

        # 2. EcoLog 고유 파라미터 설정
        # 논문에 따라 l2reg를 차원 d로 설정 (주석 참고)
        self.l2reg = params.get('l2reg', self.d)
        
        # 3. 상태 변수 초기화
        self.vtilde_matrix = self.l2reg * np.eye(self.d)
        self.vtilde_matrix_inv = (1 / self.l2reg) * np.eye(self.d)
        self.theta = np.zeros(self.d)
        self.conf_radius = 0
        self.cum_loss = 0
        self.ctr = 0

    def get_arm(self, t):
        # 1. 신뢰 반경(UCB 보너스) 업데이트
        self._update_ucb_bonus()

        # 2. 모든 팔에 대해 UCB 점수 계산
        # compute_optimistic_reward 로직을 직접 구현
        norms = np.sqrt(np.einsum('ij,ji->i', self.X @ self.vtilde_matrix_inv, self.X.T))
        pred_rewards = sigmoid(self.X @ self.theta)
        bonuses = self.conf_radius * norms
        
        ucb_scores = pred_rewards + bonuses
        
        # 3. 가장 높은 점수를 가진 팔 선택
        chosen_arm_idx = np.argmax(ucb_scores)
        return chosen_arm_idx

    def update(self, t, arm_idx, reward):
        self.ctr += 1
        arm_vector = self.X[arm_idx]

        # 1. theta (온라인 추정치) 업데이트
        # fit_online_logistic_estimate 함수를 직접 호출
        self.theta = np.real_if_close(fit_online_logistic_estimate(
            arm=arm_vector,
            reward=reward,
            current_estimate=self.theta,
            vtilde_matrix=self.vtilde_matrix,
            vtilde_inv_matrix=self.vtilde_matrix_inv,
            constraint_set_radius=self.S,
            diameter=self.S,
            precision=1/self.ctr
        ))

        # 2. theta_bar (비교용 추정치) 계산
        theta_bar = np.real_if_close(fit_online_logistic_estimate_bar(
            arm=arm_vector,
            current_estimate=self.theta,
            vtilde_matrix=self.vtilde_matrix,
            vtilde_inv_matrix=self.vtilde_matrix_inv,
            constraint_set_radius=self.S,
            diameter=self.S,
            precision=1/self.ctr
        ))

        # 3. vtilde 행렬 업데이트 (Sherman-Morrison 공식 사용)
        sensitivity = dsigmoid(np.dot(self.theta, arm_vector))
        self.vtilde_matrix += sensitivity * np.outer(arm_vector, arm_vector)
        # Sherman-Morrison update for the inverse
        v_inv_x = self.vtilde_matrix_inv @ arm_vector
        self.vtilde_matrix_inv -= sensitivity * np.outer(v_inv_x, v_inv_x) / (1 + sensitivity * (arm_vector @ v_inv_x))

        # 4. 누적 손실(cumulative loss) 업데이트
        coeff_theta = sigmoid(np.dot(self.theta, arm_vector))
        loss_theta = -reward * np.log(coeff_theta) - (1 - reward) * np.log(1 - coeff_theta)
        
        coeff_bar = sigmoid(np.dot(theta_bar, arm_vector))
        loss_theta_bar = -reward * np.log(coeff_bar) - (1 - reward) * np.log(1 - coeff_bar)
        
        self.cum_loss += loss_theta_bar - loss_theta

    def _update_ucb_bonus(self):
        """
        UCB 보너스를 계산합니다. 원본 코드의 로직을 그대로 사용합니다.
        """
        D = self.S
        # ctr이 0인 경우를 방지
        safe_ctr = self.ctr if self.ctr > 0 else 1
        
        nu_paper = 0.5 + 2 * np.log(2 * np.sqrt(1 + safe_ctr / (4 * self.l2reg)) / self.delta)
        ub1 = (2 + D) * nu_paper / 4 + D**2 / (2 + D)
        
        res_square = 4 * self.S**2 + 2 * (2 + D) * (ub1 + self.cum_loss)
        res_square += 4 * np.log(1 + safe_ctr)

        self.conf_radius = np.sqrt(max(0, res_square)) # 음수가 되지 않도록 max(0, ..) 추가