import torch
from torch.utils.data import Dataset

from src.distributions.gaussian_mixture import GaussianMixtureDistribution


class OneDimensionalBridgeDataset(Dataset):
    def __init__(self, num_samples: int = 10_000, paired: bool = False):
        self.num_samples = num_samples
        self.paired = paired
        self.x0, self.y = self.generate_data()

    def generate_data(self):
        p0 = GaussianMixtureDistribution(
            mu=[-1.2, 0, 1.2], sigma=[0.2, 0.2, 0.2], pi=[0.3, 0.4, 0.3]
        )

        p1 = GaussianMixtureDistribution(
            mu=[-0.8, 0.8], sigma=[0.5, 0.2], pi=[0.8, 0.2]
        )

        x0 = p0.sample(self.num_samples)
        y = p1.sample(self.num_samples)

        if self.paired:  # sort to pair samples
            indices = x0.argsort()
            x0 = x0[indices]

            indices = y.argsort()
            y = y[indices]

        x0 = torch.from_numpy(x0).float()
        y = torch.from_numpy(y).float()

        return x0, y

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.x0[idx], self.y[idx]
