# Import Python packages.
import copy
from typing import Any, List, Mapping, Optional, Tuple, cast

# Import external packages.
import numpy as np
import pandas as pd

# Import PyTest packagtes.
import pytest

# Import PyTest external packages.
from py._path.local import LocalPath

# Import developing library.
import fin_tech_py_toolkit as lib

# Import testing library.
from ....utils import eq_dataframe, to_eq_plural_ordered
from ...utils import template_test_io, template_test_transform


# Type aliases.
Input = List[pd.DataFrame]
Output = List[pd.DataFrame]


# Runtime constants.
IDENTIFIER = lib.transforms.TransformCCAPandas._IDENTIFIER


def synthesize(
    *, nan: float, ddof: int, ood: Optional[str], normalize_degree: str
) -> Tuple[Tuple[Input, Output], Input, Output, Mapping[str, Any]]:
    r"""
    Synthesize test I/O.

    Args
    ----
    - nan
        Default value when no valued element is presented in an aggregation.
    - ddof
        Means Delta Degrees of Freedom.
    - ood
        Category representation reserved for out-of-distribution.
        If it is null, out-of-distribution is not allowed.
    - normalize_degree
        Degree normalization schema.

    Returns
    -------
    - example
        Input and output examples.
    - input
        Input case.
    - output
        Output case.
    - supplement
        Supplementary materical for synthesized test.
    """
    # Create categorical and continuous columns.
    categorical1 = ["010", "010", "010", "020", "020"]
    categorical2 = ['"X"', '"X"', '"Y"', '"Y"', '"Y"']
    categorical3 = ["A", "B", "C", "D", "E"]
    continuous1 = [1.0, 2.0, 3.0, 4.0, 5.0]
    continuous2 = [1.0, float("nan"), 2.0, float("nan"), float("nan")]
    continuous3 = [0.0, 0.0, 0.0, 0.0, 0.0]

    # Create out-of-distribution categorical and continuous columns.
    categorical1_ = ["030"]
    categorical2_ = ['"Z"']
    categorical3_ = ["F"]
    continuous1_ = [0.0]
    continuous2_ = [0.0]
    continuous3_ = [0.0]

    # Create encoding parameters.
    encodings = {
        "010": {
            "categorical1-cca-continuous1-null": 0.0,
            "categorical1-cca-continuous1-deg": 3,
            "categorical1-cca-continuous1-min": 1.0,
            "categorical1-cca-continuous1-mean": float(np.mean([1.0, 2.0, 3.0])),
            "categorical1-cca-continuous1-max": 3.0,
            "categorical1-cca-continuous1-std": float(np.std([1.0, 2.0, 3.0], ddof=min(ddof, 2))),
            "categorical1-cca-continuous2-null": 1.0 / 3.0,
            "categorical1-cca-continuous2-deg": 2,
            "categorical1-cca-continuous2-min": 1.0,
            "categorical1-cca-continuous2-mean": float(np.mean([1.0, 2.0])),
            "categorical1-cca-continuous2-max": 2.0,
            "categorical1-cca-continuous2-std": float(np.std([1.0, 2.0], ddof=min(ddof, 1))),
        },
        "020": {
            "categorical1-cca-continuous1-null": 0.0,
            "categorical1-cca-continuous1-deg": 2,
            "categorical1-cca-continuous1-min": 4.0,
            "categorical1-cca-continuous1-mean": float(np.mean([4.0, 5.0])),
            "categorical1-cca-continuous1-max": 5.0,
            "categorical1-cca-continuous1-std": float(np.std([4.0, 5.0], ddof=min(ddof, 1))),
            "categorical1-cca-continuous2-null": 1.0,
            "categorical1-cca-continuous2-deg": 0,
            "categorical1-cca-continuous2-min": 0.0,
            "categorical1-cca-continuous2-mean": 0.0,
            "categorical1-cca-continuous2-max": 0.0,
            "categorical1-cca-continuous2-std": 0.0,
        },
        '"X"': {
            "categorical2-cca-continuous1-null": 0.0,
            "categorical2-cca-continuous1-deg": 2,
            "categorical2-cca-continuous1-min": 1.0,
            "categorical2-cca-continuous1-mean": float(np.mean([1.0, 2.0])),
            "categorical2-cca-continuous1-max": 2.0,
            "categorical2-cca-continuous1-std": float(np.std([1.0, 2.0], ddof=min(ddof, 1))),
            "categorical2-cca-continuous2-null": 1.0 / 2.0,
            "categorical2-cca-continuous2-deg": 1,
            "categorical2-cca-continuous2-min": 1.0,
            "categorical2-cca-continuous2-mean": float(np.mean([1.0])),
            "categorical2-cca-continuous2-max": 1.0,
            "categorical2-cca-continuous2-std": float(np.std([1.0], ddof=min(ddof, 0))),
        },
        '"Y"': {
            "categorical2-cca-continuous1-null": 0.0,
            "categorical2-cca-continuous1-deg": 3,
            "categorical2-cca-continuous1-min": 3.0,
            "categorical2-cca-continuous1-mean": float(np.mean([3.0, 4.0, 5.0])),
            "categorical2-cca-continuous1-max": 5.0,
            "categorical2-cca-continuous1-std": float(np.std([3.0, 4.0, 5.0], ddof=min(ddof, 2))),
            "categorical2-cca-continuous2-null": 2.0 / 3.0,
            "categorical2-cca-continuous2-deg": 1,
            "categorical2-cca-continuous2-min": 2.0,
            "categorical2-cca-continuous2-mean": float(np.mean([2.0])),
            "categorical2-cca-continuous2-max": 2.0,
            "categorical2-cca-continuous2-std": float(np.std([2.0], ddof=min(ddof, 0))),
        },
    }
    encodings_ = {
        str(ood): {
            "categorical1-cca-continuous1-null": 0.0,
            "categorical1-cca-continuous1-deg": 0,
            "categorical1-cca-continuous1-min": nan,
            "categorical1-cca-continuous1-mean": nan,
            "categorical1-cca-continuous1-max": nan,
            "categorical1-cca-continuous1-std": nan,
            "categorical1-cca-continuous2-null": 0.0,
            "categorical1-cca-continuous2-deg": 0,
            "categorical1-cca-continuous2-min": nan,
            "categorical1-cca-continuous2-mean": nan,
            "categorical1-cca-continuous2-max": nan,
            "categorical1-cca-continuous2-std": nan,
            "categorical2-cca-continuous1-null": 0.0,
            "categorical2-cca-continuous1-deg": 0,
            "categorical2-cca-continuous1-min": nan,
            "categorical2-cca-continuous1-mean": nan,
            "categorical2-cca-continuous1-max": nan,
            "categorical2-cca-continuous1-std": nan,
            "categorical2-cca-continuous2-null": 0.0,
            "categorical2-cca-continuous2-deg": 0,
            "categorical2-cca-continuous2-min": nan,
            "categorical2-cca-continuous2-mean": nan,
            "categorical2-cca-continuous2-max": nan,
            "categorical2-cca-continuous2-std": nan,
        }
    }

    # Extend encodings with out-of-distribution encodings.
    encodings = {**encodings, **encodings_}

    # Create input and output examples.
    example_input: Input
    example_input = [
        pd.DataFrame({"categorical1": categorical1, "categorical2": categorical2}),
        pd.DataFrame({"continuous1": continuous1, "continuous2": continuous2}),
        pd.DataFrame([], columns=[]),
        pd.DataFrame([], columns=[]),
    ]
    example_output: Output
    example_output = []

    # Input case.
    input: Input
    input = [
        pd.DataFrame(
            {
                "categorical1": categorical1 + ([] if ood is None else categorical1_),
                "categorical2": categorical2 + ([] if ood is None else categorical2_),
                "categorical3": categorical3 + ([] if ood is None else categorical3_),
            }
        ),
        pd.DataFrame(
            {
                "continuous1": continuous1 + ([] if ood is None else continuous1_),
                "continuous2": continuous2 + ([] if ood is None else continuous2_),
                "continuous3": continuous3 + ([] if ood is None else continuous3_),
            }
        ),
        pd.DataFrame([], columns=[]),
        pd.DataFrame([], columns=[]),
    ]

    # For output case, only out-of-distribution symbol remains.
    categorical1_ = [str(ood)]
    categorical2_ = [str(ood)]
    categorical3_ = ["F"]

    # Output case.
    output: Output
    output = [
        pd.DataFrame({"categorical3": categorical3 + ([] if ood is None else categorical3_)}),
        pd.DataFrame(
            {
                **{
                    name: [
                        encodings[cell][name]
                        for cell in categorical1 + ([] if ood is None else categorical1_)
                    ]
                    for name in [
                        "categorical1-cca-continuous1-null",
                        "categorical1-cca-continuous1-deg",
                        "categorical1-cca-continuous1-min",
                        "categorical1-cca-continuous1-mean",
                        "categorical1-cca-continuous1-max",
                        "categorical1-cca-continuous1-std",
                        "categorical1-cca-continuous2-null",
                        "categorical1-cca-continuous2-deg",
                        "categorical1-cca-continuous2-min",
                        "categorical1-cca-continuous2-mean",
                        "categorical1-cca-continuous2-max",
                        "categorical1-cca-continuous2-std",
                    ]
                },
                **{
                    name: [
                        encodings[cell][name]
                        for cell in categorical2 + ([] if ood is None else categorical2_)
                    ]
                    for name in [
                        "categorical2-cca-continuous1-null",
                        "categorical2-cca-continuous1-deg",
                        "categorical2-cca-continuous1-min",
                        "categorical2-cca-continuous1-mean",
                        "categorical2-cca-continuous1-max",
                        "categorical2-cca-continuous1-std",
                        "categorical2-cca-continuous2-null",
                        "categorical2-cca-continuous2-deg",
                        "categorical2-cca-continuous2-min",
                        "categorical2-cca-continuous2-mean",
                        "categorical2-cca-continuous2-max",
                        "categorical2-cca-continuous2-std",
                    ]
                },
                "continuous1": continuous1 + ([] if ood is None else continuous1_),
                "continuous2": continuous2 + ([] if ood is None else continuous2_),
                "continuous3": continuous3 + ([] if ood is None else continuous3_),
            }
        ),
        pd.DataFrame([], columns=[]),
        pd.DataFrame([], columns=[]),
    ]

    # Apply degree normalization.
    normdeg = lib.transforms.TransformCCAPandas.get_normdeg(normalize_degree)
    for name in [
        "categorical1-cca-continuous1-deg",
        "categorical1-cca-continuous2-deg",
        "categorical2-cca-continuous1-deg",
        "categorical2-cca-continuous2-deg",
    ]:
        # Apply default degree normalization to related columns.
        output[1][name] = normdeg(output[1][name])
    return (example_input, example_output), input, output, {}


