# Import Python packages.
import functools
import hashlib
import re
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, TypeVar

# Import external packages.
import more_itertools as xitertools
import pandas as pd

# Import relatively from other modules.
from ..types import NPANYS
from .base import BaseData, ErrorDataUnsupportPartial


# Self types.
SelfDataTabular = TypeVar("SelfDataTabular", bound="DataTabular")


def sort_either_identity(dataframe: pd.DataFrame, /) -> pd.DataFrame:
    r"""
    Pseudo sort either columns or rows by doing nothing.

    Args
    ----
    - dataframe
        Dataframe to be sorted.

    Returns
    -------
    - dataframe
        Sorted dataframe.
    """
    # Directly return original dataframe.
    return dataframe


def sort_columns_direct(
    dataframe: pd.DataFrame, /, *, order: Optional[Sequence[str]] = None
) -> pd.DataFrame:
    r"""
    Sort columns directly by given order.

    Args
    ----
    - dataframe
        Dataframe to be sorted.
    - order
        Column name orders.
        If it is not given, use original order of the dataframe.

    Returns
    -------
    - dataframe
        Sorted dataframe.
    """
    # Safely index all columns following given order.
    order_ = list(dataframe.columns) if order is None else list(order)
    assert set(order_) == set(
        dataframe.columns
    ), "Column name order does not match with all column names."
    return dataframe[order_]


def sort_columns_alphabetic(
    dataframe: pd.DataFrame, /, *, groups: Sequence[Optional[Sequence[str]]] = [None]
) -> pd.DataFrame:
    r"""
    Sort columns by alphabetic order of column titles with group listing.

    Args
    ----
    - dataframe
        Dataframe to be sorted.
    - groups
        Column groups in order.
        Columns in each group will be sorted in alphabetic order, and groups will be sorted by
        argument order.
        A special group defined by None is used to refer all column names except those collected in
        other groups.

    Returns
    -------
    - dataframe
        Sorted dataframe.
    """
    # Safety check.
    assert (
        sum(group is None for group in groups) <= 1
    ), "Detect more than one special groups for all remaing columns."

    # Explicitly collect and sort column names of each group.
    remaining = list(
        sorted(
            set(dataframe.columns)
            - set(xitertools.flatten([group for group in groups if group is not None]))
        )
    )
    names = list(
        xitertools.flatten(remaining if group is None else list(sorted(group)) for group in groups)
    )
    return sort_columns_direct(dataframe, order=names)


def sort_rows_direct(
    dataframe: pd.DataFrame, /, *, order: Optional[Sequence[int]] = None
) -> pd.DataFrame:
    r"""
    Sort columns directly by given order.

    Args
    ----
    - dataframe
        Dataframe to be sorted.
    - order
        Row index orders.
        If it is not given, use original order of the dataframe.

    Returns
    -------
    - dataframe
        Sorted dataframe.
    """
    # Safely index all rows following given order.
    order_ = list(range(len(dataframe))) if order is None else list(order)
    assert set(order_) == set(
        range(len(dataframe.index))
    ), "Row index order does not match with all row integer indices."
    return dataframe.iloc[order_]


