
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Tuple, Sequence

import math
import torch
import torch.nn.functional as F

try:
    from torchvision import datasets, transforms
    from torch.utils.data import DataLoader
except ImportError:  # pragma: no cover - torchvision is optional
    datasets = None
    transforms = None
    DataLoader = None

Tensor = torch.Tensor


# -----------------------------------------------------------------------------
# Utilities
# -----------------------------------------------------------------------------

def _as_device_dtype(device: Optional[torch.device | str], dtype: Optional[torch.dtype]):
    if device is None:
        device = torch.device("cpu")
    else:
        device = torch.device(device)
    if dtype is None:
        dtype = torch.get_default_dtype()
    return device, dtype


def _randn(n: int, d: int, device, dtype):
    return torch.randn(n, d, device=device, dtype=dtype)


def _uniform(n: int, low: float, high: float, device, dtype):
    return low + (high - low) * torch.rand(n, 1, device=device, dtype=dtype)


def _rotation_matrix(theta: float, device, dtype) -> Tensor:
    c = math.cos(theta)
    s = math.sin(theta)
    R = torch.tensor([[c, -s], [s, c]], device=device, dtype=dtype)
    return R


def _log_normal_1d(x: Tensor, mean: Tensor, std: Tensor) -> Tensor:
    var = std ** 2
    return -0.5 * ((x - mean) ** 2) / var - torch.log(std) - 0.5 * math.log(2 * math.pi)


def _log_normal_2d(x: Tensor, mean: Tensor, cov: Tensor, prec: Tensor, logdet: Tensor) -> Tensor:
    # x: (N,2), mean: (2,), cov: (2,2) [unused], prec: (2,2), logdet: scalar tensor
    diff = x - mean
    quad = torch.einsum('ni,ij,nj->n', diff, prec, diff)
    return -0.5 * (2 * math.log(2 * math.pi) + logdet + quad)


def _log_cauchy_1d(x: Tensor, loc: Tensor, scale: Tensor) -> Tensor:
    return - (math.log(math.pi) + torch.log(scale)) - torch.log1p(((x - loc) / scale) ** 2)


def _logsumexp(a: Tensor, dim: int = -1) -> Tensor:
    return torch.logsumexp(a, dim=dim)


# -----------------------------------------------------------------------------
# Base class
# -----------------------------------------------------------------------------

class BaseDistribution2D:
    """Interface for 2D distributions.

    Methods to override:
      - sample(self, n: int, device=None, dtype=None) -> Tensor[(n,2)]
      - log_prob(self, x: Tensor[(n,2)]) -> Tensor[(n,)]  (if available)
    """

    has_log_prob: bool = False

    def sample(self, n: int, device: Optional[torch.device | str] = None, dtype: Optional[torch.dtype] = None) -> Tensor:
        raise NotImplementedError

    def log_prob(self, x: Tensor) -> Tensor:
        raise NotImplementedError("Analytic log-density not available for this distribution.")


# -----------------------------------------------------------------------------
# 1) Thin Annulus
# -----------------------------------------------------------------------------

@dataclass
class ThinAnnulus(BaseDistribution2D):
    R: float = 3.0
    eps: float = 0.15

    has_log_prob: bool = True

    def sample(self, n: int, device=None, dtype=None) -> Tensor:
        device, dtype = _as_device_dtype(device, dtype)
        # Use explicit linear transform of randn to guarantee correct device/dtype
        r = torch.tensor(self.R, device=device, dtype=dtype) + \
            torch.tensor(self.eps, device=device, dtype=dtype) * torch.randn(n, device=device, dtype=dtype)
        theta = 2 * math.pi * torch.rand(n, device=device, dtype=dtype)
        x = torch.stack([r * torch.cos(theta), r * torch.sin(theta)], dim=-1)
        return x

    def log_prob(self, x: Tensor) -> Tensor:
        r = torch.linalg.norm(x, dim=-1)
        # p(x) = N(r; R, eps^2) * 1/(2π r)
        log_p_r = -0.5 * ((r - self.R) ** 2) / (self.eps ** 2) - math.log(self.eps) - 0.5 * math.log(2 * math.pi)
        log_jac = - math.log(2 * math.pi) - torch.log(r.clamp_min(1e-12))
        return log_p_r + log_jac


