import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# Import NeuralUCBModel. 
# We try relative import first (for package usage), then absolute (for script usage).
try:
    from .neural_ucb import NeuralUCBModel
except ImportError:
    from neural_ucb import NeuralUCBModel

class MixedUCB:
    def __init__(self, input_dim, hidden_dims=[32, 32], lambda_=1.0, beta=1.0, 
                 use_diag_z=True, reset_theta_each_train=False):
        """
        MixedUCB: Combines NeuralUCB's prediction with LinUCB's exploration.
        
        This class is implemented independently of NeuralUCB to avoid memory overhead
        from NeuralUCB's NTK matrix initialization.
        
        Args:
            input_dim: Input dimension
            hidden_dims: List of hidden layer dimensions for the neural network
            lambda_: Regularization parameter for LinUCB matrix (and NeuralUCB training)
            beta: Exploration parameter (acts as alpha in LinUCB)
            use_diag_z: If True, use diagonal approximation for LinUCB matrix.
            reset_theta_each_train: If True, reset neural network parameters before each training
        """
        self.input_dim = input_dim
        self.model_hidden_size = hidden_dims[-1]
        
        # Initialize Neural Network Model
        self.model = NeuralUCBModel(input_dim, self.model_hidden_size, depth=len(hidden_dims))
        self.loss_fn = nn.MSELoss()
        
        self.lambda_ = lambda_
        self.alpha = beta  # Use beta as the exploration weight (LinUCB's alpha)
        self.use_diag = use_diag_z
        self.reset_theta_each_train = reset_theta_each_train
        
        # Initialize LinUCB matrices (A and A_inv)
        if self.use_diag:
            self.A_diag = torch.ones(input_dim, dtype=torch.float32) * self.lambda_
            self.A_inv_diag = 1.0 / self.A_diag
        else:
            self.A = torch.eye(input_dim, dtype=torch.float32) * self.lambda_
            self.A_inv = torch.eye(input_dim, dtype=torch.float32) / float(self.lambda_)

        # Initialize parameters for regularization (following NeuralUCB logic)
        self.theta0 = [p.clone().detach() for p in self.model.parameters()]

    def calc_ucb(self, x_tensor):
        """
        Calculate UCB using Neural Network for prediction and LinUCB for exploration.
        """
        # 1. Neural Network Prediction
        self.model.zero_grad()
        # Ensure we don't compute gradients for prediction
        with torch.no_grad():
            pred = self.model(x_tensor)
        
        # 2. LinUCB Exploration Bonus
        # x_tensor is typically (1, input_dim), flatten it
        x = x_tensor.view(-1)
        
        if self.use_diag:
            # Ensure device consistency
            if self.A_inv_diag.device != x.device:
                self.A_inv_diag = self.A_inv_diag.to(x.device)
                self.A_diag = self.A_diag.to(x.device)

            # bonus = alpha * sqrt(sum(x^2 * A_inv_diag))
            bonus = self.alpha * torch.sqrt(torch.sum((x ** 2) * self.A_inv_diag))
        else:
            # Ensure device consistency
            if self.A_inv.device != x.device:
                self.A_inv = self.A_inv.to(x.device)
                self.A = self.A.to(x.device)

            # bonus = alpha * sqrt(x^T A_inv x)
            # x is (D,), A_inv is (D, D)
            val = (x.unsqueeze(0) @ self.A_inv @ x.unsqueeze(1)).squeeze()
            bonus = self.alpha * torch.sqrt(torch.clamp(val, min=0.0))
            
        ucb = float(pred.item()) + float(bonus.item())
        return ucb, float(pred.item()), float(bonus.item())

    def update(self, context_embedding, reward):
        """
        Update LinUCB matrices with the new context.
        The Neural Network is updated separately via the `train` method.
        """
        # Determine target device based on current matrices
        if self.use_diag:
            device = self.A_inv_diag.device
        else:
            device = self.A_inv.device
            
        if isinstance(context_embedding, torch.Tensor):
            x_tensor = context_embedding.to(device=device, dtype=torch.float32)
        else:
            x_tensor = torch.tensor(context_embedding, dtype=torch.float32, device=device)
        x = x_tensor.view(-1)
        
        # Update LinUCB matrices (A_inv)
        if self.use_diag:
            self.A_diag += x ** 2
            eps = 1e-12
            self.A_inv_diag = 1.0 / (self.A_diag + eps)
        else:
            # Sherman-Morrison update
            # A_inv = A_inv - (A_inv x x^T A_inv) / (1 + x^T A_inv x)
            v = x.unsqueeze(1) # (D, 1)
            Av = self.A_inv @ v
            denom = (1.0 + (v.t() @ Av)).item()
            if denom <= 1e-12:
                denom = 1e-12
            
            # self.A += v @ v.t() # Optional if we only track A_inv
            self.A_inv = self.A_inv - (Av @ Av.t()) / denom

    def train(self, contexts, rewards, local_training_iter=30, lr=0.0005):
        """
        Train the neural network using the stored contexts and rewards.
        This follows the NeuralUCB training logic with regularization.
        """
        if self.reset_theta_each_train:
            with torch.no_grad():
                for p, p0 in zip(self.model.parameters(), self.theta0):
                    p.copy_(p0.clone())
                    
        contexts_tensor = torch.tensor(np.array(contexts), dtype=torch.float32)
        rewards_tensor = torch.tensor(np.array(rewards), dtype=torch.float32).unsqueeze(1)
        
        # Ensure model and data are on the same device (default CPU for now)
        if next(self.model.parameters()).device != contexts_tensor.device:
            contexts_tensor = contexts_tensor.to(next(self.model.parameters()).device)
            rewards_tensor = rewards_tensor.to(next(self.model.parameters()).device)

        optimizer = optim.Adam(self.model.parameters(), lr=lr)
        
        loss = None
        for _ in range(local_training_iter):
            optimizer.zero_grad()
            predictions = self.model(contexts_tensor)
            loss = self.loss_fn(predictions, rewards_tensor)
            
            # Regularization term: sum((p - p0)^2)
            # NeuralUCB uses this to stay close to initialization (lazy training regime)
            reg = sum(((p - p0) ** 2).sum() for p, p0 in zip(self.model.parameters(), self.theta0))
            loss = loss + 0.5 * float(self.lambda_) * reg
            
            loss.backward()
            optimizer.step()
            
        return float(loss.item()) if loss is not None else None
