import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

class NeuralUCBModel(nn.Module):
    def __init__(self, input_dim, hidden_size=32, depth=2):
        super(NeuralUCBModel, self).__init__()
        layers = [nn.Linear(input_dim, hidden_size), nn.ReLU()]
        for _ in range(depth - 1):
            layers += [nn.Linear(hidden_size, hidden_size), nn.ReLU()]
        self.embedding_layer = nn.Sequential(*layers)
        self.output_layer = nn.Linear(hidden_size, 1)

    def forward(self, x):
        emb = self.embedding_layer(x)
        out = self.output_layer(emb)
        return out

class NeuralUCB:
    def __init__(self, input_dim, hidden_dims=[32, 32], lambda_=1.0, beta=1.0, 
                 use_diag_z=True, reset_theta_each_train=False):
        self.input_dim = input_dim
        self.model_hidden_size = hidden_dims[-1]
        self.model = NeuralUCBModel(input_dim, self.model_hidden_size, depth=len(hidden_dims))
        self.loss_fn = nn.MSELoss()
        
        self.lambda_ = lambda_
        self.beta = beta
        self.use_diag_z = use_diag_z
        self.reset_theta_each_train = reset_theta_each_train
        
        self.m = self._get_gradient_dim()
        
        if self.use_diag_z:
            self.Z_diag = torch.ones(self.m, dtype=torch.float32) * self.lambda_
            self.Z_inv_diag = 1.0 / self.Z_diag
        else:
            self.Z = torch.eye(self.m, dtype=torch.float32) * self.lambda_
            self.Z_inv = torch.eye(self.m, dtype=torch.float32) / float(self.lambda_)
            
        # Initialize parameters with random weights (already done by PyTorch init)
        # But we ensure they are not zero
        self.theta0 = [p.clone().detach() for p in self.model.parameters()]
        
    def _get_gradient_dim(self):
        dummy_input = torch.randn(1, self.input_dim, dtype=torch.float32)
        self.model.zero_grad()
        output = self.model(dummy_input)
        output.backward(torch.ones_like(output))
        grads = []
        for p in self.model.parameters():
            if p.grad is not None:
                grads.append(p.grad.view(-1))
        g = torch.cat(grads)
        return g.shape[0]
    
    def calc_ucb(self, x_tensor):
        self.model.zero_grad()
        pred = self.model(x_tensor)
        pred.backward(torch.ones_like(pred))
        grads = []
        for p in self.model.parameters():
            if p.grad is not None:
                grads.append(p.grad.view(-1))
        g = torch.cat(grads).detach()
        
        if self.use_diag_z:
            bonus = torch.sqrt(torch.sum((g ** 2) * self.Z_inv_diag))
        else:
            g_col = g.view(-1, 1)
            val = (g_col.t() @ self.Z_inv @ g_col).squeeze()
            bonus = torch.sqrt(torch.clamp(val, min=0.0))
            
        ucb = float(pred.item()) + float(self.beta) * float(bonus.item())
        return ucb, float(pred.item()), float(bonus.item())
    
    def update(self, context_embedding, reward):
        x_tensor = torch.tensor(context_embedding, dtype=torch.float32).unsqueeze(0)
        self.model.zero_grad()
        output = self.model(x_tensor)
        output.backward(torch.ones_like(output))
        grads = []
        for p in self.model.parameters():
            if p.grad is not None:
                grads.append(p.grad.view(-1))
        g = torch.cat(grads).detach()
        
        if self.use_diag_z:
            self.Z_diag += (g ** 2)
            eps = 1e-12
            self.Z_inv_diag = 1.0 / (self.Z_diag + eps)
        else:
            v = g.view(-1, 1)
            Zv = self.Z_inv @ v
            denom = (1.0 + (v.t() @ Zv)).item()
            if denom <= 1e-12:
                denom = 1e-12
            self.Z += (v @ v.t())
            self.Z_inv = self.Z_inv - (Zv @ Zv.t()) / denom
            
    def train(self, contexts, rewards, local_training_iter=30, lr=0.0005):
        if self.reset_theta_each_train:
            with torch.no_grad():
                for p, p0 in zip(self.model.parameters(), self.theta0):
                    p.copy_(p0.clone())
                    
        contexts_tensor = torch.tensor(np.array(contexts), dtype=torch.float32)
        rewards_tensor = torch.tensor(np.array(rewards), dtype=torch.float32).unsqueeze(1)
        
        optimizer = optim.Adam(self.model.parameters(), lr=lr)
        
        loss = None
        for _ in range(local_training_iter):
            optimizer.zero_grad()
            predictions = self.model(contexts_tensor)
            loss = self.loss_fn(predictions, rewards_tensor)
            reg = sum(((p - p0) ** 2).sum() for p, p0 in zip(self.model.parameters(), self.theta0))
            loss = loss + 0.5 * float(self.lambda_) * reg
            loss.backward()
            optimizer.step()
            
        return float(loss.item()) if loss is not None else None