# -----------------------------------------------------------------------------
# 2) Concentric Rings Mixture
# -----------------------------------------------------------------------------

@dataclass
class ConcentricRingsMixture(BaseDistribution2D):
    radii: Tuple[float, ...] = (1.0, 2.0, 3.0)
    eps: float = 0.07
    weights: Optional[Tuple[float, ...]] = None

    has_log_prob: bool = True

    def __post_init__(self):
        K = len(self.radii)
        if self.weights is None:
            self.weights = tuple([1.0 / K] * K)
        assert len(self.weights) == K
        s = sum(self.weights)
        self.weights = tuple([w / s for w in self.weights])
        self._logw = None

    def sample(self, n: int, device=None, dtype=None) -> Tensor:
        device, dtype = _as_device_dtype(device, dtype)
        K = len(self.radii)
        cat = torch.distributions.Categorical(logits=torch.tensor(self.weights, device=device, dtype=dtype).log())
        idx = cat.sample((n,))  # (n,)
        r = torch.empty(n, device=device, dtype=dtype)
        for k in range(K):
            mask = (idx == k)
            nk = int(mask.sum().item())
            if nk == 0:
                continue
            r[mask] = torch.tensor(self.radii[k], device=device, dtype=dtype) + \
                      torch.tensor(self.eps, device=device, dtype=dtype) * torch.randn(nk, device=device, dtype=dtype)
        theta = 2 * math.pi * torch.rand(n, device=device, dtype=dtype)
        x = torch.stack([r * torch.cos(theta), r * torch.sin(theta)], dim=-1)
        return x

    def log_prob(self, x: Tensor) -> Tensor:
        r = torch.linalg.norm(x, dim=-1)
        if self._logw is None:
            self._logw = torch.log(torch.tensor(self.weights, device=x.device, dtype=x.dtype))
        # component log-densities: same as ThinAnnulus with different R
        comps = []
        for R in self.radii:
            log_p_r = -0.5 * ((r - R) ** 2) / (self.eps ** 2) - math.log(self.eps) - 0.5 * math.log(2 * math.pi)
            log_jac = - math.log(2 * math.pi) - torch.log(r.clamp_min(1e-12))
            comps.append(log_p_r + log_jac)
        L = torch.stack(comps, dim=-1)  # (N,K)
        return _logsumexp(self._logw + L, dim=-1)


# -----------------------------------------------------------------------------
# 3) Two Moons (sample only)
# -----------------------------------------------------------------------------

@dataclass
class TwoMoons(BaseDistribution2D):
    noise: float = 0.02

    has_log_prob: bool = False

    def sample(self, n: int, device=None, dtype=None) -> Tensor:
        device, dtype = _as_device_dtype(device, dtype)
        n1 = n // 2
        n2 = n - n1
        # Moon 1: centered at (0,0), radius 1, theta ∈ [0,π]
        theta1 = math.pi * torch.rand(n1, device=device, dtype=dtype)
        x1 = torch.stack([torch.cos(theta1), torch.sin(theta1)], dim=-1)
        # Moon 2: shift/flip to interleave
        theta2 = math.pi * torch.rand(n2, device=device, dtype=dtype)
        x2 = torch.stack([1.0 - torch.cos(theta2), 1.0 - torch.sin(theta2) - 0.5], dim=-1)
        x = torch.cat([x1, x2], dim=0)
        if self.noise > 0:
            x = x + self.noise * torch.randn_like(x)
        return x


# -----------------------------------------------------------------------------
# 4) Spiral (sample only)
# -----------------------------------------------------------------------------

@dataclass
class Spiral2D(BaseDistribution2D):
    r0: float = 0.5
    r1: float = 3.0
    turns: float = 3.0
    noise: float = 0.03

    has_log_prob: bool = False

    def sample(self, n: int, device=None, dtype=None) -> Tensor:
        device, dtype = _as_device_dtype(device, dtype)
        u = torch.rand(n, device=device, dtype=dtype)
        r = self.r0 + (self.r1 - self.r0) * u
        theta = 2 * math.pi * self.turns * u
        x = torch.stack([r * torch.cos(theta), r * torch.sin(theta)], dim=-1)
        if self.noise > 0:
            x = x + self.noise * torch.randn_like(x)
        return x


