# Import Python packages.
import functools
import math
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, TypeVar, cast

# Import external packages.
import more_itertools as xitertools
import numpy as np
import pandas as pd

# Import relatively from other modules.
from ....types import NPANYS
from ...base import BaseTransformPandas, ErrorTransformUnsupportPartial
from ..base import TransformCategoryEncode


# Type aliases.
Input = List[pd.DataFrame]
Output = List[pd.DataFrame]


# Self types.
SelfTransformCountizePandas = TypeVar(
    "SelfTransformCountizePandas", bound="TransformCountizePandas"
)


def _lookup(encoding: pd.DataFrame, raw: Any, /, *, ood: str = "") -> NPANYS:
    r"""
    Encode a category cell through lookup.

    Args
    ----
    - encoding
        Encoding lookup table.
    - raw
        Cell value.

    Returns
    -------
    - processed
        Encoded value.
    """
    # Robust lookup encode.
    assert isinstance(
        raw, str
    ), f'Can only encode string category cell through lookup, but get "{repr(type(raw)):s}" cell.'
    try:
        # Search for the numeric values corresponding to cell representation.
        return np.array(encoding.loc[raw].values)
    except KeyError:
        # If nothing is found, use numeric values for OOD case.
        return np.array(encoding.loc[ood].values)


