import torch
from margflow.datasets.synthetic_datasets import MixtureOfGaussian, TwoMoons
from torch.utils.data import TensorDataset


class KeysToAttribute:
    def __init__(self, dictionary):
        self.__dict__.update(dictionary)


def load_margflow_dataset(name, **kwargs):
    args = KeysToAttribute(kwargs["args"])
    if name == "mog":
        dataset = MixtureOfGaussian(args=args)
    elif name == "two_moons":
        dataset = TwoMoons(args=args)

    train_tensor, val_tensor, test_tensor = dataset.load_dataset()
    train_dataset = TensorDataset(torch.tensor(train_tensor, dtype=torch.get_default_dtype()))
    val_dataset = TensorDataset(torch.tensor(val_tensor, dtype=torch.get_default_dtype()))
    test_dataset = TensorDataset(torch.tensor(test_tensor, dtype=torch.get_default_dtype()))

    return train_dataset, val_dataset, test_dataset
