# Import Python packages.
from typing import Any, List, Mapping, Optional, Sequence, TypeVar

# Import external packages.
import numpy as np
import pandas as pd
from sklearn.cluster import FeatureAgglomeration  # type: ignore[import-untyped]

# Import relatively from other modules.
from ....types import NPANYS
from ...base import BaseTransformPandas, ErrorTransformUnsupportPartial


# Type aliases.
Input = List[pd.DataFrame]
Output = List[pd.DataFrame]


# Self types.
SelfTransformFeatAggloPandas = TypeVar(
    "SelfTransformFeatAggloPandas", bound="TransformFeatAggloPandas"
)


class TransformFeatAggloPandas(BaseTransformPandas):
    r"""
    Transformation for Feature Agglomeration on Pandas data.
    """
    # Transformation unique identifier.
    _IDENTIFIER = "cluster.featagglo.pandas"

    def input(self: SelfTransformFeatAggloPandas, raw: Any, /) -> Input:
        r"""
        Convert raw data into input to the transformation.

        Args
        ----
        - raw
            Raw data.

        Returns
        -------
        - process
            Processed data compatible with the transformation.
        """
        # Conversion will vary according to raw data.
        if raw is None:
            # To have the PCA running properly, we need at least two different samples with at least
            # two continuous features.
            return [pd.DataFrame([[0.0, 0.0], [1.0, 1.0]], columns=["continuous1", "continuous2"])]
        else:
            # All the other cases are not supported.
            raise ErrorTransformUnsupportPartial(
                f"Try to formalize incompatible raw data into input domain of"
                f' "{self._IDENTIFIER:s}".'
            )

    def output(self: SelfTransformFeatAggloPandas, raw: Any, /) -> Output:
        r"""
        Convert raw data into output from the transformation.

        Args
        ----
        - raw
            Raw data.

        Returns
        -------
        - process
            Processed data compatible with the transformation.
        """
        # Conversion will vary according to raw data.
        if raw is None:
            # Expecting output is assumed to be default PCA with no reduction.
            return [
                pd.DataFrame(
                    FeatureAgglomeration(n_clusters=2).fit_transform(
                        np.array([[0.0, 0.0], [1.0, 1.0]])
                    ),
                    columns=["0", "1"],
                )
            ]
        else:
            # All the other cases are not supported.
            raise ErrorTransformUnsupportPartial(
                f"Try to formalize incompatible raw data into output domain of"
                f' "{self._IDENTIFIER:s}".'
            )

    def transform(
        self: SelfTransformFeatAggloPandas,
        input: Input,
        /,
        *args: Any,
        columns: Optional[Sequence[str]] = None,
        **kwargs: Any,
    ) -> Output:
        r"""
        Transform input into output without inplacement.

        Args
        ----
        - input
            Input to the transformation.
        - columns
            Dataframe column titles after projection.
            If it is not given, it is the string representions of column integer indices.

        Returns
        -------
        - output
            Output from the transformation.
        """
        # Collect continuous dataframe to be projected.
        (dataframe,) = input

        # Utilize Scikit Learn handler directly.
        columns_ = list(columns) if columns else [str(i) for i in range(self.n_clusters)]
        return [pd.DataFrame(self.handler.transform(dataframe.values), columns=columns_)]

    def transform_(
        self: SelfTransformFeatAggloPandas, input: Input, /, *args: Any, **kwargs: Any
    ) -> SelfTransformFeatAggloPandas:
        r"""
        Transform input with inplacement.

        Args
        ----
        - input
            Input to the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Get the output and replace input of corresponding positions by the output.
        output = self.transform(input, *args, **kwargs)
        (dataframe,) = output
        input[0] = dataframe
        return self

    def fit(
        self: SelfTransformFeatAggloPandas,
        input: Input,
        output: Output,
        /,
        *args: Any,
        scikit_learn_init_args: Sequence[Any] = [],
        scikit_learn_init_kwargs: Mapping[str, Any] = {},
        scikit_learn_fit_args: Sequence[Any] = [],
        scikit_learn_fit_kwargs: Mapping[str, Any] = {},
        **kwargs: Any,
    ) -> SelfTransformFeatAggloPandas:
        r"""
        Fit transformation parameters by example input and output.

        Args
        ----
        - input
            Example input to the transformation.
        - output
            Example output from the transformation.
        - scikit_learn_init_args
            Positional arguments to Feature Agglomeration handler initialization of Scikit Learn
            package.
        - scikit_learn_init_kwargs
            Keyword arguments to Feature Agglomeration handler initialization of Scikit Learn
            package.
        - scikit_learn_fit_args
            Positional arguments to Feature Agglomeration handler parameter fitting of Scikit Learn
            package.
        - scikit_learn_fit_kwargs
            Keyword arguments to Feature Agglomeration handler parameter fitting of Scikit Learn
            package.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Get the dataframe to be handled.
        (dataframe,) = input

        # Make a clone of essential arguments, and perform auto correction and filling.
        scikit_learn_init_args_ = [*scikit_learn_init_args]
        scikit_learn_init_kwargs_ = {**scikit_learn_init_kwargs}
        scikit_learn_fit_args_ = [*scikit_learn_fit_args]
        scikit_learn_fit_kwargs_ = {**scikit_learn_fit_kwargs}
        assert (
            "n_clusters" in scikit_learn_init_kwargs_
        ), "Number of targeting clusters must be explicitly defined."

        # Pooling function argument should be specially saved.
        self._pool_func = "mean"
        if (
            "pooling_func" in scikit_learn_init_kwargs_
            and scikit_learn_init_kwargs_["pooling_func"] is not None
        ):
            # We will explicitly assign pooling function as a callable object later.
            self._pool_func = str(scikit_learn_init_kwargs_["pooling_func"])
            del scikit_learn_init_kwargs_["pooling_func"]
        assert self._pool_func == "mean"
        scikit_learn_init_kwargs_["pooling_func"] = np.mean

        # Utilize Scikit Learn handler directly.
        self.scikit_learn_init_args = scikit_learn_init_args_
        self.scikit_learn_init_kwargs = scikit_learn_init_kwargs_
        self.handler = FeatureAgglomeration(
            *self.scikit_learn_init_args, **self.scikit_learn_init_kwargs
        )
        self.handler.fit(dataframe.values, *scikit_learn_fit_args_, **scikit_learn_fit_kwargs_)
        self.n_clusters = self.handler.n_clusters_
        return self

    def get_metadata(self: SelfTransformFeatAggloPandas, /) -> Mapping[str, Any]:
        r"""
        Get metadata of the transformation.

        Args
        ----

        Returns
        -------
        - metadata
            Metadata of the transformation.
        """
        # Do nothing.
        return {}

    def get_numeric_data(self: SelfTransformFeatAggloPandas, /) -> Mapping[str, NPANYS]:
        r"""
        Get numeric data of the transformation.

        Args
        ----

        Returns
        -------
        - data
            Numeric data of the transformation.
        """
        # Do nothing.
        return {"labels": self.handler.labels_}

    def get_alphabetic_data(self: SelfTransformFeatAggloPandas, /) -> Mapping[str, Any]:
        r"""
        Get alphabetic data of the transformation.

        Args
        ----

        Returns
        -------
        - data
            Alphabetic data of the transformation.
        """
        # Collect Scikit Learn parameters.
        # Pay attention that pooling function is a callable object that can not be serialized, thus
        # we have to save its named index instead.
        scikit_learn_init_kwargs = {**self.scikit_learn_init_kwargs}
        del scikit_learn_init_kwargs["pooling_func"]
        scikit_learn_params = {**self.handler.get_params()}
        del scikit_learn_params["pooling_func"]
        return {
            "scikit_learn_init_args": self.scikit_learn_init_args,
            "scikit_learn_init_kwargs": scikit_learn_init_kwargs,
            "scikit_learn_params": scikit_learn_params,
            "pool_func": self._pool_func,
            "n_clusters": self.n_clusters,
        }

    def set_metadata(
        self: SelfTransformFeatAggloPandas, metadata: Mapping[str, Any], /  # noqa: W504
    ) -> SelfTransformFeatAggloPandas:
        r"""
        Set metadata of the transformation.

        Args
        ----
        - metadata
            Metadata of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Do nothing.
        return self

    def set_numeric_data(
        self: SelfTransformFeatAggloPandas, data: Mapping[str, NPANYS], /  # noqa: W504
    ) -> SelfTransformFeatAggloPandas:
        r"""
        Set numeric data of the transformation.

        Args
        ----
        - data
            Numeric data of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Safety check.
        assert "labels" in data, "Numeric array of feature cluster labels is missing."

        # Cache loaded numeric data.
        self._content = data
        return self

    def set_alphabetic_data(
        self: SelfTransformFeatAggloPandas, data: Mapping[str, Any], /  # noqa: W504
    ) -> SelfTransformFeatAggloPandas:
        r"""
        Set alphabetic data of the transformation.

        Args
        ----
        - data
            Alphabetic data of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Safety check.
        assert (
            "scikit_learn_init_args" in data
        ), "Positional arguments of Scikit Learn handler initialization are missing."
        assert (
            "scikit_learn_init_kwargs" in data
        ), "Keyword arguments of Scikit Learn handler initialization are missing."
        assert "scikit_learn_params" in data, "Parameters of Scikit Learn handler are missing."
        assert (
            "pool_func" in data
        ), "Named index of pooling function of Scikit Learn handler is missing."
        assert (
            "n_clusters" in data
        ), "Number of targeting clusters of Scikit Learn handler is missing."

        # Parse loaded data.
        scikit_learn_params = data["scikit_learn_params"]

        # Handler pooling function specially.
        assert data["pool_func"] == "mean"
        scikit_learn_params["pooling_func"] = np.mean

        # Create a Scikit Learn handler of loaded numeric data.
        self.handler = FeatureAgglomeration(
            *data["scikit_learn_init_args"], **data["scikit_learn_init_kwargs"]
        )
        self.handler.set_params(**scikit_learn_params)
        self.n_clusters = data["n_clusters"]

        # Update arraies of the handler by loaded numeric data.
        self.handler.labels_ = self._content["labels"]
        return self
