# Import Python packages.
import functools
from typing import Any, Callable, List, Mapping, Sequence, TypeVar, cast

# Import external packages.
import numpy as np
import pandas as pd

# Import relatively from other modules.
from ....types import NPANYS
from ...base import ErrorTransformUnsupport, ErrorTransformUnsupportPartial
from ..count import TransformCountizePandas


# Type aliases.
Input = List[pd.DataFrame]
Output = List[pd.DataFrame]


# Self types.
SelfTransformSDVPandas = TypeVar("SelfTransformSDVPandas", bound="TransformSDVPandas")


def _lookup(encoding: pd.DataFrame, raw: Any, /, *, ood: str = "") -> NPANYS:
    r"""
    Encode a category cell through lookup.

    Args
    ----
    - encoding
        Encoding lookup table.
    - raw
        Cell value.

    Returns
    -------
    - processed
        Encoded value.
    """
    # Robust lookup encode.
    assert isinstance(
        raw, str
    ), f'Can only encode string category cell through lookup, but get "{repr(type(raw)):s}" cell.'
    try:
        # Search for the numeric values corresponding to cell representation.
        return np.array(encoding.loc[raw].values)
    except KeyError:
        # If nothing is found, use numeric values for OOD case.
        return np.array(encoding.loc[ood].values)


