# Import Python packages.
import functools
from typing import (
    Any,
    Callable,
    Dict,
    List,
    Mapping,
    Optional,
    Sequence,
    Type,
    TypeVar,
    Union,
    cast,
)

# Import external packages.
import more_itertools as xitertools
import numpy as np
import pandas as pd

# Import relatively from other modules.
from ....types import NPANYS
from ...base import BaseTransformPandas, ErrorTransformUnsupportPartial
from ..base import TransformCategoryEncode
from .utils import (
    nandeg,
    nanmax,
    nanmean,
    nanmin,
    nanratio,
    nanstd,
    normdeg_identity,
    normdeg_log,
    normdeg_log_normal,
)


# Type aliases.
Input = List[pd.DataFrame]
Output = List[pd.DataFrame]


# Self types.
SelfTransformCCAPandas = TypeVar("SelfTransformCCAPandas", bound="TransformCCAPandas")


def _lookup(encoding: pd.DataFrame, raw: Any, /, *, ood: str = "") -> NPANYS:
    r"""
    Encode a category cell through lookup.

    Args
    ----
    - encoding
        Encoding lookup table.
    - raw
        Cell value.

    Returns
    -------
    - processed
        Encoded value.
    """
    # Robust lookup encode.
    assert isinstance(
        raw, str
    ), f'Can only encode string category cell through lookup, but get "{repr(type(raw)):s}" cell.'
    try:
        # Search for the numeric values corresponding to cell representation.
        return np.array(encoding.loc[raw].values)
    except KeyError:
        # If nothing is found, use numeric values for OOD case.
        return np.array(encoding.loc[ood].values)


