import torch

import torch.nn as nn

class OFUL:
    def __init__(self, d, alpha, delta, lambd):
        self.d = d
        self.alpha = alpha
        self.delta = delta
        self.lambd = lambd
        self.A = lambd * torch.eye(d)
        self.b = torch.zeros(d)
        
    def update(self, M_x, reward):
        self.A += M_x.T @ M_x
        self.b += M_x.T @ reward
        
    def compute_theta(self):
        A_inv = torch.inverse(self.A)
        theta = A_inv @ self.b
        return theta
    
    def select_action(self, M_x):
        A_inv = torch.inverse(self.A)
        theta = self.compute_theta()
        M_x = M_x.view(-1, 1)
        bonus = self.alpha * torch.sqrt(M_x.T @ A_inv @ M_x)
        reward_estimate = (theta.T @ M_x).item() + bonus.item()
        return reward_estimate
    
class DummyModel(nn.Module):
    def __init__(self, input_dim, embedding_dim):
        super(DummyModel, self).__init__()
        self.linear = nn.Linear(input_dim, embedding_dim)

    def map_to_sphere(self, A, B):
        combined_input = torch.cat([A, B], dim=1)
        embedding = self.linear(combined_input)
        return embedding
