import torch


class GaussianMixtureGenerator:
    """
    Generates points from a Mixture of Gaussians arranged on a hypersphere.
    """

    def __init__(
        self,
        n_modes=8,
        radius=5.0,
        std=0.5,
        n_dims=2,
        seed=None,
        device=None,
        probs=None,
    ):
        self.n_modes = n_modes
        self.radius = radius
        self.std = std
        self.n_dims = n_dims
        self.device = device or torch.device("cpu")
        if probs is not None:
            self.probs = probs
        else:
            self.probs = torch.ones(n_modes) / n_modes

        if seed is not None:
            torch.manual_seed(seed)

        self.centers = self._generate_hypersphere_centers()

    def _generate_hypersphere_centers(self):
        centers = []

        # Positive axes
        for i in range(min(self.n_modes, self.n_dims)):
            vec = torch.zeros(self.n_dims, device=self.device)
            vec[i] = 1.0
            centers.append(vec)

        # Negative axes
        remaining = self.n_modes - len(centers)
        for i in range(min(remaining, self.n_dims)):
            vec = torch.zeros(self.n_dims, device=self.device)
            vec[i] = -1.0
            centers.append(vec)

        # Random directions
        remaining = self.n_modes - len(centers)
        if remaining > 0:
            raw = torch.randn(remaining, self.n_dims, device=self.device)
            unit_vectors = raw / torch.norm(raw, dim=1, keepdim=True)
            centers.extend(unit_vectors)

        centers = torch.stack(centers) * self.radius
        return centers

    def generate(self, num_points):
        mode_indices = torch.multinomial(self.probs, num_points, replacement=True).to(
            self.device
        )
        point_centers = self.centers[mode_indices]
        noise = torch.randn(num_points, self.n_dims, device=self.device) * self.std
        return (point_centers + noise).float()