# -----------------------------------------------------------------------------
# 5) Rotated Gaussian
# -----------------------------------------------------------------------------

@dataclass
class RotatedGaussian(BaseDistribution2D):
    var_small: float = 1e-4
    var_large: float = 1.0
    theta_deg: float = 30.0
    mean: Tuple[float, float] = (0.0, 0.0)

    has_log_prob: bool = True

    def _params(self, device, dtype):
        R = _rotation_matrix(math.radians(self.theta_deg), device, dtype)
        D = torch.diag(torch.tensor([self.var_large, self.var_small], device=device, dtype=dtype))
        cov = R @ D @ R.T
        prec = torch.linalg.inv(cov)
        logdet = torch.logdet(cov)
        mean = torch.tensor(self.mean, device=device, dtype=dtype)
        return mean, cov, prec, logdet

    def sample(self, n: int, device=None, dtype=None) -> Tensor:
        device, dtype = _as_device_dtype(device, dtype)
        mean, cov, _, _ = self._params(device, dtype)
        z = _randn(n, 2, device, dtype)
        # Cholesky for sampling
        L = torch.linalg.cholesky(cov)
        return mean + z @ L.T

    def log_prob(self, x: Tensor) -> Tensor:
        mean, cov, prec, logdet = self._params(x.device, x.dtype)
        return _log_normal_2d(x, mean, cov, prec, logdet)


# -----------------------------------------------------------------------------
# 6) Rosenbrock (unnormalized)
# -----------------------------------------------------------------------------

@dataclass
class Rosenbrock2D(BaseDistribution2D):
    a: float = 1.0
    b: float = 10.0  # controls curvature (classic Rosenbrock uses b=100; we scale in exponent)

    has_log_prob: bool = False  # normalization constant unknown

    def sample(self, n: int, device=None, dtype=None) -> Tensor:
        # Crude sampler: start from Gaussian and refine with a few Langevin steps (optional)
        device, dtype = _as_device_dtype(device, dtype)
        x = torch.randn(n, 2, device=device, dtype=dtype)
        # short Langevin to push into valley (kept minimal to avoid dependency on step sizes)
        steps, lr, noise = 40, 1e-2, 1e-2
        x.requires_grad_(True)
        for _ in range(steps):
            U = self.potential(x)
            (grad,) = torch.autograd.grad(U.sum(), x, create_graph=False)
            x = (x - lr * grad + math.sqrt(2 * lr) * noise * torch.randn_like(x)).detach().requires_grad_(True)
        return x.detach()

    def potential(self, x: Tensor) -> Tensor:
        # U(x,y) = (x^2 + b*(y - x^2)^2) / 20  (scaled to avoid extreme values)
        x1, x2 = x[..., 0], x[..., 1]
        return (x1**2 + self.b * (x2 - x1**2) ** 2) / 20.0

    def log_prob_unnorm(self, x: Tensor) -> Tensor:
        return -self.potential(x)


# -----------------------------------------------------------------------------
# 7) Neal's Funnel (2D version)
# -----------------------------------------------------------------------------

@dataclass
class NealFunnel2D(BaseDistribution2D):
    sigma1: float = 3.0  # std of x1
    alpha: float = 1.0   # variance of x2 is exp(alpha * x1)

    has_log_prob: bool = True

    def sample(self, n: int, device='cpu', dtype=torch.float32) -> Tensor:
        device, dtype = _as_device_dtype(device, dtype)
        x1 = self.sigma1 * torch.randn(n, 1, device=device, dtype=dtype)
        std2 = torch.exp(0.5 * self.alpha * x1)  # (n,1)
        x2 = std2 * torch.randn(n, 1, device=device, dtype=dtype)
        return torch.cat([x1, x2], dim=-1)

    def log_prob(self, x: Tensor) -> Tensor:
        x1, x2 = x[..., 0], x[..., 1]
        lp1 = _log_normal_1d(x1, x1.new_tensor(0.0), x1.new_tensor(self.sigma1))
        var2 = torch.exp(self.alpha * x1)
        lp2 = -0.5 * (x2**2) / var2 - 0.5 * (math.log(2 * math.pi) + self.alpha * x1)
        return lp1 + lp2


