import math
import torch
import torch.nn as nn

class ScoreModel(nn.Module): 
    '''
    parameters: beta_est, mu_est, sigma_est
    forward input: x: [batch_size, dx], y:[batch_size,dy], theta: [batch_size, dtheta]

    '''
    def __init__(self, dx, dy, mu):
        super().__init__()
        self.dx = dx
        self.dy = dy
        self.mu      = nn.Parameter(0.1*torch.randn(self.dx))
        self.Sigma_z = nn.Parameter(torch.eye(self.dx+self.dy))

    def forward(self, x, y, theta):
        B, d = x.shape
        device, dtype = x.device, x.dtype

        z = torch.cat([x, y], dim=1)

        mu_theta = torch.sum(self.mu * theta, dim=-1, keepdim=True)  
        zeros = torch.zeros(B, d, device=device, dtype=dtype)
        mu_z = torch.cat([zeros, mu_theta], dim=1)

        Sigma_z = (self.Sigma_z + self.Sigma_z.T) / 2
        inv_Sigma_z = torch.inverse(Sigma_z)
        sign, logdet = torch.linalg.slogdet(Sigma_z)
        if sign <= 0:
            raise ValueError("Sigma_z is not positive definite")

        const = -0.5 * (d + 1) * math.log(2 * math.pi) - 0.5 * logdet
        diff = z - mu_z  
        quad = -0.5 * torch.sum(diff @ inv_Sigma_z * diff, dim=1)  

        return quad, const


class DNNScoreModel(nn.Module):
    """
    Deep neural network to approximate log q(x,y|theta).
    Inputs:
      - x: Tensor of shape (B, dx)
      - y: Tensor of shape (B, dy)
      - theta: Tensor of shape (B, dtheta)
    Output:
      - logM: Tensor of shape (B,) giving log q for each sample
    """
    def __init__(self, dx, dy, dtheta, hidden_dims=[128, 128]):
        super().__init__()
        input_dim = dx + dy + dtheta
        layers = []
        prev_dim = input_dim
        for h in hidden_dims:
            layers.append(nn.Linear(prev_dim, h))
            layers.append(nn.Softplus())
            prev_dim = h
        layers.append(nn.Linear(prev_dim, 1))
        self.net = nn.Sequential(*layers)

    def forward(self, x, y, theta):
        # Concatenate all inputs along feature dimension
        z = torch.cat([x, y, theta], dim=1)  
        logM = self.net(z).squeeze(-1)        
        return logM