class TransformCountizePandas(BaseTransformPandas, TransformCategoryEncode[pd.DataFrame]):
    r"""
    Transformation for count encoding on Pandas data.
    """
    # Transformation unique identifier.
    _IDENTIFIER = "cateenc.count.pandas"

    def input(self: SelfTransformCountizePandas, raw: Any, /) -> Input:
        r"""
        Convert raw data into input to the transformation.

        Args
        ----
        - raw
            Raw data.

        Returns
        -------
        - process
            Processed data compatible with the transformation.
        """
        # Conversion will vary according to raw data.
        if raw is None:
            # Create an empty categorical container and an empty continuous container.
            return [
                pd.DataFrame([], columns=["feature-categorical"]),
                pd.DataFrame([], columns=["feature-continuous"]),
                pd.DataFrame([], columns=["label-categorical"]),
                pd.DataFrame([], columns=["label-continuous"]),
            ]
        else:
            # All the other cases are not supported.
            raise ErrorTransformUnsupportPartial(
                f"Try to formalize incompatible raw data into input domain of"
                f' "{self._IDENTIFIER:s}".'
            )

    def output(self: SelfTransformCountizePandas, raw: Any, /) -> Output:
        r"""
        Convert raw data into output from the transformation.

        Args
        ----
        - raw
            Raw data.

        Returns
        -------
        - process
            Processed data compatible with the transformation.
        """
        # Conversion will vary according to raw data.
        if raw is None:
            # Create an empty continuous container corresponding to null input.
            # Pay attention that we assume categorical columns are fully encoded without remainings.
            return [
                pd.DataFrame([], columns=[]),
                pd.DataFrame([], columns=["feature-categorical-count", "feature-continuous"]),
                pd.DataFrame([], columns=["label-categorical"]),
                pd.DataFrame([], columns=["label-continuous"]),
            ]
        else:
            # All the other cases are not supported.
            raise ErrorTransformUnsupportPartial(
                f"Try to formalize incompatible raw data into output domain of"
                f' "{self._IDENTIFIER:s}".'
            )

    def transform(
        self: SelfTransformCountizePandas, input: Input, /, *args: Any, **kwargs: Any
    ) -> Output:
        r"""
        Transform input into output without inplacement.

        Args
        ----
        - input
            Input to the transformation.

        Returns
        -------
        - output
            Output from the transformation.
        """
        # Select and encode categorical data based on parameters.
        categorical, continuous, *_labels = input
        missing = list(sorted(set(self.columns_categorical_consume) - set(categorical.columns)))
        missing_ = [*missing[:3], "..."] if len(missing) > 3 else missing
        if missing:
            # Raise error when application is different from fitting.
            raise ErrorTransformUnsupportPartial(
                "Encoder wants to encode {:s} columns which are missing.".format(
                    ", ".join(f'"{name:s}"' for name in missing_)
                )
            )
        consuming = [
            name for name in self.columns_categorical_consume if name in categorical.columns
        ]
        encoded = categorical[consuming]
        for name in encoded.columns:
            # Apply robust encoder function on focusing categorical column given corresponding
            # encoding lookup table.
            lookup = cast(
                Callable[[Any], Sequence[float]],
                functools.partial(_lookup, self.encodings[name], ood=self.ood),
            )
            encoded.loc[:, name] = encoded[name].apply(lookup)

        # Append encoding suffix to encoded data.
        values = np.array([np.array([*row]) for row in encoded.values])
        values = np.reshape(values, (len(encoded.index), len(consuming)))
        columns = [f"{name:s}-count" for name in consuming]
        encoded = pd.DataFrame(values, columns=columns, index=encoded.index)

        # Take remaining categorical data, and merge encoded categorical data into continuous data.
        categorical = categorical[list(sorted(set(categorical.columns) - set(consuming)))]
        continuous = pd.concat([encoded, continuous], axis=1)
        return [categorical, continuous, *_labels]

    def inverse(
        self: SelfTransformCountizePandas, output: Output, /, *args: Any, **kwargs: Any
    ) -> Input:
        r"""
        Inverse output into input without inplacement.

        Args
        ----
        - output
            Output from the transformation.

        Returns
        -------
        - input
            Input to the transformation.
        """
        # Select essential continuous data for inversion.
        categorical, continuous, *_labels = output
        invertible = []
        for name in self.columns_categorical_consume:
            # For the inversion, potential category column needs to query all produced continuous
            # columns.
            # If those continuous columns are partially detected, there will be confusion and an
            # error must be raised.
            queries = [f"{name:s}-count"]
            condition = sum([query in continuous.columns for query in queries])
            if condition == len(queries):
                # For a potential category column, all of its encoded columns are detected, thus it
                # is necessary to inverse.
                invertible.append(name)
                continue

        # Traverse potentially invertible category columns and append their inversion to
        # categorical dataframe.
        categories: Dict[str, List[str]]
        categories = {}
        for name in invertible:
            # Select continuous data only for focusing category inversion.
            queries = [f"{name:s}-count"]
            encoded = continuous[queries].values
            encoding = self.encodings[name].values

            # Make inversion for every row in the selection.
            catching = encoding == encoding
            ignoring = encoding != encoding
            categories[name] = []
            for row in encoded:
                # Find the exact matching and NaN matching.
                row = np.reshape(row, (1, len(row)))
                matching = np.logical_or(
                    np.logical_and(catching, row == encoding), np.logical_and(ignoring, row != row)
                )
                matching = np.all(matching, axis=1)

                # Matching index should be unique.
                buf = np.flatnonzero(matching)
                if len(buf) == 0:
                    # No exact matching is found.
                    raise ErrorTransformUnsupportPartial(
                        f'No exact matching category is found in inversion for column "{name:s}",'
                        f" and you need to define approximate encoding inversion as an independent"
                        f" transformation."
                    )
                if len(buf) > 1:
                    # More than one matching is found.
                    raise ErrorTransformUnsupportPartial(
                        f"More than one exact matching categories are found in inversion for column"
                        f' "{name:s}", and you need to define ambiguous encoding inversion as an'
                        f" independent transformation."
                    )
                category = self.encodings[name].index[int(buf.item())]
                if category == self.ood:
                    # Out-of-distribution symbol is ambiguous, thus it is not a valid inversion.
                    raise ErrorTransformUnsupportPartial(
                        f"Exact matching category is ambiguous out-of-distribution symbol in"
                        f' inversion for column "{name:s}", and you need to define ambiguous'
                        f" encoding inversion as an independent transformation."
                    )
                categories[name].append(category)

        # Merge inversed categories with existing categorical dataframe, and delete continuous
        # columns consumed in inversion.
        consuming = [f"{name:s}-count" for name in invertible]
        categorical = pd.concat([categorical, pd.DataFrame(categories)], axis=1)
        continuous = continuous[list(sorted(set(continuous.columns) - set(consuming)))]
        return [categorical, continuous, *_labels]

    def fit(
        self: SelfTransformCountizePandas,
        input: Input,
        output: Output,
        /,
        *args: Any,
        columns_categorical: Optional[Sequence[str]] = None,
        unk: int = 0,
        ood: Optional[str] = None,
        threshold_quantile: float = -1.0,
        threshold_percent: float = -1.0,
        avoid_collide: bool = True,
        **kwargs: Any,
    ) -> SelfTransformCountizePandas:
        r"""
        Fit transformation parameters by example input and output.

        Args
        ----
        - input
            Example input to the transformation.
        - output
            Example output from the transformation.
        - columns_categorical
            Categorical columns to apply transformation.
        - unk
            Default value for unknown or rare categories.
        - ood
            Category representation reserved for out-of-distribution.
            If it is null, out-of-distribution is not allowed.
        - threshold_quantile
            Drop category counts (inclusively) below given quantile into unknown category.
        - threshold_percent
            Drop category counts (inclusively) below given percent w.r.t. total training data into
            unknown category.
        - avoid_collide
            If True, automatically avoid encoding collision for different categories.
            To avoid, it will automatically add values after decimal points of counts to force the
            difference.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Get columns to be handled.
        categorical, _, _, _ = input
        columns_categorical_full = set(categorical.columns)
        columns_categorical_consume = (
            columns_categorical_full if columns_categorical is None else set(columns_categorical)
        )
        assert columns_categorical_consume.issubset(
            columns_categorical_full
        ), "Consuming categorical columns: {:s} do not exist in input.".format(
            ", ".join(
                '"{name:s}"'
                for name in sorted(columns_categorical_consume - columns_categorical_full)
            )
        )

        # Save essential parameters directly from input.
        self.columns_categorical_consume = list(sorted(columns_categorical_consume))
        self.unk = unk
        self.ood = ood
        self.avoid_collide = avoid_collide

        # Generate runtime parameters for counting.
        self._threshold_quantile = threshold_quantile
        self._threshold_percent = threshold_percent
        self._ood = [self.unk]

        # Fetch the data for the transformation.
        categorical = categorical[self.columns_categorical_consume].applymap(str)

        # Traverse each categorical column.
        encodings_ = {}
        for name in self.columns_categorical_consume:
            # Collect counts of all categories for focusing column as a dataframe.
            counts = categorical[name].value_counts().map(float)
            encoding = pd.DataFrame(counts.values, columns=["count"], index=counts.index.map(str))

            # Collect category filtering count threshold.
            # The implementation is robust to NaN quantile and percent.
            threshold_quantile = max(self._threshold_quantile, 0.0)
            threshold_quantile = float(encoding["count"].quantile(threshold_quantile))
            threshold_quantile -= float(self._threshold_quantile < 0.0)
            threshold_quantile = (
                0.0 if threshold_quantile != threshold_quantile else math.ceil(threshold_quantile)
            )
            threshold_percent = max(self._threshold_percent, 0.0)
            threshold_percent = float(encoding["count"].sum()) * threshold_percent
            threshold_percent -= float(self._threshold_percent < 0.0)
            threshold_percent = (
                0.0 if threshold_percent != threshold_percent else math.ceil(threshold_percent)
            )
            threshold = int(max(threshold_quantile, threshold_percent))
            encoding = encoding[encoding["count"] > threshold]

            # If OOD is allowed, we need to generate an extra encoding for it specifically.
            if self.ood is not None:
                # Automatically generate encoding for OOD.
                assert (
                    self.ood not in encoding.index
                ), 'OOD category "{self.ood:s}" already exists in "{name:s}" column.'
                encoding.loc[self.ood] = self._ood

            # If collision avoidance is required, add floating bias based on alphabetic order for
            # colliding categories.
            biases_ = list(
                xitertools.flatten(
                    [
                        [
                            (str(index), float(i) / float(len(group)))
                            for i, index in enumerate(group.sort_index().index)
                        ]
                        for _, group in encoding.groupby("count")
                    ]
                )
            )
            biases = {index: bias for index, bias in biases_}
            for index in encoding.index:
                # Add bias to corresponding category.
                encoding.loc[index] += float(self.avoid_collide) * biases[index]

            # Save final encoding as parameter for focusing categorical column.
            encodings_[name] = encoding
        self.encodings = encodings_
        return self

    def get_metadata(self: SelfTransformCountizePandas, /) -> Mapping[str, Any]:
        r"""
        Get metadata of the transformation.

        Args
        ----

        Returns
        -------
        - metadata
            Metadata of the transformation.
        """
        # Do nothing.
        return {}

    def get_numeric_data(self: SelfTransformCountizePandas, /) -> Mapping[str, NPANYS]:
        r"""
        Get numeric data of the transformation.

        Args
        ----

        Returns
        -------
        - data
            Numeric data of the transformation.
        """
        # Collect numeric data of encodings.
        return {
            name: encoding.values.astype(np.float64) if len(encoding) == 0 else encoding.values
            for name, encoding in self.encodings.items()
        }

    def get_alphabetic_data(self: SelfTransformCountizePandas, /) -> Mapping[str, Any]:
        r"""
        Get alphabetic data of the transformation.

        Args
        ----

        Returns
        -------
        - data
            Alphabetic data of the transformation.
        """
        # Collect all encoding column and category representations.
        return {
            "columns": self.columns_categorical_consume,
            "unk": self.unk,
            "ood": self.ood,
            "avoid_collide": self.avoid_collide,
            "categories": {name: list(encoding.index) for name, encoding in self.encodings.items()},
        }

    def set_metadata(
        self: SelfTransformCountizePandas, metadata: Mapping[str, Any], /  # noqa: W504
    ) -> SelfTransformCountizePandas:
        r"""
        Set metadata of the transformation.

        Args
        ----
        - metadata
            Metadata of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Do nothing.
        return self

    def set_numeric_data(
        self: SelfTransformCountizePandas, data: Mapping[str, NPANYS], /  # noqa: W504
    ) -> SelfTransformCountizePandas:
        r"""
        Set numeric data of the transformation.

        Args
        ----
        - data
            Numeric data of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Create anonymous encoding container from given numeric data.
        self.encodings = {}
        for name, values in data.items():
            # Generate encoding dataframe for focusing categorical column with anonymous row and
            # column indices.
            # Naming information will be loaded and overwritten in later loading.
            self.encodings[name] = pd.DataFrame(values, columns=["count"])
        return self

    def set_alphabetic_data(
        self: SelfTransformCountizePandas, data: Mapping[str, Any], /  # noqa: W504
    ) -> SelfTransformCountizePandas:
        r"""
        Set alphabetic data of the transformation.

        Args
        ----
        - data
            Alphabetic data of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Safety check.
        assert "columns" in data, "Categorical column namings are missing."
        assert "unk" in data, "Default value for unknown category is missing."
        assert "ood" in data, "OOD category symbol (including null) is missing."
        assert "avoid_collide" in data, "Encoding collision avoidance flag is missing."
        assert "categories" in data, "Category namings are missing."
        for name, encoding in self.encodings.items():
            # Safety check for each numeric data block.
            assert name in data["categories"], f'Category namings of column "{name:s}" is missing.'
            assert len(data["categories"][name]) == len(
                encoding
            ), f'Sizes of category namings and numeric data of column "{name:s}" do not match.'

        # Load naming information of columns and categories.
        self.columns_categorical_consume = data["columns"]
        self.unk = data["unk"]
        self.ood = data["ood"]
        self.avoid_collide = data["avoid_collide"]
        for name, encoding in self.encodings.items():
            # Overwrite row index namings.
            encoding.index = pd.Index(data["categories"][name])
        return self
