import torch
from torch.utils.data import Dataset

from src.distributions.gaussian_mixture import GaussianMixtureDistribution


class OneDimensionalGenerativeDataset(Dataset):
    def __init__(self, num_samples: int = 10_000):
        self.num_samples = num_samples
        self.x0 = self.generate_data()

    def generate_data(self):
        gm = GaussianMixtureDistribution(
            mu=[-1.2, 0, 1.2], sigma=[0.2, 0.2, 0.2], pi=[0.3, 0.4, 0.3]
        )

        return torch.from_numpy(gm.sample(self.num_samples)).float()

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.x0[idx]
