from .base_dataset import DiffusionDataset
from ..singular_2d import Singular2D

MEANS = [
    [0.0, 1.0],
    [-1., -1.]
]

COVS = [
    [
        [0.1, 0],
        [0, 0.1]
    ],
    [
        [0.5, 0],
        [0, 0.5]
    ]
]

class CircleMultimodal(DiffusionDataset):
    def __init__(self, 
                 n_samples = 1000, 
                 eps = 0.001,
                 means = None, 
                 covs = None,
                 ws = None,
                 coeffs = None):
        
        super().__init__()
        if means is None:
            means = MEANS

        if covs is None:
            covs = COVS

        generator = Singular2D(1., 1, device = 'cpu', 
                               means= means,
                               covs = covs,
                               ws = ws,
                               N_Integral=100,
                               )
        
        self.data = generator.sample_t(eps, n_samples=n_samples).cpu().numpy()

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

    def __getitems__(self, idxs):
        return self.data[idxs]