@pytest.mark.parametrize(
    ("raw_input", "raw_output"),
    [
        pytest.param(
            ...,
            None,
            id="unsupport-input",
            marks=[pytest.mark.xfail(raises=lib.transforms.ErrorTransformUnsupportPartial)],
        ),
        pytest.param(
            None,
            ...,
            id="unsupport-output",
            marks=[pytest.mark.xfail(raises=lib.transforms.ErrorTransformUnsupportPartial)],
        ),
        pytest.param(None, None, id="both-null"),
    ],
)
def test_io(*, raw_input: Any, raw_output: Any) -> None:
    r"""
    Test transformation input and output domain formalization.

    Args
    ----
    - raw_input
        Raw input.
    - raw_output
        Raw output.

    Returns
    -------
    """
    # Initialize testing transformation.
    factory = lib.transforms.FactoryTransform()

    # Run test template.
    template_test_io(
        IDENTIFIER,
        factory,
        raw_input,
        raw_output,
        to_eq_plural_ordered(eq_dataframe),
        to_eq_plural_ordered(eq_dataframe),
    )


@pytest.mark.parametrize(
    ("nan", "ddof", "ood", "normalize_degree"),
    [
        pytest.param(0, 1, None, "identity", id="default"),
        pytest.param(
            0,
            1,
            "<unk>",
            "identity",
            id="ood",
            marks=[pytest.mark.xfail(raises=lib.transforms.ErrorTransformUnsupportPartial)],
        ),
        pytest.param(0, 1, None, "log_normal", id="log-normal"),
    ],
)
def test_default(
    *, tmpdir: LocalPath, nan: float, ddof: int, ood: Optional[str], normalize_degree: str
) -> None:
    r"""
    Test transformation for CCA on Pandas data.

    Args
    ----
    - tmpdir
        Temporary directory for this test.
        It is automatically provided by PyTest, so its value should not be explicitly defined.
    - nan
        Default value when no valued element is presented in an aggregation.
    - ddof
        Means Delta Degrees of Freedom.
    - ood
        Category representation reserved for out-of-distribution.
    - normalize_degree
        Degree normalization schema.

    Returns
    -------
    """
    # Initialize testing transformation.
    root = str(tmpdir)
    factory = lib.transforms.FactoryTransform()

    # Generate inputs and outputs.
    example, input, output, _ = synthesize(
        nan=nan, ddof=ddof, ood=ood, normalize_degree=normalize_degree
    )

    # Run test template.
    template_test_transform(
        root,
        IDENTIFIER,
        factory,
        example,
        input,
        output,
        to_eq_plural_ordered(eq_dataframe),
        to_eq_plural_ordered(eq_dataframe),
        fit_kwargs=dict(nan=nan, ddof=ddof, ood=ood, normalize_degree=normalize_degree),
    )