# -----------------------------------------------------------------------------
# 8) Independent Cauchy (product of 1D Cauchy)
# -----------------------------------------------------------------------------


class ZScoreWrapper(BaseDistribution2D):
    """Wrap a base sampler to operate in z-scored coordinates."""

    def __init__(self, base: BaseDistribution2D, mean: torch.Tensor, std: torch.Tensor) -> None:
        self.base = base
        self.mean = mean
        self.std = std
        self.has_log_prob = getattr(base, "has_log_prob", False)

    def sample(self, n: int, *, device=None, dtype=None) -> torch.Tensor:
        raw = self.base.sample(n, device=device, dtype=dtype)
        mean = self.mean.to(raw.device, raw.dtype)
        std = self.std.to(raw.device, raw.dtype)
        return (raw - mean) / std

    def log_prob(self, x: torch.Tensor) -> torch.Tensor:
        if not hasattr(self.base, "log_prob"):
            raise AttributeError("Wrapped sampler does not implement log_prob")
        mean = self.mean.to(x.device, x.dtype)
        std = self.std.to(x.device, x.dtype)
        raw = x * std + mean
        log_det = torch.log(std.abs()).sum()
        return self.base.log_prob(raw) - log_det

    def to_raw(self, x: torch.Tensor) -> torch.Tensor:
        mean = self.mean.to(x.device, x.dtype)
        std = self.std.to(x.device, x.dtype)
        return x * std + mean

    def __getattr__(self, attr):
        return getattr(self.base, attr)


@dataclass
class IndependentCauchy2D(BaseDistribution2D):
    loc: Tuple[float, float] = (0.0, 0.0)
    scale: Tuple[float, float] = (1.0, 1.0)

    has_log_prob: bool = True

    def sample(self, n: int, device=None, dtype=None) -> Tensor:
        device, dtype = _as_device_dtype(device, dtype)
        # Inverse CDF sampling for 1D Cauchy: x = loc + scale * tan(pi*(u-1/2))
        u = torch.rand(n, 2, device=device, dtype=dtype)
        loc = torch.tensor(self.loc, device=device, dtype=dtype)
        scale = torch.tensor(self.scale, device=device, dtype=dtype)
        return loc + scale * torch.tan(math.pi * (u - 0.5))

    def log_prob(self, x: Tensor) -> Tensor:
        loc = x.new_tensor(self.loc)
        scale = x.new_tensor(self.scale)
        lp = _log_cauchy_1d(x[..., 0], loc[0], scale[0]) + _log_cauchy_1d(x[..., 1], loc[1], scale[1])
        return lp


def _as_device_dtype(device=None, dtype=None):
    device = torch.device(device) if device is not None else torch.device("cpu")
    dtype = dtype if dtype is not None else torch.float32
    return device, dtype

@dataclass
class GaussianCauchyMixture2D(BaseDistribution2D):
    weight_cauchy: float = 0.2
    # Gaussian component
    gauss_mean: Tuple[float, float] = (0.0, 0.0)
    gauss_var: float = 1.0
    # Cauchy component (independent 1D)
    cauchy_loc: Tuple[float, float] = (0.0, 0.0)
    cauchy_scale: float = 1.0



# -----------------------------------------------------------------------------
# 10) Near-manifold sine curve with small normal noise
# -----------------------------------------------------------------------------

@dataclass
class NearManifoldSine(BaseDistribution2D):
    freq: float = 5.0
    x_low: float = -math.pi
    x_high: float = math.pi
    sigma_perp: float = 0.02

    has_log_prob: bool = True  # analytic joint density p(x,y) on domain

    def sample(self, n: int, device=None, dtype=None) -> Tensor:
        device, dtype = _as_device_dtype(device, dtype)
        x = _uniform(n, self.x_low, self.x_high, device, dtype)
        y = torch.sin(self.freq * x) + self.sigma_perp * torch.randn_like(x)
        return torch.cat([x, y], dim=-1)

    def log_prob(self, x: Tensor) -> Tensor:
        x1, y1 = x[..., 0], x[..., 1]
        # p(x,y) = (1/(x_high - x_low)) * N(y; sin(freq x), sigma_perp^2) for x∈[low,high], else 0
        inside = (x1 >= self.x_low) & (x1 <= self.x_high)
        log_uniform_x = -math.log(self.x_high - self.x_low)
        log_py = _log_normal_1d(y1, torch.sin(self.freq * x1), x1.new_tensor(self.sigma_perp))
        out = x1.new_full(x1.shape, float('-inf'))
        out[inside] = log_uniform_x + log_py[inside]
        return out