class TransformSDVPandas(TransformCountizePandas):
    r"""
    Transformation for SDV encoding on Pandas data.
    """
    # Transformation unique identifier.
    _IDENTIFIER = "cateenc.sdv.pandas"

    def input(self: SelfTransformSDVPandas, raw: Any, /) -> Input:
        r"""
        Convert raw data into input to the transformation.

        Args
        ----
        - raw
            Raw data.

        Returns
        -------
        - process
            Processed data compatible with the transformation.
        """
        # Conversion will vary according to raw data.
        if raw is None:
            # Create an empty categorical container and an empty continuous container.
            return [
                pd.DataFrame([], columns=["feature-categorical"]),
                pd.DataFrame([], columns=["feature-continuous"]),
                pd.DataFrame([], columns=["label-categorical"]),
                pd.DataFrame([], columns=["label-continuous"]),
            ]
        else:
            # All the other cases are not supported.
            raise ErrorTransformUnsupportPartial(
                f"Try to formalize incompatible raw data into input domain of"
                f' "{self._IDENTIFIER:s}".'
            )

    def output(self: SelfTransformSDVPandas, raw: Any, /) -> Output:
        r"""
        Convert raw data into output from the transformation.

        Args
        ----
        - raw
            Raw data.

        Returns
        -------
        - process
            Processed data compatible with the transformation.
        """
        # Conversion will vary according to raw data.
        if raw is None:
            # Create an empty continuous container corresponding to null input.
            # Pay attention that we assume categorical columns are fully encoded without remainings.
            return [
                pd.DataFrame([], columns=[]),
                pd.DataFrame([], columns=["feature-categorical-sdv", "feature-continuous"]),
                pd.DataFrame([], columns=["label-categorical"]),
                pd.DataFrame([], columns=["label-continuous"]),
            ]
        else:
            # All the other cases are not supported.
            raise ErrorTransformUnsupportPartial(
                f"Try to formalize incompatible raw data into output domain of"
                f' "{self._IDENTIFIER:s}".'
            )

    def transform(
        self: SelfTransformSDVPandas, input: Input, /, *args: Any, seed: int = 42, **kwargs: Any
    ) -> Output:
        r"""
        Transform input into output without inplacement.

        Args
        ----
        - input
            Input to the transformation.
        - seed
            Random seed for Gaussian noise.

        Returns
        -------
        - output
            Output from the transformation.
        """
        # Select and encode categorical data based on parameters.
        categorical, continuous, *_labels = input
        missing = list(sorted(set(self.columns_categorical_consume) - set(categorical.columns)))
        missing_ = [*missing[:3], "..."] if len(missing) > 3 else missing
        if missing:
            # Raise error when application is different from fitting.
            raise ErrorTransformUnsupportPartial(
                "Encoder wants to encode {:s} columns which are missing.".format(
                    ", ".join(f'"{name:s}"' for name in missing_)
                )
            )
        consuming = [
            name for name in self.columns_categorical_consume if name in categorical.columns
        ]
        rng = np.random.RandomState(seed)
        encoded = categorical[consuming]
        for name in encoded.columns:
            # Apply robust encoder function on focusing categorical column given corresponding
            # encoding lookup table.
            lookup = cast(
                Callable[[Any], Sequence[float]],
                functools.partial(_lookup, self.encodings[name][["mean", "std"]], ood=self.ood),
            )
            distributes = np.array([np.array([*row]) for row in encoded[name].apply(lookup).values])
            distributes = np.reshape(distributes, (len(distributes), 2))

            # Add Gaussian noises based on CDFs.
            noises = rng.normal(0.0, 1.0, (len(distributes),))
            encoded.loc[:, name] = distributes[:, 0] + noises * distributes[:, 1]

        # Append encoding suffix to encoded data.
        values = np.reshape(encoded.values, (len(encoded.index), len(consuming)))
        columns = [f"{name:s}-sdv" for name in consuming]
        encoded = pd.DataFrame(values, columns=columns, index=encoded.index)

        # Take remaining categorical data, and merge encoded categorical data into continuous data.
        categorical = categorical[list(sorted(set(categorical.columns) - set(consuming)))]
        continuous = pd.concat([encoded, continuous], axis=1)
        return [categorical, continuous, *_labels]

    def inverse(
        self: SelfTransformSDVPandas, output: Output, /, *args: Any, **kwargs: Any
    ) -> Input:
        r"""
        Inverse output into input without inplacement.

        Args
        ----
        - output
            Output from the transformation.

        Returns
        -------
        - input
            Input to the transformation.
        """
        raise ErrorTransformUnsupport("Leave for future.")

    def fit(
        self: SelfTransformSDVPandas, input: Input, output: Output, /, *args: Any, **kwargs: Any
    ) -> SelfTransformSDVPandas:
        r"""
        Fit transformation parameters by example input and output.

        Args
        ----
        - input
            Example input to the transformation.
        - output
            Example output from the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Collect count encoding parameters first.
        TransformCountizePandas.fit(self, input, output, *args, **kwargs)

        # Generate CDF parameters from count parameters.
        for name, encoding in self.encodings.items():
            # Compute CDFs from counts.
            encoding = encoding.sort_values("count")
            pdfs = encoding.to_numpy() / float(encoding["count"].sum())
            cdfs = np.zeros((len(pdfs) + 1,), dtype=pdfs.dtype)
            cdfs[1:] = np.cumsum(pdfs)

            # Generate encoding distributions from CDFs.
            lowers, uppers = cdfs[:-1], cdfs[1:]
            means = (lowers + uppers) / 2.0
            stds = (uppers - lowers) / 6.0
            encoding["mean"] = means
            encoding["std"] = stds
            self.encodings[name] = encoding
        return self

    def get_metadata(self: SelfTransformSDVPandas, /) -> Mapping[str, Any]:
        r"""
        Get metadata of the transformation.

        Args
        ----

        Returns
        -------
        - metadata
            Metadata of the transformation.
        """
        # Do nothing.
        return {}

    def get_numeric_data(self: SelfTransformSDVPandas, /) -> Mapping[str, NPANYS]:
        r"""
        Get numeric data of the transformation.

        Args
        ----

        Returns
        -------
        - data
            Numeric data of the transformation.
        """
        # Collect numeric data of encodings.
        return {
            name: encoding.values.astype(np.float64) if len(encoding) == 0 else encoding.values
            for name, encoding in self.encodings.items()
        }

    def get_alphabetic_data(self: SelfTransformSDVPandas, /) -> Mapping[str, Any]:
        r"""
        Get alphabetic data of the transformation.

        Args
        ----

        Returns
        -------
        - data
            Alphabetic data of the transformation.
        """
        # Collect all encoding column and category representations.
        return {
            "columns": self.columns_categorical_consume,
            "unk": self.unk,
            "ood": self.ood,
            "categories": {name: list(encoding.index) for name, encoding in self.encodings.items()},
        }

    def set_metadata(
        self: SelfTransformSDVPandas, metadata: Mapping[str, Any], /  # noqa: W504
    ) -> SelfTransformSDVPandas:
        r"""
        Set metadata of the transformation.

        Args
        ----
        - metadata
            Metadata of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Do nothing.
        return self

    def set_numeric_data(
        self: SelfTransformSDVPandas, data: Mapping[str, NPANYS], /  # noqa: W504
    ) -> SelfTransformSDVPandas:
        r"""
        Set numeric data of the transformation.

        Args
        ----
        - data
            Numeric data of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Create anonymous encoding container from given numeric data.
        self.encodings = {}
        for name, values in data.items():
            # Generate encoding dataframe for focusing categorical column with anonymous row and
            # column indices.
            # Naming information will be loaded and overwritten in later loading.
            self.encodings[name] = pd.DataFrame(values, columns=["count", "mean", "std"])
        return self

    def set_alphabetic_data(
        self: SelfTransformSDVPandas, data: Mapping[str, Any], /  # noqa: W504
    ) -> SelfTransformSDVPandas:
        r"""
        Set alphabetic data of the transformation.

        Args
        ----
        - data
            Alphabetic data of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Safety check.
        assert "columns" in data, "Categorical column namings are missing."
        assert "unk" in data, "Default value for unknown category is missing."
        assert "ood" in data, "OOD category symbol (including null) is missing."
        assert "categories" in data, "Category namings are missing."
        for name, encoding in self.encodings.items():
            # Safety check for each numeric data block.
            assert name in data["categories"], f'Category namings of column "{name:s}" is missing.'
            assert len(data["categories"][name]) == len(
                encoding
            ), f'Sizes of category namings and numeric data of column "{name:s}" do not match.'

        # Load naming information of columns and categories.
        self.columns_categorical_consume = data["columns"]
        self.unk = data["unk"]
        self.ood = data["ood"]
        for name, encoding in self.encodings.items():
            # Overwrite row index namings.
            encoding.index = pd.Index(data["categories"][name])
        return self