def test_empty(*, tmpdir: LocalPath) -> None:
    r"""
    Test transformation for CCA on empty Pandas data.

    Args
    ----
    - tmpdir
        Temporary directory for this test.
        It is automatically provided by PyTest, so its value should not be explicitly defined.

    Returns
    -------
    """
    # Initialize testing transformation.
    root = str(tmpdir)
    factory = lib.transforms.FactoryTransform()

    # Generate inputs and outputs.
    transform = factory.from_args(IDENTIFIER)
    input = transform.input(None)
    output = transform.output(None)
    example: Tuple[Input, Output]
    example = (input, [])

    # Run test template.
    template_test_transform(
        root,
        IDENTIFIER,
        factory,
        example,
        input,
        output,
        to_eq_plural_ordered(eq_dataframe),
        to_eq_plural_ordered(eq_dataframe),
    )


@pytest.mark.xfail(raises=lib.transforms.ErrorTransformUnsupportPartial)
def test_missing_columns(*, tmpdir: LocalPath) -> None:
    r"""
    Test transformation for CCA with insufficient (categorical) columns.

    Args
    ----
    - tmpdir
        Temporary directory for this test.
        It is automatically provided by PyTest, so its value should not be explicitly defined.

    Returns
    -------
    """
    # Initialize testing transformation.
    root = str(tmpdir)
    factory = lib.transforms.FactoryTransform()

    # Generate inputs and outputs.
    transform = factory.from_args(IDENTIFIER)
    input = transform.input(None)
    output = transform.output(None)
    input_ = copy.deepcopy(input)
    input_[0]["category-missing"] = []
    example: Tuple[Input, Output]
    example = (input_, [])

    # Run test template.
    template_test_transform(
        root,
        IDENTIFIER,
        factory,
        example,
        input,
        output,
        to_eq_plural_ordered(eq_dataframe),
        to_eq_plural_ordered(eq_dataframe),
    )


