# Import Python packages.
from typing import Any, List, Mapping, Optional, Sequence, Tuple, TypeVar

# Import external packages.
import numpy as np
import pandas as pd
from category_encoders import CatBoostEncoder  # type: ignore[import-untyped]

# Import relatively from other modules.
from ....types import NPANYS
from ...base import BaseTransformPandas, ErrorTransformUnsupportPartial
from ..base import TransformCategoryEncode


# Type aliases.
Input = List[pd.DataFrame]
Output = List[pd.DataFrame]


# Self types.
SelfTransformCatBoostEncodePandas = TypeVar(
    "SelfTransformCatBoostEncodePandas", bound="TransformCatBoostEncodePandas"
)


class TransformCatBoostEncodePandas(BaseTransformPandas, TransformCategoryEncode[pd.DataFrame]):
    r"""
    Transformation for CatBoost encoding on Pandas data.
    """
    # Transformation unique identifier.
    _IDENTIFIER = "cateenc.catboost.pandas"

    def input(self: SelfTransformCatBoostEncodePandas, raw: Any, /) -> Input:
        r"""
        Convert raw data into input to the transformation.

        Args
        ----
        - raw
            Raw data.

        Returns
        -------
        - process
            Processed data compatible with the transformation.
        """
        # Conversion will vary according to raw data.
        if raw is None:
            # Create an empty categorical container and an empty continuous container.
            return [
                pd.DataFrame([], columns=["feature-categorical"]),
                pd.DataFrame([], columns=["feature-continuous"]),
                pd.DataFrame([], columns=["label-categorical"]),
                pd.DataFrame([], columns=["label-continuous"]),
            ]
        else:
            # All the other cases are not supported.
            raise ErrorTransformUnsupportPartial(
                f"Try to formalize incompatible raw data into input domain of"
                f' "{self._IDENTIFIER:s}".'
            )

    def output(self: SelfTransformCatBoostEncodePandas, raw: Any, /) -> Output:
        r"""
        Convert raw data into output from the transformation.

        Args
        ----
        - raw
            Raw data.

        Returns
        -------
        - process
            Processed data compatible with the transformation.
        """
        # Conversion will vary according to raw data.
        if raw is None:
            # Create an empty continuous container corresponding to null input.
            # Pay attention that we assume categorical columns are fully encoded without remainings.
            return [
                pd.DataFrame([], columns=[]),
                pd.DataFrame([], columns=["feature-categorical-catboost", "feature-continuous"]),
                pd.DataFrame([], columns=["label-categorical"]),
                pd.DataFrame([], columns=["label-continuous"]),
            ]
        else:
            # All the other cases are not supported.
            raise ErrorTransformUnsupportPartial(
                f"Try to formalize incompatible raw data into output domain of"
                f' "{self._IDENTIFIER:s}".'
            )

    def transform(
        self: SelfTransformCatBoostEncodePandas, input: Input, /, *args: Any, **kwargs: Any
    ) -> Output:
        r"""
        Transform input into output without inplacement.

        Args
        ----
        - input
            Input to the transformation.

        Returns
        -------
        - output
            Output from the transformation.
        """
        # Select and encode categorical data based on parameters.
        categorical, continuous, *_labels = input
        missing = list(sorted(set(self.columns_categorical_consume) - set(categorical.columns)))
        missing_ = [*missing[:3], "..."] if len(missing) > 3 else missing
        if missing:
            # Raise error when application is different from fitting.
            raise ErrorTransformUnsupportPartial(
                "Encoder wants to encode {:s} columns which are missing.".format(
                    ", ".join(f'"{name:s}"' for name in missing_)
                )
            )
        consuming = [
            name for name in self.columns_categorical_consume if name in categorical.columns
        ]
        encoded = categorical[consuming]
        encoded = self.handler.transform(encoded)
        columns = [f"{name:s}-catboost" for name in consuming]
        encoded = pd.DataFrame(encoded.values, columns=columns, index=encoded.index)

        # Take remaining categorical data, and merge encoded categorical data into continuous data.
        categorical = categorical[list(sorted(set(categorical.columns) - set(consuming)))]
        continuous = pd.concat([encoded, continuous], axis=1)
        return [categorical, continuous, *_labels]

    def fit(
        self: SelfTransformCatBoostEncodePandas,
        input: Input,
        output: Output,
        /,
        *args: Any,
        category_encoders_init_args: Sequence[Any] = [],
        category_encoders_init_kwargs: Mapping[str, Any] = {},
        category_encoders_fit_args: Sequence[Any] = [],
        category_encoders_fit_kwargs: Mapping[str, Any] = {},
        target: Optional[Tuple[str, str]] = None,
        **kwargs: Any,
    ) -> SelfTransformCatBoostEncodePandas:
        r"""
        Fit transformation parameters by example input and output.

        Args
        ----
        - input
            Example input to the transformation.
        - output
            Example output from the transformation.
        - category_encoders_init_args
            Positional arguments to PCA handler initialization of Category Encoder package.
        - category_encoders_init_kwargs
            Keyword arguments to PCA handler initialization of Category Encoder package.
        - category_encoders_fit_args
            Positional arguments to PCA handler parameter fitting of Category Encoder package.
        - category_encoders_fit_kwargs
            Keyword arguments to PCA handler parameter fitting of Category Encoder package.
        - target
            Target column in given input.
            The first element should be categorical or continuous, and the second element should be
            a valid column name defined by the first element.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Get columns to be handled.
        (categorical, _, labels_categorical, labels_continuous) = input
        labels = {"categorical": labels_categorical, "continuous": labels_continuous}

        # Check target availability.
        if target is not None:
            # Parse and valid target definition.
            target_type, target_name = target
            assert (
                target_name in labels[target_type]
            ), f'Target column "{target_name:s}" is not in {target_type:s} data.'
        else:
            # Target must be explicitly defined.
            raise ErrorTransformUnsupportPartial(
                "Target must be explicitly defined for CatBoost encoder."
            )

        # Make a clone of essential arguments, and perform auto correction and filling.
        category_encoders_init_args_ = [*category_encoders_init_args]
        category_encoders_init_kwargs_ = {**category_encoders_init_kwargs}
        category_encoders_fit_args_ = [*category_encoders_fit_args]
        category_encoders_fit_kwargs_ = {**category_encoders_fit_kwargs}
        if (
            "random_state" not in category_encoders_init_kwargs_
            or category_encoders_init_kwargs_["random_state"] is None
        ):
            # Autofill random seed.
            category_encoders_init_kwargs_["random_state"] = 42

        # Utilize Scikit Learn handler directly.
        self.category_encoders_init_args = category_encoders_init_args_
        self.category_encoders_init_kwargs = category_encoders_init_kwargs_
        self.handler = CatBoostEncoder(
            *self.category_encoders_init_args, **self.category_encoders_init_kwargs
        )
        self.handler.fit(
            categorical,
            labels[target_type][target_name],
            *category_encoders_fit_args_,
            **category_encoders_fit_kwargs_,
        )
        self.columns_categorical_consume = self.handler.get_feature_names_in()
        self.mean_from_fitting = self.handler._mean
        return self

    def get_metadata(self: SelfTransformCatBoostEncodePandas, /) -> Mapping[str, Any]:
        r"""
        Get metadata of the transformation.

        Args
        ----

        Returns
        -------
        - metadata
            Metadata of the transformation.
        """
        # Do nothing.
        return {}

    def get_numeric_data(self: SelfTransformCatBoostEncodePandas, /) -> Mapping[str, NPANYS]:
        r"""
        Get numeric data of the transformation.

        Args
        ----

        Returns
        -------
        - data
            Numeric data of the transformation.
        """
        # Collect numeric data of encodings.
        return {
            name: encoding.values.astype(np.float64) if len(encoding) == 0 else encoding.values
            for name, encoding in self.handler.mapping.items()
        }

    def get_alphabetic_data(self: SelfTransformCatBoostEncodePandas, /) -> Mapping[str, Any]:
        r"""
        Get alphabetic data of the transformation.

        Args
        ----

        Returns
        -------
        - data
            Alphabetic data of the transformation.
        """
        # Collect Category Encoder parameters.
        return {
            "columns": self.columns_categorical_consume,
            "category_encoders_init_args": self.category_encoders_init_args,
            "category_encoders_init_kwargs": self.category_encoders_init_kwargs,
            "category_encoders_params": self.handler.get_params(),
            "mean_from_fitting": self.mean_from_fitting,
            "categories": {
                name: list(encoding.index) for name, encoding in self.handler.mapping.items()
            },
        }

    def set_metadata(
        self: SelfTransformCatBoostEncodePandas, metadata: Mapping[str, Any], /  # noqa: W504
    ) -> SelfTransformCatBoostEncodePandas:
        r"""
        Set metadata of the transformation.

        Args
        ----
        - metadata
            Metadata of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Do nothing.
        return self

    def set_numeric_data(
        self: SelfTransformCatBoostEncodePandas, data: Mapping[str, NPANYS], /  # noqa: W504
    ) -> SelfTransformCatBoostEncodePandas:
        r"""
        Set numeric data of the transformation.

        Args
        ----
        - data
            Numeric data of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Cache loaded numeric data.
        self._content = data
        return self

    def set_alphabetic_data(
        self: SelfTransformCatBoostEncodePandas, data: Mapping[str, Any], /  # noqa: W504
    ) -> SelfTransformCatBoostEncodePandas:
        r"""
        Set alphabetic data of the transformation.

        Args
        ----
        - data
            Alphabetic data of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Safety check.
        assert "columns" in data, "Categorical column namings are missing."
        assert (
            "category_encoders_init_args" in data
        ), "Positional arguments of Category Encoder handler initialization are missing."
        assert (
            "category_encoders_init_kwargs" in data
        ), "Keyword arguments of Category Encoder handler initialization are missing."
        assert (
            "category_encoders_params" in data
        ), "Parameters of Category Encoder handler are missing."
        assert "mean_from_fitting" in data, "Global mean from parameter fitting data is missing."
        assert "categories" in data, "Category namings are missing."
        for name, encoding in self._content.items():
            # Safety check for each numeric data block.
            assert name in data["categories"], f'Category namings of column "{name:s}" is missing.'
            assert len(data["categories"][name]) == len(
                encoding
            ), f'Sizes of category namings and numeric data of column "{name:s}" do not match.'

        # Create a Scikit Learn handler of loaded numeric data.
        self.handler = CatBoostEncoder(
            *data["category_encoders_init_args"], **data["category_encoders_init_kwargs"]
        )
        self.handler.set_params(**data["category_encoders_params"])

        # Update arraies of the handler by loaded numeric data.
        self.columns_categorical_consume = data["columns"]
        self.mean_from_fitting = data["mean_from_fitting"]
        self.handler._dim = len(self.columns_categorical_consume)
        self.handler.cols = self.columns_categorical_consume
        self.handler.mapping = {
            name: pd.DataFrame(encoding, columns=["sum", "count"], index=data["categories"][name])
            for name, encoding in self._content.items()
        }
        self.handler._mean = self.mean_from_fitting
        del self._content
        return self