# -----------------------------------------------------------------------------
# 11) Atom + Gaussian mixture (mixed measure)
# -----------------------------------------------------------------------------

@dataclass
class AtomPlusGaussian2D(BaseDistribution2D):
    weight_atom: float = 0.3
    atom_loc: Tuple[float, float] = (0.0, 0.0)
    gauss_mean: Tuple[float, float] = (1.0, 1.0)
    gauss_var: float = 0.5

    has_log_prob: bool = False  # not a pure density

    def sample(self, n: int, device=None, dtype=None) -> Tensor:
        device, dtype = _as_device_dtype(device, dtype)
        w = torch.tensor([self.weight_atom, 1 - self.weight_atom], device=device, dtype=dtype)
        cat = torch.distributions.Categorical(probs=w)
        z = cat.sample((n,))
        atom = torch.tensor(self.atom_loc, device=device, dtype=dtype).expand(n, 2)
        mean = torch.tensor(self.gauss_mean, device=device, dtype=dtype)
        cov = torch.eye(2, device=device, dtype=dtype) * self.gauss_var
        L = torch.linalg.cholesky(cov)
        gauss = mean + _randn(n, 2, device, dtype) @ L.T
        return torch.where(z[:, None] == 0, atom, gauss)

    def log_prob_continuous(self, x: Tensor) -> Tensor:
        # density of the continuous Gaussian component only (scaled by (1 - w))
        mean = x.new_tensor(self.gauss_mean)
        cov = torch.eye(2, device=x.device, dtype=x.dtype) * self.gauss_var
        prec = torch.linalg.inv(cov)
        logdet = torch.logdet(cov)
        logN = _log_normal_2d(x, mean, cov, prec, logdet)
        return math.log(1 - self.weight_atom) + logN

    def atom_log_mass(self) -> float:
        return math.log(self.weight_atom)


# -----------------------------------------------------------------------------
# 12) Checkerboard stripes in [-4,4]^2
# -----------------------------------------------------------------------------

@dataclass
class CheckerboardStripes(BaseDistribution2D):
    low: float = -4.0
    high: float = 4.0

    has_log_prob: bool = True

    def _pick_square(self, n: int, device, dtype) -> Tensor:
        # integer grid squares i,j ∈ {low..high-1}
        low_i = int(math.floor(self.low))
        high_i = int(math.floor(self.high))  # exclusive upper in our sampling
        I = torch.arange(low_i, high_i, device=device)
        J = torch.arange(low_i, high_i, device=device)
        ii, jj = torch.meshgrid(I, J, indexing='ij')
        mask = ((ii + jj) % 2 == 0)
        valid = torch.stack([ii[mask], jj[mask]], dim=-1)  # (M,2)
        idx = torch.randint(0, valid.shape[0], (n,), device=device)
        return valid[idx].to(dtype)

    def sample(self, n: int, device=None, dtype=None) -> Tensor:
        device, dtype = _as_device_dtype(device, dtype)
        squares = self._pick_square(n, device, dtype)  # (n,2)
        # sample uniformly within the chosen unit squares
        offs = torch.rand(n, 2, device=device, dtype=dtype)
        x = squares + offs
        return x

    def log_prob(self, x: Tensor) -> Tensor:
        # density is constant on support: total area of support = half of domain area
        area_total = (self.high - self.low) ** 2
        area_support = area_total / 2.0
        log_const = -math.log(area_support)
        i = torch.floor(x[..., 0])
        j = torch.floor(x[..., 1])
        inside = (x[..., 0] >= self.low) & (x[..., 0] <= self.high) & \
                 (x[..., 1] >= self.low) & (x[..., 1] <= self.high) & \
                 (((i + j) % 2) == 0)
        out = x.new_full(x.shape[:-1], float('-inf'))
        out[inside] = log_const
        return out


