# Import Python packages.
import abc
import itertools
from typing import Any, Dict, List, Mapping, Sequence, Tuple, Type, TypeVar

# Import external packages.
import more_itertools as xitertools

# Import relatively from other modules.
from ...datasets import BaseDataset
from ...transforms import BaseTransformList
from ...types import NPANYS


# Type variables.
Dataset = TypeVar("Dataset", bound="BaseDataset[Any]")


# Self types.
SelfTransdatasetSplit = TypeVar("SelfTransdatasetSplit", bound="TransdatasetSplit[Any]")


class TransdatasetSplit(BaseTransformList[Dataset, Dataset]):
    r"""
    Transformation for splitting a dataset into multiple datasets.
    """
    # Transformation unique identifier.
    _IDENTIFIER = "split.dataset"

    def __init__(
        self: SelfTransdatasetSplit, /, *args: Any, allow_alias: bool = True, **kwargs: Any
    ) -> None:
        r"""
        Initialize the class.

        Args
        ----
        - allow_alias
            If True, Allow content is an alias of other object.
            Otherwise, the class will make a completed copy of the content.

        Returns
        -------
        """
        # Super call.
        BaseTransformList.__init__(self, *args, **kwargs)

        # Save essential arguments.
        self.allow_alias = allow_alias

    @abc.abstractmethod
    def transform(
        self: SelfTransdatasetSplit,
        input: List[Dataset],
        /,
        *args: Any,
        props: Sequence[Tuple[Sequence[str], Tuple[int, int]]] = [],
        **kwargs: Any,
    ) -> List[Dataset]:
        r"""
        Transform input into output without inplacement.

        Args
        ----
        - input
            Input to the transformation.
        - props
            Proportions of major and minor split halves at each memory slot.

        Returns
        -------
        - output
            Output from the transformation.
        """

    @classmethod
    def _collect_props(
        cls: Type[SelfTransdatasetSplit],
        raw: Sequence[Tuple[Sequence[str], Tuple[int, int]]],
        /,  # noqa: W504
    ) -> Mapping[str, Tuple[int, int]]:
        r"""
        Collect formalized proportions of both halves.

        Args
        ----
        - raw
            Raw proportions of both havlves.

        Returns
        -------
        - processed
            Processed proportions of both havlves.
        """
        # Collect proportions of both halves.
        processed: Dict[str, Tuple[int, int]]
        processed = {}
        for names, it in raw:
            # Save items corresponding to each name independently.
            for name in names:
                # Ensure items are defined without conflicts.
                assert (
                    name not in processed
                ), f'Split proportions are duplicated for memory slot "{name:s}".'
                processed[name] = it
        return processed

    @classmethod
    def _collect_split_indices_distribute(
        cls: Type[SelfTransdatasetSplit],
        total: int,
        props: Tuple[int, int],
        /,
        *,
        reverse: bool = False,
    ) -> Tuple[Sequence[int], Sequence[int]]:
        r"""
        Collect indices of both halves in "distribute" schema following the proportions.

        Args
        ----
        - total
            Total number of elements to be collected.
        - props
            Proportions of major and minor split halves over all collecting elements.
        - reverse
            If True, collect from last to first element.
            If False, collect from first to last element which is the default.

        Returns
        -------
        - indices_major
            Indices of major half elements.
        - indices_minor
            Indices of minor half elements.
        """
        # Distribute total indices of given direction uniformly into total number of proportions.
        prop_major, prop_minor = props
        splits = xitertools.distribute(
            prop_major + prop_minor, range(total - 1, -1, -1) if reverse else range(total)
        )

        # Take splits in round robin schema among major and minor halves, respecting their
        # proportions.
        # For example, if we have major and minor proportions, 5 and 2, we will take by round robin
        # for the first 4 splits, where splits 0 and 2 are for major, and splits 1 and 3 are for
        # minor, then for the rest splits 4, 5, and 6, they are for major.
        num_splits = len(splits)
        num_splits_round_robin = min(prop_major, prop_minor) * 2
        num_splits_rest_major = max(prop_major - prop_minor, 0)
        num_splits_rest_minor = max(prop_minor - prop_major, 0)

        # Flatten indices from corresponding index splits.
        indices_major = list(
            xitertools.flatten(
                splits[i]
                for i in [
                    *itertools.islice(range(num_splits), 0, num_splits_round_robin, 2),
                    *range(num_splits_round_robin, num_splits_round_robin + num_splits_rest_major),
                ]
            )
        )
        indices_minor = list(
            xitertools.flatten(
                splits[i]
                for i in [
                    *itertools.islice(range(num_splits), 1, num_splits_round_robin, 2),
                    *range(num_splits_round_robin, num_splits_round_robin + num_splits_rest_minor),
                ]
            )
        )
        return list(sorted(indices_major)), list(sorted(indices_minor))

    def get_metadata(self: SelfTransdatasetSplit, /) -> Mapping[str, Any]:
        r"""
        Get metadata of the transformation.

        Args
        ----

        Returns
        -------
        - metadata
            Metadata of the transformation.
        """
        # Collect essential attributes.
        return {"allow_alias": self.allow_alias}

    def get_numeric_data(self: SelfTransdatasetSplit, /) -> Mapping[str, NPANYS]:
        r"""
        Get numeric data of the transformation.

        Args
        ----

        Returns
        -------
        - data
            Numeric data of the transformation.
        """
        # Do nothing.
        return {}

    def set_metadata(
        self: SelfTransdatasetSplit, metadata: Mapping[str, Any], /  # noqa: W504
    ) -> SelfTransdatasetSplit:
        r"""
        Set metadata of the transformation.

        Args
        ----
        - metadata
            Metadata of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Safety check.
        assert "allow_alias" in metadata, "Data container alias flag is missing"

        # Get data container alias flag.
        self.allow_alias = metadata["allow_alias"]
        return self

    def set_numeric_data(
        self: SelfTransdatasetSplit, data: Mapping[str, NPANYS], /  # noqa: W504
    ) -> SelfTransdatasetSplit:
        r"""
        Set numeric data of the transformation.

        Args
        ----
        - data
            Numeric data of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Do nothing.
        return self