@pytest.mark.xfail(raises=lib.transforms.ErrorTransformUnsupportPartial)
def test_incomplete_inverse(*, tmpdir: LocalPath) -> None:
    r"""
    Test transformation for CCA with inversion on incomplete data.

    Args
    ----
    - tmpdir
        Temporary directory for this test.
        It is automatically provided by PyTest, so its value should not be explicitly defined.

    Returns
    -------
    """
    # Initialize testing transformation.
    root = str(tmpdir)
    factory = lib.transforms.FactoryTransform()

    # Generate inputs and outputs.
    nan = 0
    ddof = 1
    ood = None
    normalize_degree = "identity"
    example, input, output, _ = synthesize(
        nan=nan, ddof=ddof, ood=ood, normalize_degree=normalize_degree
    )
    categorical, continuous, *_labels = output
    names = [name for name in categorical.columns if "categorical1" not in name]
    categorical = categorical[names]
    names = [name for name in continuous.columns if "categorical1" not in name]
    continuous = continuous[names]
    output = [categorical, continuous, *_labels]

    # Run test template.
    template_test_transform(
        root,
        IDENTIFIER,
        factory,
        example,
        input,
        output,
        to_eq_plural_ordered(eq_dataframe),
        to_eq_plural_ordered(eq_dataframe),
        require_test_transform=False,
        require_test_transform_=False,
        fit_kwargs=dict(nan=nan, ddof=ddof, ood=ood, normalize_degree=normalize_degree),
    )