def sort_rows_rankable(
    dataframe: pd.DataFrame,
    /,
    *,
    string_rank: Mapping[str, str] = {},
    agg: str = "tuple_sorted",
    ascending: bool = False,
) -> pd.DataFrame:
    r"""
    Sort columns directly by aggregated column rank.

    Args
    ----
    - dataframe
        Dataframe to be sorted.
    - string_rank
        Ranking handler for strings.
        For frequency then alphabet, use "freq_alpha" which is the default.
    - agg
        Multiple column rank aggregation strategy.
        For ranking higher values of each columns (respect column title order) as top as possible,
        use "tuple_sorted" which is the default.
    - ascending
        If True, column rank should be achieved from ascending rankable values.
        If False, it is achieved from descending rankable values.

    Returns
    -------
    - dataframe
        Sorted dataframe.
    """
    # Safety check.
    assert set(string_rank.keys()).issubset(
        set(dataframe.columns)
    ), "String rank handler keys must be a subset of column titles."
    assert all(
        name in ("freq_alpha",) for name in string_rank.values()
    ), 'String rank handler must be one of "freq_alpha".'
    assert agg in ("tuple_sorted",), 'Rank aggregator must be one of "tuple_sorted".'

    # Collect ranks of each column.
    ranks = {}
    for name, series in dataframe.items():
        # Translate column data into rankable data.
        if pd.api.types.is_numeric_dtype(series.dropna()):
            # Translate numeric data to floating numbers.
            # NaN case will be treated as negative infinite.
            rankable = series.map(float).fillna(float("-inf"))
        else:
            # Translate non-numeric data to strings and their frequencies, then combine each pair of
            # frequency and string text into rankable cell.
            # NaN case will be treated as string "nan".
            texts = series.map(str).fillna(str(float("nan")))
            freqs = texts.map(texts.value_counts())
            rankable = pd.Series(zip(freqs, texts))

        # Get the order from rankable column data.
        # If the rankable keys are the same, the rank will always be the smallest one.
        rankable_ = [(key, -i) for i, key in enumerate(rankable)]
        ranks_ = [
            (-i, r, key) for r, (key, i) in enumerate(sorted(rankable_, reverse=not ascending))
        ]
        ranks_ = [
            (i, r_, key) if r > 0 and key == key_ else (i, r_ := r, key_ := key)  # noqa: F821, F841
            for (i, r, key) in ranks_
        ]
        ranks[name] = [r for _, r, _ in sorted(ranks_)]

    # Aggregate ranks of all columns, and get the final order.
    rankable = pd.Series(
        zip(
            pd.DataFrame(ranks).agg(lambda series: tuple(sorted(series)), axis=1),
            range(len(dataframe)),
        )
    )
    rankable_ = [(key, i) for i, key in enumerate(rankable)]
    order = [i for _, i in sorted(rankable_)]
    return sort_rows_direct(dataframe, order=order)


