from margflow.datasets.conditional_datasets import MixtureOfGaussianTime
from margflow.datasets.gas import GasDataset
from margflow.datasets.hepmass import HepmassDataset
from margflow.datasets.jaffe import ConditionalJaffe
from margflow.datasets.miniboone import MinibooneDataset
from margflow.datasets.mnist import MNIST, ConditionalMNIST
from margflow.datasets.power import PowerDataset
from margflow.datasets.synthetic_datasets import (
    MixtureOfGaussian,
    MixtureOfGaussianManifold,
    UniformSphere,
    TwoMoons,
    TwoCircles,
    SwissRoll,
    Checkerboard,
    Pinwheel,
)


def create_dataset(args):
    if args.dataset == "uniform_sphere":
        dataset = UniformSphere(args)
    elif args.dataset == "mog":
        dataset = MixtureOfGaussian(args)
    elif args.dataset == "mog_manifold":
        dataset = MixtureOfGaussianManifold(args)
    elif args.dataset == "two_circles":
        dataset = TwoCircles(args)
    elif args.dataset == "swiss_roll":
        dataset = SwissRoll(args)
    elif args.dataset == "checkerboard":
        dataset = Checkerboard(args)
    elif args.dataset == "pinwheel":
        dataset = Pinwheel(args)
    elif args.dataset == "two_moons":
        dataset = TwoMoons(args)
    elif args.dataset == "power":
        dataset = PowerDataset(args)
    elif args.dataset == "gas":
        dataset = GasDataset(args)
    elif args.dataset == "hepmass":
        dataset = HepmassDataset(args)
    elif args.dataset == "miniboone":
        dataset = MinibooneDataset(args)
    elif args.dataset == "mog_time":
        dataset = MixtureOfGaussianTime(args)
    elif args.dataset.startswith("mnist_"):
        _, autoencoder_model, sample_from_means, filter_class = args.dataset.split("_")
        filter_class = None if filter_class.lower() == "none" else int(filter_class)
        if autoencoder_model.lower() == "none":
            autoencoder_model = None
        dataset = MNIST(
            args,
            encoder_model=autoencoder_model,
            sample_from_means=sample_from_means,
            filter_class=filter_class,
        )
    elif args.dataset.startswith("condmnist_"):
        _, autoencoder_model, sample_from_means, one_hot = args.dataset.split("_")
        assert one_hot == "True" or one_hot == "False"
        if autoencoder_model.lower() == "none":
            autoencoder_model = None
        dataset = ConditionalMNIST(
            args,
            encoder_model=autoencoder_model,
            sample_from_means=sample_from_means,
            one_hot_=bool(one_hot),
        )
    elif args.dataset.startswith("jaffe_"):
        _, autoencoder_model, sample_from_means = args.dataset.split("_")
        if autoencoder_model.lower() == "none":
            autoencoder_model = None
        dataset = ConditionalJaffe(
            args,
            encoder_model=autoencoder_model,
            sample_from_means=sample_from_means,
        )
    elif args.dataset.startswith("sbi"):
        from margflow.datasets.sbi_datasets import SimulationBasedInference

        dataset = SimulationBasedInference(args)
    else:
        raise ValueError("Dataset {} not recognized".format(args.dataset))

    return dataset
