# Import Python packages.
import math
import warnings
from typing import Any, List, Mapping, Optional, Sequence, Tuple, TypeVar

# Import external packages.
import numpy as np
import pandas as pd

# Import relatively from other modules.
from ...data import DataTabular
from ...datasets import DatasetTabular, DatasetTabularSimple
from ...transforms import ErrorTransformUnsupportPartial
from .base import TransdatasetSplit


# Type aliases.
Input = List[DatasetTabular]
Output = List[DatasetTabular]


# Self types.
SelfTransdatasetSplitTabular = TypeVar(
    "SelfTransdatasetSplitTabular", bound="TransdatasetSplitTabular"
)


class TransdatasetSplitTabular(TransdatasetSplit[DatasetTabular]):
    r"""
    Transformation for splitting a tabular dataset into multiple tabular datasets.
    """
    # Transformation unique identifier.
    _IDENTIFIER = "split.dataset.tabular"

    def input(self: SelfTransdatasetSplitTabular, raw: Any, /) -> Input:
        r"""
        Convert raw data into input to the transformation.

        Args
        ----
        - raw
            Raw data.

        Returns
        -------
        - process
            Processed data compatible with the transformation.
        """
        # Conversion will vary according to raw data.
        if raw is None:
            # Nothing to be unraveled.
            return [
                DatasetTabularSimple.from_memalias(
                    [
                        DataTabular.from_numeric(
                            {"generic": np.array([], dtype=np.int64)},
                            sort_columns="alphabetic",
                            sort_rows="rankable",
                        )
                    ],
                    ["full"],
                    sorts=("alphabetic", "rankable"),
                )
            ]
        else:
            # All the other cases are not supported.
            raise ErrorTransformUnsupportPartial(
                f"Try to formalize incompatible raw data into input domain of"
                f' "{self._IDENTIFIER:s}".'
            )

    def output(self: SelfTransdatasetSplitTabular, raw: Any, /) -> Output:
        r"""
        Convert raw data into output from the transformation.

        Args
        ----
        - raw
            Raw data.

        Returns
        -------
        - process
            Processed data compatible with the transformation.
        """
        # Conversion will vary according to raw data.
        if raw is None:
            # Nothing to be unraveled.
            return [
                DatasetTabularSimple.from_memalias(
                    [
                        DataTabular.from_numeric(
                            {"generic": np.array([], dtype=np.int64)},
                            sort_columns="alphabetic",
                            sort_rows="rankable",
                        )
                    ],
                    ["full"],
                    sorts=("alphabetic", "rankable"),
                ),
                DatasetTabularSimple.from_memalias(
                    [
                        DataTabular.from_numeric(
                            {"generic": np.array([], dtype=np.int64)},
                            sort_columns="alphabetic",
                            sort_rows="rankable",
                        )
                    ],
                    ["full"],
                    sorts=("alphabetic", "rankable"),
                ),
            ]
        else:
            # All the other cases are not supported.
            raise ErrorTransformUnsupportPartial(
                f"Try to formalize incompatible raw data into output domain of"
                f' "{self._IDENTIFIER:s}".'
            )

    def transform(
        self: SelfTransdatasetSplitTabular,
        input: Input,
        /,
        *args: Any,
        props: Sequence[Tuple[Sequence[str], Tuple[int, int]]] = [],
        sort_columns: Optional[str] = None,
        sort_rows: Optional[str] = None,
        **kwargs: Any,
    ) -> Output:
        r"""
        Transform input into output without inplacement.

        Args
        ----
        - input
            Input to the transformation.
        - props
            Proportions of major and minor split halves at each memory slot.
        - sort_columns
            Column sorting algorithm name in output domain.
            If not given, it will inherit from input domain.
        - sort_rows
            Row sorting algorithm name in output domain.
            If not given, it will inherit from input domain.

        Returns
        -------
        - output
            Output from the transformation.
        """
        # Parse input memory.
        (dataset,) = input

        # Collect proportions of both halves.
        props_ = self._collect_props(props)

        # Perform split sorting.
        memory: Sequence[DataTabular]
        if self.sorts:
            # Create a sorted copy of original memory.
            assert (
                not self.allow_alias
            ), "Memory aliasing is not compatible with tabular dataset split sorting."
            sort_split_columns, sort_split_rows = self.sorts
            memory = [
                DataTabular(
                    data._content, sort_columns=sort_split_columns, sort_rows=sort_split_rows
                )
                for data in dataset.memory
            ]
        else:
            # If split sorting is null, we use original dataset order directly.
            memory = dataset.memory

        # Collect content of major and minor halves.
        slots_major = []
        slots_minor = []
        names = []
        for data, name in zip(memory, dataset.memory_names):
            # Group tabular data by combination tuples of sharing column values to ensure both split
            # half to share those values.
            groups = (
                [content for _, content in data._content.groupby(list(self.groupbys))]
                if self.groupbys
                else [data._content]
            )
            (prop_major, prop_minor) = props_[name]

            # Peform splitting for each group separately, and merge all groups independently for
            # each half.
            buf_major = []
            buf_minor = []
            for content in groups:
                # Before uniform split, break focusing group into equally heading and tailing
                # halves, and if there is a remaining middle row, prioritize it into major half.
                size = len(content)
                threshold = int(math.floor(float(size) * 0.5))
                threshold_head = threshold
                threshold_tail = threshold + size % 2

                # Take indices of heading half in "distribute" schema from first to last.
                content_head = content.iloc[:threshold_head]
                indices_major_, indices_minor_ = self._collect_split_indices_distribute(
                    len(content_head), (prop_major, prop_minor), reverse=False
                )
                buf_major.append(content_head.iloc[list(indices_major_)])
                buf_minor.append(content_head.iloc[list(indices_minor_)])

                # Take the potentially middle row left from equally splitting eading and tailing
                # halves.
                if size % 2 == 1:
                    # Prioritize it into major half.
                    # Pay attention to a special case where major proportion is zero.
                    (buf_minor if prop_major == 0 else buf_major).append(content.iloc[[threshold]])

                # Take indices of tailing half in "distribute" schema from last to first.
                content_tail = content.iloc[threshold_tail:]
                indices_major_, indices_minor_ = self._collect_split_indices_distribute(
                    len(content_tail), (prop_major, prop_minor), reverse=True
                )
                buf_major.append(content_tail.iloc[list(indices_major_)])
                buf_minor.append(content_tail.iloc[list(indices_minor_)])
            slots_major.append(pd.concat(buf_major))
            slots_minor.append(pd.concat(buf_minor))
            names.append(name)

        # Create output datasets.
        sort_output_columns = dataset._sorts[0] if sort_columns is None else sort_columns
        sort_output_rows = dataset._sorts[1] if sort_rows is None else sort_rows
        datasets: List[DatasetTabular]
        datasets = [
            DatasetTabularSimple.from_memalias(
                [
                    DataTabular(
                        content,
                        sort_columns=sort_output_columns,
                        sort_rows=sort_output_rows,
                        allow_alias_disambiguition=self.allow_alias,
                    )
                    for content in slots
                ],
                names,
                sorts=(sort_output_columns, sort_output_rows),
            )
            for slots in [slots_major, slots_minor]
        ]
        return datasets

    def inverse(
        self: SelfTransdatasetSplitTabular,
        output: Output,
        /,
        *args: Any,
        sort_columns: Optional[str] = None,
        sort_rows: Optional[str] = None,
        **kwargs: Any,
    ) -> Input:
        r"""
        Inverse output into input without inplacement.

        Args
        ----
        - output
            Output from the transformation.
        - sort_columns
            Column sorting algorithm name in input domain.
            If not given, it will inherit from output major domain.
        - sort_rows
            Row sorting algorithm name in input domain.
            If not given, it will inherit from output major domain.

        Returns
        -------
        - input
            Input to the transformation.
        """
        # Parse output memory.
        dataset_major, dataset_minor = output
        assert set(dataset_major.memory_names) == set(dataset_minor.memory_names), (
            "Memory slots do not match between major and minor datasets to be inversed from"
            " splitting."
        )

        # For some disambiguition sorting algorithms, we can not guarantee a perfect inverse, thus a
        # warning will be raised.
        sort_input_columns = dataset_major._sorts[0] if sort_columns is None else sort_columns
        sort_input_rows = dataset_major._sorts[1] if sort_rows is None else sort_rows
        if sort_input_columns == "identity" or sort_input_rows == "identity":
            # Identity sorting can be ambiguous
            warnings.warn(
                f'Input tabular domain uses "{str((sort_input_columns, sort_input_rows)):s}"'
                f" disambiguition, which may result in ambiguition in inverson.",
                RuntimeWarning,
            )

        # Merge splitted datasets back together in the basis of major dataset.
        dataset = DatasetTabularSimple.from_memalias(
            [
                DataTabular(
                    pd.concat(
                        [
                            dataset_major.memory[dataset_major._memory_indices[name]]._content,
                            dataset_minor.memory[dataset_minor._memory_indices[name]]._content,
                        ]
                    ),
                    sort_columns=sort_input_columns,
                    sort_rows=sort_input_rows,
                    allow_alias_disambiguition=self.allow_alias,
                )
                for name in dataset_major.memory_names
            ],
            dataset_major.memory_names,
            sorts=(sort_input_columns, sort_input_rows),
        )

        # Construct inversed input memory by the only merged dataset.
        input: List[DatasetTabular]
        input = [dataset]
        return input

    def fit(
        self: SelfTransdatasetSplitTabular,
        input: Input,
        output: Output,
        /,
        sorts: Optional[Tuple[str, str]] = None,
        groupbys: Sequence[str] = [],
        *args: Any,
        **kwargs: Any,
    ) -> SelfTransdatasetSplitTabular:
        r"""
        Fit transformation parameters by example input and output.

        Args
        ----
        - input
            Example input to the transformation.
        - output
            Example output from the transformation.
        - sorts
            Sorting algorithms to be applied before cutting original dataset into major and minor
            halves.
            If it is not None, original dataset order is used.
        - groupbys
            Tabular data columns whose unique values must be shared between both halves except that
            rare values (which only exists in the major half).
            If empty, no sharing value restriction is required.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Parse input memory.
        (dataset,) = input

        # Safety check.
        columns = set(next(iter(dataset.memory))._content.columns)
        missing = list(sorted(set(groupbys) - columns))
        missing_ = [*missing[:3], "..."] if len(missing) > 3 else missing
        assert (
            not missing
        ), "Columns required value sharing between both halves are missing: {:s}".format(
            ", ".join(f'"{name:s}"' for name in missing_)
        )

        # Simply save essential attributes with safety check.
        self.sorts = sorts
        self.groupbys = groupbys
        return self

    def get_alphabetic_data(self: SelfTransdatasetSplitTabular, /) -> Mapping[str, Any]:
        r"""
        Get alphabetic data of the transformation.

        Args
        ----

        Returns
        -------
        - data
            Alphabetic data of the transformation.
        """
        # Collect essential attributes for operating datasets.
        return {"sorts": self.sorts, "groupbys": self.groupbys}

    def set_alphabetic_data(
        self: SelfTransdatasetSplitTabular, data: Mapping[str, Any], /  # noqa: W504
    ) -> SelfTransdatasetSplitTabular:
        r"""
        Set alphabetic data of the transformation.

        Args
        ----
        - data
            Alphabetic data of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Safety check.
        assert (
            "sorts" in data
        ), "Split sorting algorithms of data containers for the operating dataset are missing."
        assert (
            "groupbys" in data
        ), "Sharing columns for the split of the operating dataset are missing."

        # Parse loaded metadata.
        sorts = data["sorts"]

        # Load disambiguition sorting algorithms of data containers.
        self.sorts = tuple(sorts) if sorts else None
        self.groupbys = data["groupbys"]
        return self
