from typing import Dict, List, Tuple

from metatensor.torch import TensorMap
from omegaconf import DictConfig

from .dataset import Dataset, DiskDataset, MemmapDataset
from .readers import read_extra_data, read_systems, read_targets
from .target_info import TargetInfo


def get_dataset(
    options: DictConfig,
) -> Tuple[Dataset, Dict[str, TargetInfo], Dict[str, TargetInfo]]:
    """
    Gets a dataset given a configuration dictionary.

    The system and targets in the dataset are read from one or more
    files, as specified in ``options``.

    :param options: the configuration options for the dataset.
        This configuration dictionary must contain keys for both the
        systems and targets in the dataset.

    :returns: A tuple containing a ``Dataset`` object and a
        ``Dict[str, TargetInfo]`` containing additional information (units,
        physical quantities, ...) on the targets in the dataset
    """

    extra_data_info_dictionary = {}

    if options["systems"]["read_from"].endswith(".zip"):  # disk dataset
        dataset = DiskDataset(
            options["systems"]["read_from"],
            fields=[*options["targets"], *options.get("extra_data", {})],
        )
        target_info_dictionary = dataset.get_target_info(options["targets"])
        if "extra_data" in options:
            extra_data_info_dictionary = dataset.get_target_info(options["extra_data"])
    elif options["systems"]["read_from"].endswith("_mm/"):  # mmap dataset
        conservative = True
        if "forces" in options["targets"]["energy"]:
            conservative = options["targets"]["energy"]["forces"]
        dataset = MemmapDataset(
            options["systems"]["read_from"],
            conservative=conservative,
            non_conservative=("non_conservative_forces" in options["targets"])
        )
        target_info_dictionary = dataset.get_target_info(options["targets"])
        if "extra_data" in options:
            extra_data_info_dictionary = dataset.get_target_info(options["extra_data"])
    else:
        systems = read_systems(
            filename=options["systems"]["read_from"],
            reader=options["systems"]["reader"],
        )
        targets, target_info_dictionary = read_targets(conf=options["targets"])
        extra_data: Dict[str, List[TensorMap]] = {}
        if "extra_data" in options:
            extra_data, extra_data_info_dictionary = read_extra_data(
                conf=options["extra_data"]
            )
            intersecting_keys = targets.keys() & extra_data.keys()
            if intersecting_keys:
                raise ValueError(
                    f"Extra data keys {intersecting_keys} overlap with target keys. "
                    "Please use unique keys for targets and extra data."
                )
        dataset = Dataset.from_dict({"system": systems, **targets, **extra_data})

    return dataset, target_info_dictionary, extra_data_info_dictionary
