import torch

# Parametrized rings
class RingsDistribution:
    """
    Multimodal radial distribution in R^d:
        π(x) ∝ sum_k exp( - (||x|| - r_k)^2 / (2 σ^2) )
    """

    def __init__(
        self,
        d,
        radii=(2.0, 4.0, 6.0),
        sigma=0.2,
        device="cpu",
    ):
        self.d = d
        self.radii = torch.tensor(radii, device=device)
        self.sigma = sigma
        self.device = device

    # log-density
    def logpi(self, x):
        r = torch.norm(x, dim=1, keepdim=True)
        log_terms = -0.5 * ((r - self.radii) / self.sigma) ** 2
        logp_radial = torch.logsumexp(log_terms, dim=1)

        # Jacobian correction
        log_jac = -(self.d - 1) * torch.log(r.squeeze() + 1e-8)

        return logp_radial + log_jac


    # gradient via autograd
    def grad_logpi(self, x):
        """
        Compute ∇ log π(x) using autograd
        """
        x = x.requires_grad_(True)
        logp = self.logpi(x).sum()
        grad = torch.autograd.grad(logp, x)[0]
        return grad

    # sampling
    def sample(self, N):
        """
        Exact sampling:
        1. sample ring index
        2. sample direction uniformly on sphere
        3. sample radius around r_k
        """
        # choose ring
        idx = torch.randint(
            0, len(self.radii), (N,), device=self.device
        )
        r0 = self.radii[idx]

        # sample radius
        r = r0 + self.sigma * torch.randn(N, device=self.device)

        # sample direction uniformly on sphere
        z = torch.randn(N, self.d, device=self.device)
        z = z / torch.norm(z, dim=1, keepdim=True)

        return r[:, None] * z