# Import dataset classes
# Import other datasets as they are added
from training.datasets.DiffusionPDE.dataset_darcy import DiffusionPDEDarcyDataset
from training.datasets.DiffusionPDE.dataset_poisson import DiffusionPDEPoissonDataset
from training.datasets.DiffusionPDE.dataset_helmholtz import DiffusionPDEHelmholtzDataset
from training.datasets.DiffusionPDE.dataset_ns_nonbounded import DiffusionPDENavierStokesNBDataset
from training.datasets.DiffusionPDE.dataset_ns_bounded import DiffusionPDENavierStokesBDataset
from training.datasets.FNO.dataset_darcy import FNODarcyDataset

# Dictionary to map dataset names to dataset classes
DATASET_CLASSES = {
    'diffusion_pde_darcy': DiffusionPDEDarcyDataset,
    'diffusion_pde_poisson': DiffusionPDEPoissonDataset,
    'diffusion_pde_helmholtz': DiffusionPDEHelmholtzDataset,
    'fno_darcy': FNODarcyDataset,
    'diffusion_pde_ns_nonbounded': DiffusionPDENavierStokesNBDataset,
    'diffusion_pde_ns_bounded': DiffusionPDENavierStokesBDataset, 
    # 'darcy': 'DarcyDataset',
    # 'wave': 'WaveDataset',
    # Add more mappings here as needed
}

# Function to get the dataset class based on the dataset name
def get_dataset_class(dataset_name):
    """Retrieve the dataset class based on a dataset name string."""
    if dataset_name in DATASET_CLASSES:
        return DATASET_CLASSES[dataset_name]
    else:
        available_datasets = ", ".join(DATASET_CLASSES.keys())
        raise ValueError(f"Unknown dataset: '{dataset_name}'. Available datasets are: {available_datasets}")