# -----------------------------------------------------------------------------
# 13) Spiky GMM (one tiny, low-weight mode)
# -----------------------------------------------------------------------------

@dataclass
class SpikyGMM2D(BaseDistribution2D):
    mean_big: Tuple[float, float] = (0.0, 0.0)
    var_big: float = 1.0
    mean_spike: Tuple[float, float] = (3.0, 0.0)
    var_spike: float = 1e-3
    weight_spike: float = 0.03

    has_log_prob: bool = True

    def sample(self, n: int, device=None, dtype=None) -> Tensor:
        device, dtype = _as_device_dtype(device, dtype)
        w = torch.tensor([1 - self.weight_spike, self.weight_spike], device=device, dtype=dtype)
        cat = torch.distributions.Categorical(probs=w)
        z = cat.sample((n,))
        # big component
        Lb = math.sqrt(self.var_big) * torch.eye(2, device=device, dtype=dtype)
        xb = torch.tensor(self.mean_big, device=device, dtype=dtype) + _randn(n, 2, device, dtype) @ Lb.T
        # spike component
        Ls = math.sqrt(self.var_spike) * torch.eye(2, device=device, dtype=dtype)
        xs = torch.tensor(self.mean_spike, device=device, dtype=dtype) + _randn(n, 2, device, dtype) @ Ls.T
        return torch.where(z[:, None] == 0, xb, xs)

    def log_prob(self, x: Tensor) -> Tensor:
        # log-sum-exp over two Gaussians
        def logN(mean_xy, var):
            mean = x.new_tensor(mean_xy)
            cov = torch.eye(2, device=x.device, dtype=x.dtype) * var
            prec = torch.linalg.inv(cov)
            logdet = torch.logdet(cov)
            return _log_normal_2d(x, mean, cov, prec, logdet)
        logb = logN(self.mean_big, self.var_big) + math.log(1 - self.weight_spike)
        logs = logN(self.mean_spike, self.var_spike) + math.log(self.weight_spike)
        return _logsumexp(torch.stack([logb, logs], dim=-1), dim=-1)


# -----------------------------------------------------------------------------
# 14) 3x3 Grid GMM (uniform weights, tiny variance)
# -----------------------------------------------------------------------------

@dataclass
class GridGMM9(BaseDistribution2D):
    spacing: float = 1.0
    var: float = 0.0025
    weights: Optional[Sequence[float]] = None

    has_log_prob: bool = True

    def __post_init__(self):
        coords = (-float(self.spacing), 0.0, float(self.spacing))
        self._means = tuple((x, y) for x in coords for y in coords)
        if self.weights is None:
            w = [0.01, 0.1, 0.3, 0.2, 0.02, 0.15, 0.02, 0.15, 0.05]
        else:
            if len(self.weights) != len(self._means):
                raise ValueError(f"weights must have length {len(self._means)}")
            w = list(self.weights)
        total = sum(w)
        if total <= 0:
            raise ValueError("weights must sum to a positive value")
        self.weights = tuple(ww / total for ww in w)
        self._logw = None

    def sample(self, n: int, device=None, dtype=None) -> Tensor:
        device, dtype = _as_device_dtype(device, dtype)
        weights = torch.tensor(self.weights, device=device, dtype=dtype)
        cat = torch.distributions.Categorical(probs=weights)
        idx = cat.sample((n,))
        means = torch.tensor(self._means, device=device, dtype=dtype)
        std = math.sqrt(self.var)
        noise = std * torch.randn(n, 2, device=device, dtype=dtype)
        return means[idx] + noise

    def log_prob(self, x: Tensor) -> Tensor:
        if self._logw is None or self._logw.device != x.device or self._logw.dtype != x.dtype:
            self._logw = torch.log(torch.tensor(self.weights, device=x.device, dtype=x.dtype))
        means = x.new_tensor(self._means)
        diff = x[:, None, :] - means[None, :, :]
        quad = (diff ** 2).sum(dim=-1) / self.var
        log_comp = -0.5 * (quad + 2 * math.log(2 * math.pi * self.var))
        return _logsumexp(self._logw + log_comp, dim=-1)

