from typing import Callable, Dict

from .dataloader import (
    DataLoaders,
    get_psys_and_dataloaders,
)
from .datasets.base import (
    BaseDataset,
    PartiallySplitDataset,
    PresplitDataset,
    RawSample,
    SupportsIndex,
    Targets,
    UnsplitDataset,
)
from .datasets.ensemble import DatasetEnsemble
from .datasets.md17 import MD17
from .datasets.oqdc_sets import DES370K, SPICE, QMugs
from .datasets.qm9 import QM9
from .datasets.qm40 import QM40
from .datasets.threebpa import ThreeBPA
from .transform import (
    DensityMatrices,
    InitialDensityMatrixFn,
    ToJaxTransform,
    get_initial_density_matrix_fn,
    get_jax_transform,
    get_preload_transform,
)
from .utils import IndexWrapper

key_to_dataset: Dict[str, Callable] = {
    'md17': MD17,
    'qm9': QM9,
    'qm40': QM40,
    '3bpa': ThreeBPA,
    'des370k': DES370K,
    'qmugs': QMugs,
    'spice': SPICE,
}
