# Import Python packages.
from enum import IntEnum
from typing import Any, List, Mapping, Optional, Sequence, TypeVar

# Import external packages.
import pandas as pd

# Import relatively from other modules.
from ...types import NPANYS
from ..base import BaseTransformPandas, ErrorTransformUnsupportPartial


# Enumeration types.
_Value = IntEnum("_Value", ["CATEGORICAL", "CONTINUOUS"])


# Type aliases.
Input = List[pd.DataFrame]
Output = List[pd.DataFrame]


# Self types.
SelfTransformTabularize = TypeVar("SelfTransformTabularize", bound="TransformTabularize")


class TransformTabularize(BaseTransformPandas):
    r"""
    Transformation for CCA on Pandas data.
    """
    # Transformation unique identifier.
    _IDENTIFIER = "tabularize"

    # Value types.
    CATEGORICAL = 0
    CONTINUOUS = 1

    def input(self: SelfTransformTabularize, raw: Any, /) -> Input:
        r"""
        Convert raw data into input to the transformation.

        Args
        ----
        - raw
            Raw data.

        Returns
        -------
        - process
            Processed data compatible with the transformation.
        """
        # Conversion will vary according to raw data.
        if raw is None:
            # Create an empty categorical container and an empty continuous container.
            return [pd.DataFrame([], columns=["generic"])]
        else:
            # All the other cases are not supported.
            raise ErrorTransformUnsupportPartial(
                f"Try to formalize incompatible raw data into input domain of"
                f' "{self._IDENTIFIER:s}".'
            )

    def output(self: SelfTransformTabularize, raw: Any, /) -> Output:
        r"""
        Convert raw data into output from the transformation.

        Args
        ----
        - raw
            Raw data.

        Returns
        -------
        - process
            Processed data compatible with the transformation.
        """
        # Conversion will vary according to raw data.
        if raw is None:
            # Create an empty continuous container corresponding to null input.
            return [pd.DataFrame([], columns=["generic"]), pd.DataFrame([], columns=[])]
        else:
            # All the other cases are not supported.
            raise ErrorTransformUnsupportPartial(
                f"Try to formalize incompatible raw data into output domain of"
                f' "{self._IDENTIFIER:s}".'
            )

    def transform(
        self: SelfTransformTabularize, input: Input, /, *args: Any, **kwargs: Any
    ) -> Output:
        r"""
        Transform input into output without inplacement.

        Args
        ----
        - input
            Input to the transformation.

        Returns
        -------
        - output
            Output from the transformation.
        """
        # Collect generic dataframe.
        (generic,) = input

        # Processing column safety check.
        redundant = list(
            sorted(
                (
                    set(generic.columns)
                    - set(self._columns_categorical)
                    - set(self._columns_continuous)
                )
            )
        )
        redundant_ = [*redundant[:3], "..."] if len(redundant) > 3 else redundant
        if redundant:
            # Forbid automatically handle redundant columns for safety.
            raise ErrorTransformUnsupportPartial(
                '"{:s}" transformation gets redundant columns {:s}.'.format(
                    self._IDENTIFIER, ", ".join([f'"{name:s}"' for name in redundant_])
                )
            )
        missing = list(sorted(set(self._columns_categorical) - set(generic.columns)))
        missing_ = [*missing[:3], "..."] if len(missing) > 3 else missing
        if missing:
            # Forbid automatically ignore missing categorical columns for safety.
            raise ErrorTransformUnsupportPartial(
                '"{:s}" transformation is missing categorical columns {:s}.'.format(
                    self._IDENTIFIER, ", ".join([f'"{name:s}"' for name in missing_])
                )
            )
        missing = list(sorted(set(self._columns_continuous) - set(generic.columns)))
        missing_ = [*missing[:3], "..."] if len(missing) > 3 else missing
        if missing:
            # Forbid automatically ignore missing continuous columns for safety.
            raise ErrorTransformUnsupportPartial(
                '"{:s}" transformation is missing continuous columns {:s}.'.format(
                    self._IDENTIFIER, ", ".join([f'"{name:s}"' for name in missing_])
                )
            )

        # Separate generic dataframe into categorical and continuous dataframes.
        categorical = generic[self._columns_categorical]
        continuous = generic[self._columns_continuous]
        output = [categorical, continuous]
        return output

    def transform_(
        self: SelfTransformTabularize, input: Input, /, *args: Any, **kwargs: Any
    ) -> SelfTransformTabularize:
        r"""
        Transform input with inplacement.

        Args
        ----
        - input
            Input to the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Get the output and replace input of corresponding positions by the output.
        output = self.transform(input, *args, **kwargs)
        categorical, continuous = output
        input.pop(0)
        input.insert(0, categorical)
        input.insert(1, continuous)
        return self

    def inverse(
        self: SelfTransformTabularize, output: Output, /, *args: Any, **kwargs: Any
    ) -> Input:
        r"""
        Inverse output into input without inplacement.

        Args
        ----
        - output
            Output from the transformation.

        Returns
        -------
        - input
            Input to the transformation.
        """
        # Collect categorical and continuous dataframes.
        categorical, continuous = output

        # Processing column safety check.
        redundant = list(sorted(set(categorical.columns) - set(self._columns_categorical)))
        redundant_ = [*redundant[:3], "..."] if len(redundant) > 3 else redundant
        if redundant:
            # Forbid automatically handle redundant categorical columns for safety.
            raise ErrorTransformUnsupportPartial(
                '"{:s}" transformation gets redundant categorical columns {:s}.'.format(
                    self._IDENTIFIER, ", ".join([f'"{name:s}"' for name in redundant_])
                )
            )
        missing = list(sorted(set(self._columns_categorical) - set(categorical.columns)))
        missing_ = [*missing[:3], "..."] if len(missing) > 3 else missing
        if missing:
            # Forbid automatically ignore missing categorical columns for safety.
            raise ErrorTransformUnsupportPartial(
                '"{:s}" transformation is missing categorical columns {:s}.'.format(
                    self._IDENTIFIER, ", ".join([f'"{name:s}"' for name in missing_])
                )
            )
        redundant = list(sorted(set(continuous.columns) - set(self._columns_continuous)))
        redundant_ = [*redundant[:3], "..."] if len(redundant) > 3 else redundant
        if redundant:
            # Forbid automatically handle redundant continuous columns for safety.
            raise ErrorTransformUnsupportPartial(
                '"{:s}" transformation gets redundant continuous columns {:s}.'.format(
                    self._IDENTIFIER, ", ".join([f'"{name:s}"' for name in redundant_])
                )
            )
        missing = list(sorted(set(self._columns_continuous) - set(continuous.columns)))
        missing_ = [*missing[:3], "..."] if len(missing) > 3 else missing
        if missing:
            # Forbid automatically ignore missing continuous columns for safety.
            raise ErrorTransformUnsupportPartial(
                '"{:s}" transformation is missing continuous columns {:s}.'.format(
                    self._IDENTIFIER, ", ".join([f'"{name:s}"' for name in missing_])
                )
            )

        # Merge columns together following original order.
        _columns = [name for name, _ in self.columns]
        output = [pd.concat([categorical, continuous], axis=1)[_columns]]
        return output

    def inverse_(
        self: SelfTransformTabularize, output: Output, /, *args: Any, **kwargs: Any
    ) -> SelfTransformTabularize:
        r"""
        Inverse output with inplacement.

        Args
        ----
        - output
            Output from the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Get the input and replace output of corresponding positions by the input.
        input = self.inverse(output, *args, **kwargs)
        (generic,) = input
        output.pop(1)
        output.pop(0)
        output.insert(0, generic)
        return self

    def fit(
        self: SelfTransformTabularize,
        input: Input,
        output: Output,
        /,
        *args: Any,
        discretizable: Optional[Sequence[str]] = None,
        **kwargs: Any,
    ) -> SelfTransformTabularize:
        r"""
        Fit transformation parameters by example input and output.

        Args
        ----
        - input
            Example input to the transformation.
        - output
            Example output from the transformation.
        - discretizable
            Columns that will be enforced as categorical regardless of value types.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Get columns to be handled.
        (generic,) = input

        # Save essential parameters directly from input.
        self._discretizable = [] if discretizable is None else discretizable

        # Collect column value types.
        self.columns = [
            (str(name), self.get_value_type(str(name), series)) for name, series in generic.items()
        ]
        self._columns_categorical = [
            name for name, vtype in self.columns if vtype == _Value.CATEGORICAL
        ]
        self._columns_continuous = [
            name for name, vtype in self.columns if vtype == _Value.CONTINUOUS
        ]
        return self

    def get_value_type(
        self: SelfTransformTabularize, name: str, series: "pd.Series[Any]", /  # noqa: W504
    ) -> _Value:
        r"""
        Get value type from column information.

        Args
        ----
        - name
            Column title.
        - series
            Column data.

        Returns
        -------
        - vtype
            Value type of the column.
        """
        # Only valued cells matter in value type analysis.
        series = series.dropna()
        if name in self._discretizable:
            # Column is enforced to be categorical.
            return _Value.CATEGORICAL
        if len(series) == 0:
            # A fully null column is treated as categorical.
            return _Value.CATEGORICAL
        if pd.api.types.is_numeric_dtype(series):
            # If valued column only contains numeric data without categorical enforcement, it is a
            # continuous column.
            return _Value.CONTINUOUS
        else:
            # Otherwise, it is a categorical column.
            return _Value.CATEGORICAL

    def get_metadata(self: SelfTransformTabularize, /) -> Mapping[str, Any]:
        r"""
        Get metadata of the transformation.

        Args
        ----

        Returns
        -------
        - metadata
            Metadata of the transformation.
        """
        # Do nothing.
        return {}

    def get_numeric_data(self: SelfTransformTabularize, /) -> Mapping[str, NPANYS]:
        r"""
        Get numeric data of the transformation.

        Args
        ----

        Returns
        -------
        - data
            Numeric data of the transformation.
        """
        # Do nothing.
        return {}

    def get_alphabetic_data(self: SelfTransformTabularize, /) -> Mapping[str, Any]:
        r"""
        Get alphabetic data of the transformation.

        Args
        ----

        Returns
        -------
        - data
            Alphabetic data of the transformation.
        """
        # Collect all columns and their value types.
        return {"columns": self.columns}

    def set_metadata(
        self: SelfTransformTabularize, metadata: Mapping[str, Any], /  # noqa: W504
    ) -> SelfTransformTabularize:
        r"""
        Set metadata of the transformation.

        Args
        ----
        - metadata
            Metadata of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Do nothing.
        return self

    def set_numeric_data(
        self: SelfTransformTabularize, data: Mapping[str, NPANYS], /  # noqa: W504
    ) -> SelfTransformTabularize:
        r"""
        Set numeric data of the transformation.

        Args
        ----
        - data
            Numeric data of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Do nothing.
        return self

    def set_alphabetic_data(
        self: SelfTransformTabularize, data: Mapping[str, Any], /  # noqa: W504
    ) -> SelfTransformTabularize:
        r"""
        Set alphabetic data of the transformation.

        Args
        ----
        - data
            Alphabetic data of the transformation.

        Returns
        -------
        - self
            Class instance itself.
        """
        # Safety check.
        assert "columns" in data, "Columns and their value types are missing."

        # Load columns and their value types.
        self.columns = data["columns"]

        # Validate loaded parameters.
        vtypes = set(_Value)
        invalid = [f'"{name:s}"' for name, vtype in self.columns if vtype not in vtypes]
        invalid_ = [*invalid[:3], "..."] if len(invalid) > 3 else invalid
        assert (
            not invalid
        ), '"{:s}" transformation loads invalid value types for columns {:s}.'.format(
            self._IDENTIFIER, ", ".join(invalid_)
        )

        # Generate categorical and continuous column names.
        self._columns_categorical = [
            name for name, vtype in self.columns if vtype == _Value.CATEGORICAL
        ]
        self._columns_continuous = [
            name for name, vtype in self.columns if vtype == _Value.CONTINUOUS
        ]
        return self