@pytest.mark.xfail(raises=lib.transforms.ErrorTransformUnsupportPartial)
def test_nonexist_inverse(*, tmpdir: LocalPath) -> None:
    r"""
    Test transformation for CCA with inversion on non-existing encodings.

    Args
    ----
    - tmpdir
        Temporary directory for this test.
        It is automatically provided by PyTest, so its value should not be explicitly defined.

    Returns
    -------
    """
    # Initialize testing transformation.
    root = str(tmpdir)
    factory = lib.transforms.FactoryTransform()

    # Generate inputs and outputs.
    nan = 0
    ddof = 1
    ood = None
    normalize_degree = "identity"
    example, input, output, _ = synthesize(
        nan=nan, ddof=ddof, ood=ood, normalize_degree=normalize_degree
    )
    _, continuous, _, _ = output
    series = continuous.iloc[-1].copy()
    for name in series.index:
        # Change encoding dimensions to arbitrary values to create a non-existing encoding.
        if "categorical1" in name:
            # Infinite is always an invalid encoding value.
            series[name] = float("inf")
    continuous.iloc[-1] = series

    # Run test template.
    template_test_transform(
        root,
        IDENTIFIER,
        factory,
        example,
        input,
        output,
        to_eq_plural_ordered(eq_dataframe),
        to_eq_plural_ordered(eq_dataframe),
        require_test_transform=False,
        require_test_transform_=False,
        fit_kwargs=dict(nan=nan, ddof=ddof, ood=ood, normalize_degree=normalize_degree),
    )


@pytest.mark.xfail(raises=lib.transforms.ErrorTransformUnsupportPartial)
def test_ambiguous_inverse(*, tmpdir: LocalPath) -> None:
    r"""
    Test transformation for CCA with inversion with ambiguous parameters.

    Args
    ----
    - tmpdir
        Temporary directory for this test.
        It is automatically provided by PyTest, so its value should not be explicitly defined.

    Returns
    -------
    """
    # Initialize testing transformation.
    root = str(tmpdir)
    factory = lib.transforms.FactoryTransform()

    # Generate inputs and outputs.
    categorical = ["X", "X", "Y", "Y"]
    continuous = [0.0, 2.0, 0.0, 2.0]
    encodings = {
        "X": {
            "categorical-cca-continuous-null": 0.0,
            "categorical-cca-continuous-deg": 2,
            "categorical-cca-continuous-min": 0.0,
            "categorical-cca-continuous-mean": float(np.mean([0.0, 2.0])),
            "categorical-cca-continuous-max": 2.0,
            "categorical-cca-continuous-std": float(np.std([0.0, 2.0])),
        },
        "Y": {
            "categorical-cca-continuous-null": 0.0,
            "categorical-cca-continuous-deg": 2,
            "categorical-cca-continuous-min": 0.0,
            "categorical-cca-continuous-mean": float(np.mean([0.0, 2.0])),
            "categorical-cca-continuous-max": 2.0,
            "categorical-cca-continuous-std": float(np.std([0.0, 2.0])),
        },
    }
    input = [
        pd.DataFrame({"categorical": categorical}),
        pd.DataFrame({"continuous": continuous}),
        pd.DataFrame([], columns=[]),
        pd.DataFrame([], columns=[]),
    ]
    output = [
        pd.DataFrame([], columns=[]),
        pd.DataFrame(
            {
                **{
                    name: [encodings[cell][name] for cell in categorical]
                    for name in [
                        "categorical-cca-continuous-null",
                        "categorical-cca-continuous-deg",
                        "categorical-cca-continuous-min",
                        "categorical-cca-continuous-mean",
                        "categorical-cca-continuous-max",
                        "categorical-cca-continuous-std",
                    ]
                },
                "continuous": continuous,
            }
        ),
        pd.DataFrame([], columns=[]),
        pd.DataFrame([], columns=[]),
    ]
    example: Tuple[Input, Output]
    example = (input, [])

    # Run test template.
    template_test_transform(
        root,
        IDENTIFIER,
        factory,
        example,
        input,
        output,
        to_eq_plural_ordered(eq_dataframe),
        to_eq_plural_ordered(eq_dataframe),
        require_test_transform=False,
        require_test_transform_=False,
        fit_kwargs=dict(normalize_degree="identity"),
    )


def normdeg_force_error(series: "pd.Series[Any]", /) -> "pd.Series[Any]":
    r"""
    Degree column normalization algorithm that always fail.

    Args
    ----
    - series
        Degree column.

    Returns
    -------
    - series
        Normalized degree column.
    """
    # Force error directly.
    raise RuntimeError("Force error.")


def test_register_sort_invalid() -> None:
    r"""
    Test registrating an invalid sorting algorithm.

    Args
    ----

    Returns
    -------
    """
    # Initialize testing transformation.
    factory = lib.transforms.FactoryTransform()
    transform = cast(lib.transforms.TransformCCAPandas, factory.from_args(IDENTIFIER))

    # Register invalid degree column normalization algorithm.
    transform.register_normdeg(normdeg_force_error, "test_invalid")