class TransformCCAPandas(BaseTransformPandas, TransformCategoryEncode[pd.DataFrame]):
    r"""
    Transformation for CCA on Pandas data.
    """
    # Transformation unique identifier.
    _IDENTIFIER = "cateenc.cca.pandas"

    # Tabular disambiguition sorting algorithms.
    _NORMDEGS: Dict[str, Callable[["pd.Series[Any]"], "pd.Series[Any]"]]
    _NORMDEGS = {
        "identity": cast(Callable[["pd.Series[Any]"], "pd.Series[Any]"], normdeg_identity),
        "log": cast(Callable[["pd.Series[Any]"], "pd.Series[Any]"], normdeg_log),
        "log_normal": cast(Callable[["pd.Series[Any]"], "pd.Series[Any]"], normdeg_log_normal),
    }

    @classmethod
    def register_normdeg(
        cls: Type[SelfTransformCCAPandas],
        f: Callable[["pd.Series[Any]"], "pd.Series[Any]"],
        name: str,
        /,
    ) -> None:
        r"""
        Register a degree column normalization algorithm.

        Args
        ----
        - f
            Degree column normalization algorithm.
        - name
            Degree column normalization algorithm name for indexing.

        Returns
        -------
        """
        # Register a degree column normalization without duplication.
        assert (
            name not in cls._NORMDEGS
        ), f'Degree column normalization algorithm "{name:s}" has been registered.'
        cls._NORMDEGS[name] = f

    @classmethod
    def get_normdeg(
        cls: Type[SelfTransformCCAPandas], name: str, /  # noqa: W504
    ) -> Callable[["pd.Series[Any]"], "pd.Series[Any]"]:
        r"""
        Get a degree column normalization algorithm.

        Args
        ----
        - name
            Degree column normalization algorithm name for indexing.

        Returns
        -------
        - f
            Degree column normalization algorithm.
        """
        # Get the sorting algorithm from the class registration.
        return cls._NORMDEGS[name]

    def input(self: SelfTransformCCAPandas, raw: Any, /) -> Input:
        r"""
        Convert raw data into input to the transformation.

        Args
        ----
        - raw
            Raw data.

        Returns
        -------
        - process
            Processed data compatible with the transformation.
        """
        # Conversion will vary according to raw data.
        if raw is None:
            # Create an empty categorical container and an empty continuous container.
            return [
                pd.DataFrame([], columns=["feature-categorical"]),
                pd.DataFrame([], columns=["feature-continuous"]),
                pd.DataFrame([], columns=["label-categorical"]),
                pd.DataFrame([], columns=["label-continuous"]),
            ]
        else:
            # All the other cases are not supported.
            raise ErrorTransformUnsupportPartial(
                f"Try to formalize incompatible raw data into input domain of"
                f' "{self._IDENTIFIER:s}".'
            )

    def output(self: SelfTransformCCAPandas, raw: Any, /) -> Output:
        r"""
        Convert raw data into output from the transformation.

        Args
        ----
        - raw
            Raw data.

        Returns
        -------
        - process
            Processed data compatible with the transformation.
        """
        # Conversion will vary according to raw data.
        if raw is None:
            # Create an empty continuous container corresponding to null input.
            # Pay attention that we assume categorical columns are fully encoded without remainings.
            return [
                pd.DataFrame([], columns=[]),
                pd.DataFrame(
                    [],
                    columns=[
                        "feature-categorical-cca-feature-continuous-null",
                        "feature-categorical-cca-feature-continuous-deg",
                        "feature-categorical-cca-feature-continuous-min",
                        "feature-categorical-cca-feature-continuous-mean",
                        "feature-categorical-cca-feature-continuous-max",
                        "feature-categorical-cca-feature-continuous-std",
                        "feature-continuous",
                    ],
                ),
                pd.DataFrame([], columns=["label-categorical"]),
                pd.DataFrame([], columns=["label-continuous"]),
            ]
        else:
            # All the other cases are not supported.
            raise ErrorTransformUnsupportPartial(
                f"Try to formalize incompatible raw data into output domain of"
                f' "{self._IDENTIFIER:s}".'
            )

    def transform(
        self: SelfTransformCCAPandas, input: Input, /, *args: Any, **kwargs: Any
    ) -> Output:
        r"""
        Transform input into output without inplacement.

        Args
        ----
        - input
            Input to the transformation.

        Returns
        -------
        - output
            Output from the transformation.
        """
        # Select and encode categorical data based on parameters.
        categorical, continuous, *_labels = input
        missing = list(sorted(set(self.columns_categorical_consume) - set(categorical.columns)))
        missing_ = [*missing[:3], "..."] if len(missing) > 3 else missing
        if missing:
            # Raise error when application is different from fitting.
            raise ErrorTransformUnsupportPartial(
                "Encoder wants to encode {:s} columns which are missing.".format(
                    ", ".join(f'"{name:s}"' for name in missing_)
                )
            )
        consuming = [
            name for name in self.columns_categorical_consume if name in categorical.columns
        ]
        encoded = categorical[consuming]
        for name in encoded.columns:
            # Apply robust encoder function on focusing categorical column given corresponding
            # encoding lookup table.
            lookup = cast(
                Callable[[Any], Sequence[float]],
                functools.partial(_lookup, self.encodings[name], ood=self.ood),
            )
            encoded.loc[:, name] = encoded[name].apply(lookup)

        # Merge all encoded column data together as a single numeric block.
        values = np.array([np.array([*row]) for row in encoded.values])
        values = np.reshape(
            values, (len(encoded.index), len(consuming) * len(self._columns_encoding_produce))
        )
        columns = list(
            xitertools.flatten(
                [
                    [f"{name:s}-cca-{suffix:s}" for suffix in self._columns_encoding_produce]
                    for name in consuming
                ]
            )
        )
        encoded = pd.DataFrame(values, columns=columns, index=encoded.index)

        # Take remaining categorical data, and merge encoded categorical data into continuous data.
        categorical = categorical[list(sorted(set(categorical.columns) - set(consuming)))]
        continuous = pd.concat([encoded, continuous], axis=1)
        return [categorical, continuous, *_labels]

    def inverse(
        self: SelfTransformCCAPandas, output: Output, /, *args: Any, **kwargs: Any
    ) -> Input:
        r"""
        Inverse output into input without inplacement.

        Args
        ----
        - output
            Output from the transformation.

        Returns
        -------
        - input
            Input to the transformation.
        """
        # Select essential continuous data for inversion.
        categorical, continuous, *_labels = output
        invertible = []
        incomplete = []
        for name in self.columns_categorical_consume:
            # For the inversion, potential category column needs to query all produced continuous
            # columns.
            # If those continuous columns are partially detected, there will be confusion and an
            # error must be raised.
            queries = [f"{name:s}-cca-{suffix:s}" for suffix in self._columns_encoding_produce]
            condition = sum([query in continuous.columns for query in queries])
            if condition == len(queries):
                # For a potential category column, all of its encoded columns are detected, thus it
                # is necessary to inverse.
                invertible.append(name)
                continue
            incomplete.append(f'"{name:s}"')
        incomplete_ = [*incomplete[:3], "..."] if len(incomplete) > 3 else incomplete
        if incomplete:
            # If encoded columns of a potential category column are incomplete, simple lookup can
            # not inverse encoding to category safely.
            raise ErrorTransformUnsupportPartial(
                "Fail to inverse for potential category columns: {:s}, due to incomplete encoded"
                " columns.".format(", ".join(incomplete_))
            )

        # Traverse potentially invertible category columns and append their inversion to
        # categorical dataframe.
        categories: Dict[str, List[str]]
        categories = {}
        for name in invertible:
            # Select continuous data only for focusing category inversion.
            queries = [f"{name:s}-cca-{suffix:s}" for suffix in self._columns_encoding_produce]
            encoded = continuous[queries].values
            encoding = self.encodings[name].values

            # Make inversion for every row in the selection.
            catching = encoding == encoding
            ignoring = encoding != encoding
            categories[name] = []
            for row in encoded:
                # Find the exact matching and NaN matching.
                row = np.reshape(row, (1, len(row)))
                matching = np.logical_or(
                    np.logical_and(catching, row == encoding), np.logical_and(ignoring, row != row)
                )
                matching = np.all(matching, axis=1)

                # Matching index should be unique.
                buf = np.flatnonzero(matching)
                if len(buf) == 0:
                    # No exact matching is found.
                    raise ErrorTransformUnsupportPartial(
                        f'No exact matching category is found in inversion for column "{name:s}",'
                        f" and you need to define approximate encoding inversion as an independent"
                        f" transformation."
                    )
                if len(buf) > 1:
                    # More than one matching is found.
                    raise ErrorTransformUnsupportPartial(
                        f"More than one exact matching categories are found in inversion for column"
                        f' "{name:s}", and you need to define ambiguous encoding inversion as an'
                        f" independent transformation."
                    )
                category = self.encodings[name].index[int(buf.item())]
                if category == self.ood:
                    # Out-of-distribution symbol is ambiguous, thus it is not a valid inversion.
                    raise ErrorTransformUnsupportPartial(
                        f"Exact matching category is ambiguous out-of-distribution symbol in"
                        f' inversion for column "{name:s}", and you need to define ambiguous'
                        f" encoding inversion as an independent transformation."
                    )
                categories[name].append(category)

        # Merge inversed categories with existing categorical dataframe, and delete continuous
        # columns consumed in inversion.
        consuming = list(
            xitertools.flatten(
                [
                    [f"{name:s}-cca-{suffix:s}" for suffix in self._columns_encoding_produce]
                    for name in invertible
                ]
            )
        )
        categorical = pd.concat([categorical, pd.DataFrame(categories)], axis=1)
        continuous = continuous[list(sorted(set(continuous.columns) - set(consuming)))]
        return [categorical, continuous, *_labels]

    def fit(
        self: SelfTransformCCAPandas,
        input: Input,
        output: Output,
        /,
        *args: Any,
        columns_categorical: Optional[Sequence[str]] = None,
        columns_continuous: Optional[Sequence[str]] = None,
        nan: float = float("nan"),
        ddof: int = 0,
        ood: Optional[str] = None,
        normalize_degree: str = "log_normal",
        aggregates: Sequence[str] = ["deg", "null", "mean", "std", "min", "max"],
        **kwargs: Any,
    ) -> SelfTransformCCAPandas:
        r"""
        Fit transformation parameters by example input and output.

        Args
        ----
        - input
            Example input to the transformation.
        - output
            Example output from the transformation.
        - columns_categorical
            Categorical columns to apply transformation.
        - columns_continuous
            Continuous columns to support transformation.
        - nan
            Default value when no valued element is presented in an aggregation.
        - ddof
            Means Delta Degrees of Freedom.
        - ood
            Category representation reserved for out-of-distribution.
            If it is null, out-of-distribution is not allowed.
        - normalize_degree
            Normalization schema used on encoded degree columns.
        - aggregates
            Aggregations to be collected from every continuous feature.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Get columns to be handled.
        categorical, continuous, _, _ = input
        columns_categorical_full = set(categorical.columns)
        columns_categorical_consume = (
            columns_categorical_full if columns_categorical is None else set(columns_categorical)
        )
        columns_continuous_full = set(continuous.columns)
        columns_continuous_consume = (
            columns_continuous_full if columns_continuous is None else set(columns_continuous)
        )
        assert columns_categorical_consume.issubset(
            columns_categorical_full
        ), "Consuming categorical columns: {:s} do not exist in input.".format(
            ", ".join(
                '"{name:s}"'
                for name in sorted(columns_categorical_consume - columns_categorical_full)
            )
        )
        assert columns_continuous_consume.issubset(
            columns_continuous_full
        ), "Consuming continuous columns: {:s} do not exist in input.".format(
            ", ".join(
                '"{name:s}"'
                for name in sorted(columns_continuous_consume - columns_continuous_full)
            )
        )

        # Save essential parameters directly from input.
        self.columns_categorical_consume = list(sorted(columns_categorical_consume))
        self.columns_continuous_consume = list(sorted(columns_continuous_consume))
        self.nan = nan
        self.ood = ood
        self.normalize_degree = normalize_degree
        self.aggregates = list(sorted(set(aggregates)))

        # Generate runtime parameters for aggregation.
        self._ddof = ddof
        self._aggs: Sequence[Union[Callable[[Sequence[Any]], Union[int, float]], str, np.ufunc]]
        self._aggs = [
            cast(
                Callable[[Sequence[Any]], Union[int, float]],
                {
                    "deg": nandeg,
                    "null": nanratio,
                    "mean": functools.partial(nanmean, null=self.nan),
                    "std": functools.partial(nanstd, null=self.nan, ddof=self._ddof),
                    "min": functools.partial(nanmin, null=self.nan),
                    "max": functools.partial(nanmax, null=self.nan),
                }[name],
            )
            for name in self.aggregates
        ]
        self._columns_encoding_produce = list(
            xitertools.flatten(
                [
                    [f"{prefix:s}-{suffix:s}" for suffix in self.aggregates]
                    for prefix in self.columns_continuous_consume
                ]
            )
        )
        self._ood = [
            {
                "deg": 0,
                "null": 0.0,
                "mean": self.nan,
                "std": self.nan,
                "min": self.nan,
                "max": self.nan,
            }[name]
            for name in self.aggregates
        ]
        self._normdeg = self.get_normdeg(normalize_degree)

        # Fetch the data for the transformation.
        categorical = categorical[self.columns_categorical_consume].applymap(str)
        continuous = continuous[self.columns_continuous_consume]
        data = pd.concat([categorical, continuous], axis=1)

        # Traverse each categorical column.
        encodings_ = {}
        for name in self.columns_categorical_consume:
            # Collect aggregation statistics from all continuous columns for all categories in
            # the focusing categorical column.
            groupby = data[[*self.columns_continuous_consume, name]].groupby(name)
            encoding = groupby[self.columns_continuous_consume].agg(self._aggs)
            encoding.columns = pd.Index(self._columns_encoding_produce)
            encoding.index = encoding.index.map(str)

            # If OOD is allowed, we need to generate an extra encoding for it specifically.
            if self.ood is not None:
                # Automatically generate encoding for OOD.
                assert (
                    self.ood not in encoding.index
                ), 'OOD category "{self.ood:s}" already exists in "{name:s}" column.'
                encoding.loc[self.ood] = self._ood * len(self.columns_continuous_consume)

            # Degree columns need additional normalization.
            for column, series in encoding.items():
                # Only degree column needs additional normalization.
                if str(column).endswith("-deg"):
                    # Skip non-degree columns.
                    encoding[column] = self._normdeg(series)

            # Save final encoding as parameter for focusing categorical column.
            encodings_[name] = encoding
        self.encodings = encodings_
        return self

    def get_metadata(self: SelfTransformCCAPandas, /) -> Mapping[str, Any]:
        r"""
        Get metadata of the transformation.

        Args
        ----

        Returns
        -------
        - metadata
            Metadata of the transformation.
        """
        # Do nothing.
        return {}

    def get_numeric_data(self: SelfTransformCCAPandas, /) -> Mapping[str, NPANYS]:
        r"""
        Get numeric data of the transformation.

        Args
        ----

        Returns
        -------
        - data
            Numeric data of the transformation.
        """
        # Collect numeric data of encodings.
        return {
            name: np.reshape(
                encoding.values.astype(np.float64) if len(encoding) == 0 else encoding.values,
                (len(encoding), len(self.columns_continuous_consume), len(self.aggregates)),
            )
            for name, encoding in self.encodings.items()
        }

    def get_alphabetic_data(self: SelfTransformCCAPandas, /) -> Mapping[str, Any]:
        r"""
        Get alphabetic data of the transformation.

        Args
        ----

        Returns
        -------
        - data
            Alphabetic data of the transformation.
        """
        # Collect all encoding column and category representations.
        return {
            "columns": {
                "categorical": self.columns_categorical_consume,
                "continuous": self.columns_continuous_consume,
            },
            "nan": self.nan,
            "ood": self.ood,
            "normalize_degree": self.normalize_degree,
            "aggregates": self.aggregates,
            "categories": {name: list(encoding.index) for name, encoding in self.encodings.items()},
        }

    def set_metadata(
        self: SelfTransformCCAPandas, metadata: Mapping[str, Any], /  # noqa: W504
    ) -> SelfTransformCCAPandas:
        r"""
        Set metadata of the transformation.

        Args
        ----
        - metadata
            Metadata of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Do nothing.
        return self

    def set_numeric_data(
        self: SelfTransformCCAPandas, data: Mapping[str, NPANYS], /  # noqa: W504
    ) -> SelfTransformCCAPandas:
        r"""
        Set numeric data of the transformation.

        Args
        ----
        - data
            Numeric data of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Create anonymous encoding container from given numeric data.
        self.encodings = {}
        for name, values in data.items():
            # Generate encoding dataframe for focusing categorical column with anonymous row and
            # column indices.
            # Naming information will be loaded and overwritten in later loading.
            num_categories, num_columns_continuous, num_aggs = values.shape
            num_encodings = num_columns_continuous * num_aggs
            self.encodings[name] = pd.DataFrame(np.reshape(values, (num_categories, num_encodings)))
        return self

    def set_alphabetic_data(
        self: SelfTransformCCAPandas, data: Mapping[str, Any], /  # noqa: W504
    ) -> SelfTransformCCAPandas:
        r"""
        Set alphabetic data of the transformation.

        Args
        ----
        - data
            Alphabetic data of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Safety check.
        assert "columns" in data, "Column namings are missing."
        assert "categorical" in data["columns"], "Categorical column namings are missing."
        assert "continuous" in data["columns"], "Continuous column naming are missing."
        assert "nan" in data, "Defaukt value for NaN is missing."
        assert "ood" in data, "OOD category symbol (including null) is missing."
        assert "normalize_degree" in data, "Normalization schema for degree columns is missing"
        assert "aggregates" in data, "Collected aggregation statistics are missing"
        assert "categories" in data, "Category namings are missing."
        for name, encoding in self.encodings.items():
            # Safety check for each numeric data block.
            assert name in data["categories"], f'Category namings of column "{name:s}" is missing.'
            assert len(data["categories"][name]) == len(
                encoding
            ), f'Sizes of category namings and numeric data of column "{name:s}" do not match.'

        # Load naming information of columns and categories.
        self.columns_categorical_consume = data["columns"]["categorical"]
        self.columns_continuous_consume = data["columns"]["continuous"]
        self.nan = data["nan"]
        self.ood = data["ood"]
        self.normalize_degree = data["normalize_degree"]
        self.aggregates = data["aggregates"]
        self._columns_encoding_produce = list(
            xitertools.flatten(
                [
                    [f"{prefix:s}-{suffix:s}" for suffix in self.aggregates]
                    for prefix in self.columns_continuous_consume
                ]
            )
        )
        for name, encoding in self.encodings.items():
            # Overwrite row and column index namings.
            encoding.index = pd.Index(data["categories"][name])
            encoding.columns = pd.Index(self._columns_encoding_produce)
        return self
