import sys

from dataset.morphomnist_dataset import MorphoMNISTDataset

from .gene_dataset import GeneDataset
from .celebA_dataset import CelebADataset
from .openBHB_dataset import OpenBHBDataset
from .pendulum_dataset import PendulumDataset

import warnings
warnings.filterwarnings("ignore")

if not sys.warnoptions:
    warnings.simplefilter("ignore")
warnings.simplefilter(action="ignore", category=FutureWarning)

def load_dataset_splits(
    data_name: str,
    data_path: str,
    sample_cf: bool = False
):

    if data_name == "gene":
        dataset = GeneDataset(
            data_path, "perturbation", "control", "dose", "covariates", "split", 
            sample_cf=sample_cf
        )

        return {
            "train": dataset.subset("train", "all"),
            "test": dataset.subset("test", "all"),
            "ood": dataset.subset("ood", "all"),
        }
    elif data_name == "celebA":
        return {
            "train": CelebADataset(data_path, split="train"),
            "valid": CelebADataset(data_path, split="valid"),
            "test": CelebADataset(data_path, split="test"),
        }
    elif data_name == "openBHB":
        return {
            "train": OpenBHBDataset(data_path, split="train"),
            "valid": OpenBHBDataset(data_path, split="valid"),
            "test": OpenBHBDataset(data_path, split="test"),
        }
    elif data_name == "pendulum":
        return {
            "train": PendulumDataset(data_path, split="train"),
            "valid": PendulumDataset(data_path, split="valid"),
            "test": PendulumDataset(data_path, split="test"),
        }
    elif data_name == "morphomnist":
        return {
            "train": MorphoMNISTDataset(data_path, split="train"),
            "valid": MorphoMNISTDataset(data_path, split="valid"),
            "test": MorphoMNISTDataset(data_path, split="test"),
        }
    else:
        raise ValueError("data_name not recognized")