class DataTabular(BaseData[pd.DataFrame]):
    r"""
    Data of tabular format.
    """
    # Transformation unique identifier.
    _IDENTIFIER = "tabular"

    # Tabular disambiguition sorting algorithms.
    _SORTS: Dict[str, Dict[str, Callable[[pd.DataFrame], pd.DataFrame]]]
    _SORTS = {
        "columns": {
            "identity": sort_either_identity,
            "direct": sort_columns_direct,
            "alphabetic": sort_columns_alphabetic,
        },
        "rows": {
            "identity": sort_either_identity,
            "direct": sort_rows_direct,
            "rankable": sort_rows_rankable,
        },
    }

    @classmethod
    def register_sort(
        cls: Type[SelfDataTabular],
        f: Callable[[pd.DataFrame], pd.DataFrame],
        axis: str,
        name: str,
        /,
    ) -> None:
        r"""
        Register a disambiguition sorting algorithm.

        Args
        ----
        - f
            Sorting algorithm.
        - axis
            Sorting tabular dimension.
            It should be either "columns" or "rows", but extra dimension of customized usage is
            allowed.
        - name
            Sorting algorithm name for indexing.

        Returns
        -------
        """
        # Default and custimized registration work differently.
        if axis in ("columns", "rows"):
            # Register a sorting algorithm without duplication.
            assert (
                name not in cls._SORTS[axis]
            ), f'Sorting {axis:s} algorithm "{name:s}" has been registered.'
            cls._SORTS[axis][name] = f
        else:
            # Use exception catching block to reduce redundant indexing.
            try:
                # Register a sorting algorithm without duplication.
                assert (
                    name not in cls._SORTS[axis]
                ), f'Sorting "{axis:s}" algorithm "{name:s}" has been registered.'
                cls._SORTS[axis][name] = f
            except KeyError:
                # If axis is missing, create a new sorting dimension.
                cls._SORTS[axis] = {name: f}

    @classmethod
    def get_sort(
        cls: Type[SelfDataTabular], axis: str, name: str, /  # noqa: W504
    ) -> Callable[[pd.DataFrame], pd.DataFrame]:
        r"""
        Get a disambiguition sorting algorithm.

        Args
        ----
        - axis
            Sorting tabular dimension.
            It should be either "columns" or "rows", but extra dimension of customized usage is
            allowed.
        - name
            Sorting algorithm name for indexing.

        Returns
        -------
        - f
            Sorting algorithm.
        """
        # Get the sorting algorithm from the class registration.
        return cls._SORTS[axis][name]

    def __init__(
        self: SelfDataTabular,
        content: pd.DataFrame,
        /,
        *args: Any,
        sort_columns: Optional[str] = None,
        sort_rows: Optional[str] = None,
        **kwargs: Any,
    ) -> None:
        r"""
        Initialize the class.

        Args
        ----
        - content
            Content in the data.
        - sort_columns
            Column sorting algorithm name.
        - sort_rows
            Row sorting algorithm name.

        Returns
        -------
        """
        # Super call.
        BaseData.__init__(self, content, *args, **kwargs)

        # Ensure all column indices are regular.
        is_regular = functools.partial(self.is_regular, [str])
        irregular = [
            f'"{str(index_column):s}"'
            for index_column in content.columns
            if not is_regular(index_column)
        ]
        irregular_ = [*irregular[:3], "..."] if len(irregular) > 3 else irregular
        assert not irregular, "Detect irregular column indices: {:s}.".format(", ".join(irregular_))

        # Ensure all row indices are regular.
        is_regular = functools.partial(self.is_regular, [int])
        irregular = [
            f'"{str(index_row):s}"' for index_row in content.index if not is_regular(index_row)
        ]
        irregular_ = [*irregular[:3], "..."] if len(irregular) > 3 else irregular
        assert not irregular, "Detect irregular row indices: {:s}.".format(", ".join(irregular_))

        # Ensure all content values are regular.
        is_regular = functools.partial(self.is_regular, [type(None), bool, int, float, str])
        irregular = [
            f'"{name:s}"'
            for name, series in content.items()
            if not bool(series.map(is_regular).all())
        ]
        irregular_ = [*irregular[:3], "..."] if len(irregular) > 3 else irregular
        assert (
            not irregular
        ), "Detect irregular values in regular tabular data columns: {:s}.".format(
            ", ".join(irregular_)
        )

        # Save content only after regularition check.
        self._content = content

        # Remove white spaces from column names.
        renaming = {
            name: re.sub(r"[^a-zA-Z0-9]+", "_", name).strip("_") for name in self._content.columns
        }
        assert len(set(renaming.keys())) == len(
            set(renaming.values())
        ), "Detect duplication column names after removing white spaces."
        self._content = self._content.rename(columns=renaming)

        # Tabular data will be disambiguated by sorting columns and rows.
        self.sort_columns = sort_columns
        self.sort_rows = sort_rows
        self._disambiguate(deep=not self.allow_alias_disambiguition)

    @classmethod
    def is_regular(cls: Type[SelfDataTabular], types: Sequence[type], cell: Any, /) -> bool:
        r"""
        Check if the cell value is a regular value.

        Args
        ----
        - types
            Regular value types.
        - cell
            A cell value.

        Returns
        -------
        - flag
            Flag indicating if the cell value is regular or not.
        """
        # Regular value must be a boolean, integer, floating or string.
        return isinstance(cell, tuple(types))

    def identity(self: SelfDataTabular, /) -> SelfDataTabular:
        r"""
        Generate identity (empty) element of the datum.

        Args
        ----

        Returns
        -------
        - obj
            Identity (empty) instance of the datum.
        """
        # Create an empty dataframe of the same columns.
        content = pd.DataFrame([], columns=self._content.columns)
        return self.__class__(content, sort_columns=self.sort_columns, sort_rows=self.sort_rows)

    def _disambiguate(self: SelfDataTabular, /, *, deep: bool = True) -> SelfDataTabular:
        r"""
        Remove ambiguation caused by representing or storage differences of the same datum.

        Args
        ----
        - deep
            Make a completed copy after disambiguition.

        Returns
        -------
        """
        # Safety check.
        if not deep:
            # Default sorting algorithms will return a copy, thus copy mode is fatal.
            raise ErrorDataUnsupportPartial(
                f"Tabular data disambiguition only supports copy mode, but get alias mode for"
                f' "{self._IDENTIFIER:s}" data.'
            )

        # Ensure sortings are provided.
        if not self.sort_columns:
            # Missing column sorting means the datum can be ambiguous.
            raise ErrorDataUnsupportPartial(
                f'Column sorting is not provided for "{self._IDENTIFIER:s}" data.'
            )
        if not self.sort_rows:
            # Missing row sorting means the datum can be ambiguous.
            raise ErrorDataUnsupportPartial(
                f'Row sorting is not provided for "{self._IDENTIFIER:s}" data.'
            )

        # Sort columns and rows.
        self._content = self.get_sort("columns", self.sort_columns)(self._content)
        self._content = self.get_sort("rows", self.sort_rows)(self._content)
        self._content = self._content.reset_index(drop=True).copy(deep=deep)
        return self

    def save(self: SelfDataTabular, path: str, /) -> SelfDataTabular:
        r"""
        Save the content to file system.

        Args
        ----
        - path
            Path to save the content.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Save tabular content in CSV format.
        self._content.to_csv(path, index=False)
        return self

    def load(self: SelfDataTabular, path: str, /) -> SelfDataTabular:
        r"""
        Load the content from file system.

        Args
        ----
        - path
            Path to load the content.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Load tabular content from a CSV file.
        self._content = pd.read_csv(path).reset_index(drop=True)
        return self

    @property
    def hashcode(self: SelfDataTabular, /) -> str:
        r"""
        Collect hash code of data content.

        Args
        ----

        Returns
        -------
        - code
            Hash code.
        """
        # Get hash code of recoverable content bytes iteratively.
        # The hash code includes following content separated by white space byte:
        # 1. Number of columns;
        # 3. Column names concatenated by white space byte;
        # 4. Table value type;
        # 2. Number of rows;
        # 5. Full table data in bytes.
        hashcoder = hashlib.sha256()
        hashcoder.update(bytes(str(len(self._content.columns)), "utf-8"))
        hashcoder.update(b" ")
        hashcoder.update(bytes(" ".join(self._content.columns), "utf-8"))
        hashcoder.update(b" ")
        hashcoder.update(bytes(str(self._content.values.dtype), "utf-8"))
        hashcoder.update(b" ")
        hashcoder.update(bytes(str(len(self._content.index)), "utf-8"))
        hashcoder.update(b" ")
        hashcoder.update(self._content.values.tobytes())
        return hashcoder.hexdigest()

    def to_numeric(self: SelfDataTabular, /, *args: Any, **kwargs: Any) -> Mapping[str, NPANYS]:
        r"""
        Translate into numeric format.

        Args
        ----

        Returns
        -------
        - data
            Numeric data of the content.
            Numeric data include booleans, integers, floatings, strings, lists, mappings and other
            serializable formats.
        """
        # Get numeric array of the only table directly.
        return {str(name): series.to_numpy() for name, series in self._content.items()}

    @classmethod
    def from_numeric(
        cls: Type[SelfDataTabular],
        data: Mapping[str, NPANYS],
        /,
        *args: Any,
        sort_columns: Optional[str] = None,
        sort_rows: Optional[str] = None,
        **kwargs: Any,
    ) -> SelfDataTabular:
        r"""
        Translate from numeric format.

        Args
        ----
        - data
            Numeric data of the content.
            Numeric data include booleans, integers, floatings, strings, lists, mappings and other
            serializable formats.
        - sort_columns
            Column sorting algorithm name.
        - sort_rows
            Row sorting algorithm name.

        Returns
        -------
        """
        # Merge numeric dictionary into the content, and create an instance.
        return cls(
            pd.DataFrame(data), *args, sort_columns=sort_columns, sort_rows=sort_rows, **kwargs
        )

    def to_csv(self: SelfDataTabular, path: str, /, *args: Any, **kwargs: Any) -> str:
        r"""
        Translate into numeric format.

        Args
        ----
        - path
            Path of the CSV file.

        Returns
        -------
        - path
            Path of the CSV file.
        """
        # Save the content into a CSV file.
        self._content.to_csv(path, index=False)
        return path

    @classmethod
    def from_csv(
        cls: Type[SelfDataTabular],
        path: str,
        /,
        *args: Any,
        sort_columns: Optional[str] = None,
        sort_rows: Optional[str] = None,
        **kwargs: Any,
    ) -> SelfDataTabular:
        r"""
        Initialize the class from CSV file.

        Args
        ----
        - path
            Path of the CSV file.
        - sort_columns
            Column sorting algorithm name.
        - sort_rows
            Row sorting algorithm name.

        Returns
        -------
        - obj
            A class instance of the class.
        """
        # Load content from CSV file, and create an instance.
        return cls(
            pd.read_csv(path, *args, **kwargs), sort_columns=sort_columns, sort_rows=sort_rows
        )