# -----------------------------------------------------------------------------
# MNIST sampler
# -----------------------------------------------------------------------------


class MNISTSampler:
    """Random batches from MNIST with optional flattening."""

    def __init__(
        self,
        *,
        root: str = "./data",
        train: bool = True,
        download: bool = True,
        transform=None,
        flatten: bool = True,
        preload_batch_size: int = 1024,
    ) -> None:
        if datasets is None or transforms is None or DataLoader is None:
            raise ImportError(
                "torchvision is required for the MNIST sampler but is not available"
            )

        if transform is None:
            transform = transforms.Compose(
                [
                    #transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5,), (0.5,)),
                ]
            )

        dataset = datasets.MNIST(
            root=root,
            train=train,
            download=download,
            transform=transform,
        )

        loader = DataLoader(dataset, batch_size=preload_batch_size, shuffle=False)
        chunks = []
        for images, _ in loader:
            if flatten:
                images = images.view(images.shape[0], -1)
            chunks.append(images)

        self.data = torch.cat(chunks, dim=0).contiguous()
        self.flatten = flatten
        self.image_shape = (1, 28, 28)
        self.dim = self.data.shape[1] if flatten else self.data.shape[1:]

    def sample(
        self,
        n: int,
        device: Optional[torch.device | str] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> torch.Tensor:
        device, dtype = _as_device_dtype(device, dtype)
        idx = torch.randint(0, self.data.shape[0], (n,))
        batch = self.data[idx]
        if dtype is not None and batch.dtype != dtype:
            batch = batch.to(dtype)
        if device is not None:
            batch = batch.to(device)
        return batch



# -----------------------------------------------------------------------------
# CIFAR-10 sampler
# -----------------------------------------------------------------------------


class CIFAR10Sampler:
    """Random batches from CIFAR-10 with optional flattening and class access."""

    def __init__(
        self,
        *,
        root: str = "./data",
        train: bool = True,
        download: bool = True,
        transform=None,
        flatten: bool = True,
        preload_batch_size: int = 1024,
    ) -> None:
        if datasets is None or transforms is None or DataLoader is None:
            raise ImportError(
                "torchvision is required for the CIFAR-10 sampler but is not available"
            )

        if transform is None:
            transform = transforms.Compose(
                [
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]
            )

        dataset = datasets.CIFAR10(
            root=root,
            train=train,
            download=download,
            transform=transform,
        )

        loader = DataLoader(dataset, batch_size=preload_batch_size, shuffle=False)

        image_storage = []
        label_storage = []
        for images, labels in loader:
            if flatten:
                images = images.view(images.shape[0], -1)
            image_storage.append(images)
            label_storage.append(labels)

        self.data = torch.cat(image_storage, dim=0).contiguous()
        self.labels = torch.cat(label_storage, dim=0)
        self.flatten = flatten
        self.image_shape = (3, 32, 32)
        self.dim = self.data.shape[1] if flatten else self.data.shape[1:]

        # Cache per-class index tensors for deterministic subset sampling.
        num_classes = 10
        class_indices = []
        for cls in range(num_classes):
            mask = (self.labels == cls).nonzero(as_tuple=False).view(-1)
            class_indices.append(mask)
        self.class_indices = class_indices
        self.class_counts = tuple(int(idx.shape[0]) for idx in class_indices)
        self.num_classes = num_classes
        self.num_samples = int(self.data.shape[0])

    def _gather(
        self,
        indices: torch.Tensor,
        *,
        device: Optional[torch.device | str] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        device, dtype = _as_device_dtype(device, dtype)
        batch = self.data[indices]
        labels = self.labels[indices]
        if dtype is not None and batch.dtype != dtype:
            batch = batch.to(dtype)
        return batch.to(device), labels.to(device)

    def sample(
        self,
        n: int,
        *,
        device: Optional[torch.device | str] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> torch.Tensor:
        idx = torch.randint(0, self.data.shape[0], (n,), dtype=torch.long)
        batch, _ = self._gather(idx, device=device, dtype=dtype)
        return batch

    def sample_with_labels(
        self,
        n: int,
        *,
        device: Optional[torch.device | str] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        idx = torch.randint(0, self.data.shape[0], (n,), dtype=torch.long)
        return self._gather(idx, device=device, dtype=dtype)

    def sample_class_subset(
        self,
        cls: int,
        count: int,
        *,
        device: Optional[torch.device | str] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if not (0 <= cls < self.num_classes):
            raise ValueError(f"Class index out of range: {cls}")
        pool = self.class_indices[cls]
        if count > int(pool.shape[0]):
            raise ValueError(
                f"Requested {count} samples for class {cls}, but only {int(pool.shape[0])} available"
            )
        choice = torch.randperm(pool.shape[0])[:count]
        idx = pool[choice]
        return self._gather(idx, device=device, dtype=dtype)



# -----------------------------------------------------------------------------
# Factory helper
# -----------------------------------------------------------------------------

def get_distribution(name: str, **kwargs) -> BaseDistribution2D:
    name = name.lower()
    if name in {"mnist"}:
        return MNISTSampler(**kwargs)
    if name in {"cifar", "cifar10", "cifar-10"}:
        return CIFAR10Sampler(**kwargs)
    if name in {"annulus", "ring", "thinannulus"}:
        return ThinAnnulus(**kwargs)
    if name in {"rings", "concentricrings"}:
        return ConcentricRingsMixture(**kwargs)
    if name in {"twomoons", "moons"}:
        return TwoMoons(**kwargs)
    if name in {"spiral", "spiral2d"}:
        return Spiral2D(**kwargs)
    if name in {"rotgauss", "rotatedgaussian"}:
        return RotatedGaussian(**kwargs)
    if name in {"rosenbrock"}:
        return Rosenbrock2D(**kwargs)
    if name in {"funnel", "nealfunnel"}:
        base = NealFunnel2D(**kwargs)
        mean = torch.zeros(2)
        std = torch.tensor([
            base.sigma1,
            math.exp(0.25 * (base.alpha ** 2) * (base.sigma1 ** 2)),
        ])
        return ZScoreWrapper(base, mean, std)
    if name in {"cauchy", "indcauchy"}:
        return IndependentCauchy2D(**kwargs)
    if name in {"gausscauchymixture", "gcmix"}:
        return GaussianCauchyMixture2D(**kwargs)
    if name in {"sine", "nearmanifold", "nearmanifoldsine"}:
        return NearManifoldSine(**kwargs)
    if name in {"atomgauss", "atomplusgaussian"}:
        return AtomPlusGaussian2D(**kwargs)
    if name in {"checker", "checkerboard"}:
        return CheckerboardStripes(**kwargs)
    if name in {"spikygmm", "gmmspiky"}:
        return SpikyGMM2D(**kwargs)
    if name in {"gridgmm", "gridgmm9", "gmmgrid"}:
        return GridGMM9(**kwargs)
    raise ValueError(f"Unknown distribution name: {name}")


# -----------------------------------------------------------------------------
# Quick demo (run as script)
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    torch.manual_seed(0)
    dists = [
        ThinAnnulus(R=3.0, eps=0.1),
        ConcentricRingsMixture(radii=(1.0, 2.0, 3.2), eps=0.07),
        TwoMoons(noise=0.03),
        Spiral2D(turns=3.0, noise=0.03),
        RotatedGaussian(var_small=1e-4, var_large=1.0, theta_deg=30.0),
        Rosenbrock2D(b=100.0),
        NealFunnel2D(sigma1=3.0, alpha=1.0),
        IndependentCauchy2D(scale=(1.0, 0.5)),
        GaussianCauchyMixture2D(weight_cauchy=0.2),
        NearManifoldSine(freq=5.0, sigma_perp=0.03),
        AtomPlusGaussian2D(weight_atom=0.3),
        CheckerboardStripes(low=-4.0, high=4.0),
        SpikyGMM2D(weight_spike=0.03),
        GridGMM9(var=0.01)
    ]
    for d in dists:
        x = d.sample(8)
        msg = f"{d.__class__.__name__}: sample shape={tuple(x.shape)}"
        if getattr(d, 'has_log_prob', False):
            try:
                lp = d.log_prob(x)
                msg += f", log_prob shape={tuple(lp.shape)}"
            except Exception as e:
                msg += f", log_prob: {type(e).__name__}"
        print(msg)
