import numpy as np
import torch

class LinUCB:
    def __init__(self, input_dim, alpha=1.0, lambda_=1.0, use_diag=True):
        """
        LinUCB algorithm for contextual bandits.
        
        Parameters:
        - input_dim: Dimension of input features
        - alpha: Exploration parameter controlling the confidence interval size
        - lambda_: Regularization parameter
        - use_diag: Whether to use diagonal approximation for computational efficiency
        """
        self.input_dim = input_dim
        self.alpha = alpha
        self.lambda_ = lambda_
        self.use_diag = use_diag
        
        # Initialize parameters
        if self.use_diag:
            # Diagonal version: use vector to store diagonal elements
            self.A_diag = lambda_ * np.ones(input_dim)  # Diagonal elements of A
            self.A_inv_diag = 1.0 / self.A_diag         # Diagonal elements of A^{-1}
        else:
            # Full matrix version
            self.A = lambda_ * np.eye(input_dim)        # A = λI + Σ x x^T
            self.A_inv = np.eye(input_dim) / lambda_    # A^{-1}
            
        self.b = np.zeros(input_dim)                    # b = Σ r x
        self.theta = np.zeros(input_dim)                # θ = A^{-1} b
        
    def calc_ucb(self, x):
        x = np.array(x).flatten()
        
        # calculate predicted reward
        pred = np.dot(self.theta, x)
        
        # calculate confidence upper bound
        if self.use_diag:
            # Diagonal version: x^T A^{-1} x = Σ (x_i^2 / A_ii)
            bonus = self.alpha * np.sqrt(np.sum((x ** 2) * self.A_inv_diag))
        else:
            # Full matrix version: x^T A^{-1} x
            x_col = x.reshape(-1, 1)
            val = (x_col.T @ self.A_inv @ x_col).item()
            bonus = self.alpha * np.sqrt(max(val, 0))
            
        ucb = pred + bonus
        return ucb, pred, bonus
    
    def update(self, x, reward):
        x = np.array(x).flatten()
        
        if self.use_diag:
            # Diagonal version: only update diagonal elements
            self.A_diag += x ** 2
            # Numerical stability protection
            eps = 1e-12
            self.A_inv_diag = 1.0 / (self.A_diag + eps)
        else:
            # Full matrix version: incremental update of A_inv using Sherman-Morrison formula
            # A = A + x x^T
            # Update A^{-1} using Sherman-Morrison formula
            A_inv_x = self.A_inv @ x
            denom = 1.0 + np.dot(x, A_inv_x)
            
            if denom > 1e-12:  # Avoid division by zero
                self.A_inv = self.A_inv - np.outer(A_inv_x, A_inv_x) / denom
            # No need to explicitly update A because we only use A_inv
        
        # Update b
        self.b += reward * x
        
        # Update θ
        if self.use_diag:
            # Diagonal version: θ_i = b_i / A_ii
            self.theta = self.b / (self.A_diag + 1e-12)
        else:
            # Full matrix version: θ = A^{-1} b
            self.theta = self.A_inv @ self.b
    
    def train(self, contexts, rewards, **kwargs):
        # LinUCB does not require additional training; parameters are updated in update
        return None
