
import torch

class GaussianProposal:
    """
    Multivariate Gaussian proposal with diagonal covariance.

    μ(x) = N(mean, diag(std^2))
    """

    def __init__(
        self,
        mean,            # Tensor [d]
        std,             # Tensor [d] or scalar
        device="cpu",
    ):
        """
        Args:
            mean : Tensor [d]  (mean vector)
            std  : Tensor [d] or float (diagonal std)
        """
        self.mean = mean.to(device)

        if torch.is_tensor(std):
            self.std = std.to(device)
        else:
            self.std = torch.full_like(self.mean, std)

        self.d = self.mean.shape[0]

        self.dist = torch.distributions.Normal(
            self.mean,
            self.std
        )

    def sample(self, N):
        """
        Sample N points: shape [N, d]
        """
        return self.dist.sample((N,))

    def log_prob(self, x):
        """
        Compute log μ(x): shape [N]
        """
        return self.dist.log_prob(x).sum(dim=1)
    



