# Import Python packages.
import copy
from typing import Any, List, Mapping, Optional, Tuple

# Import external packages.
import pandas as pd

# Import PyTest packagtes.
import pytest

# Import PyTest external packages.
from py._path.local import LocalPath

# Import developing library.
import fin_tech_py_toolkit as lib

# Import testing library.
from ....utils import eq_dataframe, to_eq_plural_ordered
from ...utils import template_test_io, template_test_transform


# Type aliases.
Input = List[pd.DataFrame]
Output = List[pd.DataFrame]


# Runtime constants.
IDENTIFIER = lib.transforms.TransformCountizePandas._IDENTIFIER


def synthesize(
    *, unk: int, ood: Optional[str]
) -> Tuple[Tuple[Input, Output], Input, Output, Mapping[str, Any]]:
    r"""
    Synthesize test I/O.

    Args
    ----
    - unk
        Default value for unknown or rare categories.
    - ood
        Category representation reserved for out-of-distribution.
        If it is null, out-of-distribution is not allowed.

    Returns
    -------
    - example
        Input and output examples.
    - input
        Input case.
    - output
        Output case.
    - supplement
        Supplementary materical for synthesized test.
    """
    # Create the categorical columns.
    categorical1 = ["010", "010", "010", "020", "030"]
    categorical2 = ['"X"', '"X"', '"X"', '"Y"', '"Z"']
    categorical3 = ["A", "B", "C", "D", "E"]

    # Create out-of-distribution categorical columns.
    categorical1_ = ["040"]
    categorical2_ = ['"U"']
    categorical3_ = ["F"]

    # Create encoding parameters.
    encodings = {
        "010": {"categorical1-count": 3.0},
        "020": {"categorical1-count": 1.0},
        "030": {"categorical1-count": 1.5},
        '"X"': {"categorical2-count": 3.0},
        '"Y"': {"categorical2-count": 1.0},
        '"Z"': {"categorical2-count": 1.5},
    }
    encodings_ = {str(ood): {"categorical1-count": float(unk), "categorical2-count": float(unk)}}

    # Extend encodings with out-of-distribution encodings.
    encodings = {**encodings, **encodings_}

    # Create input and output examples.
    example_input: Input
    example_input = [
        pd.DataFrame({"categorical1": categorical1, "categorical2": categorical2}),
        pd.DataFrame([], columns=[], index=range(5)),
        pd.DataFrame([], columns=[]),
        pd.DataFrame([], columns=[]),
    ]
    example_output: Output
    example_output = []

    # Input case.
    input: Input
    input = [
        pd.DataFrame(
            {
                "categorical1": categorical1 + ([] if ood is None else categorical1_),
                "categorical2": categorical2 + ([] if ood is None else categorical2_),
                "categorical3": categorical3 + ([] if ood is None else categorical3_),
            }
        ),
        pd.DataFrame([], columns=[], index=range(5 + int(ood is not None))),
        pd.DataFrame([], columns=[]),
        pd.DataFrame([], columns=[]),
    ]

    # For output case, only out-of-distribution symbol remains.
    categorical1_ = [str(ood)]
    categorical2_ = [str(ood)]
    categorical3_ = ["F"]

    # Output case.
    output: Output
    output = [
        pd.DataFrame({"categorical3": categorical3 + ([] if ood is None else categorical3_)}),
        pd.DataFrame(
            {
                "categorical1-count": [
                    encodings[cell]["categorical1-count"]
                    for cell in categorical1 + ([] if ood is None else categorical1_)
                ],
                "categorical2-count": [
                    encodings[cell]["categorical2-count"]
                    for cell in categorical2 + ([] if ood is None else categorical2_)
                ],
            }
        ),
        pd.DataFrame([], columns=[]),
        pd.DataFrame([], columns=[]),
    ]
    return (example_input, example_output), input, output, {}


@pytest.mark.parametrize(
    ("raw_input", "raw_output"),
    [
        pytest.param(
            ...,
            None,
            id="unsupport-input",
            marks=[pytest.mark.xfail(raises=lib.transforms.ErrorTransformUnsupportPartial)],
        ),
        pytest.param(
            None,
            ...,
            id="unsupport-output",
            marks=[pytest.mark.xfail(raises=lib.transforms.ErrorTransformUnsupportPartial)],
        ),
        pytest.param(None, None, id="both-null"),
    ],
)
def test_io(*, raw_input: Any, raw_output: Any) -> None:
    r"""
    Test transformation input and output domain formalization.

    Args
    ----
    - raw_input
        Raw input.
    - raw_output
        Raw output.

    Returns
    -------
    """
    # Initialize testing transformation.
    factory = lib.transforms.FactoryTransform()

    # Run test template.
    template_test_io(
        IDENTIFIER,
        factory,
        raw_input,
        raw_output,
        to_eq_plural_ordered(eq_dataframe),
        to_eq_plural_ordered(eq_dataframe),
    )


@pytest.mark.parametrize(
    ("unk", "ood"),
    [
        pytest.param(0, None, id="default"),
        pytest.param(
            0,
            "<unk>",
            id="ood",
            marks=[pytest.mark.xfail(raises=lib.transforms.ErrorTransformUnsupportPartial)],
        ),
    ],
)
def test_default(*, tmpdir: LocalPath, unk: int, ood: Optional[str]) -> None:
    r"""
    Test transformation for count encoding on Pandas data.

    Args
    ----
    - tmpdir
        Temporary directory for this test.
        It is automatically provided by PyTest, so its value should not be explicitly defined.
    - unk
        Default value for unknown or rare categories.
    - ood
        Category representation reserved for out-of-distribution.
        If it is null, out-of-distribution is not allowed.

    Returns
    -------
    """
    # Initialize testing transformation.
    root = str(tmpdir)
    factory = lib.transforms.FactoryTransform()

    # Generate inputs and outputs.
    example, input, output, _ = synthesize(unk=unk, ood=ood)

    # Run test template.
    template_test_transform(
        root,
        IDENTIFIER,
        factory,
        example,
        input,
        output,
        to_eq_plural_ordered(eq_dataframe),
        to_eq_plural_ordered(eq_dataframe),
        fit_kwargs=dict(unk=unk, ood=ood),
    )


def test_empty(*, tmpdir: LocalPath) -> None:
    r"""
    Test transformation for count encoding on empty Pandas data.

    Args
    ----
    - tmpdir
        Temporary directory for this test.
        It is automatically provided by PyTest, so its value should not be explicitly defined.

    Returns
    -------
    """
    # Initialize testing transformation.
    root = str(tmpdir)
    factory = lib.transforms.FactoryTransform()

    # Generate inputs and outputs.
    transform = factory.from_args(IDENTIFIER)
    input = transform.input(None)
    output = transform.output(None)
    example: Tuple[Input, Output]
    example = (input, [])

    # Run test template.
    template_test_transform(
        root,
        IDENTIFIER,
        factory,
        example,
        input,
        output,
        to_eq_plural_ordered(eq_dataframe),
        to_eq_plural_ordered(eq_dataframe),
    )


@pytest.mark.xfail(raises=lib.transforms.ErrorTransformUnsupportPartial)
def test_missing_columns(*, tmpdir: LocalPath) -> None:
    r"""
    Test transformation for count encoding with insufficient (categorical) columns.

    Args
    ----
    - tmpdir
        Temporary directory for this test.
        It is automatically provided by PyTest, so its value should not be explicitly defined.

    Returns
    -------
    """
    # Initialize testing transformation.
    root = str(tmpdir)
    factory = lib.transforms.FactoryTransform()

    # Generate inputs and outputs.
    transform = factory.from_args(IDENTIFIER)
    input = transform.input(None)
    output = transform.output(None)
    input_ = copy.deepcopy(input)
    input_[0]["category-missing"] = []
    example: Tuple[Input, Output]
    example = (input_, [])

    # Run test template.
    template_test_transform(
        root,
        IDENTIFIER,
        factory,
        example,
        input,
        output,
        to_eq_plural_ordered(eq_dataframe),
        to_eq_plural_ordered(eq_dataframe),
    )


@pytest.mark.xfail(raises=lib.transforms.ErrorTransformUnsupportPartial)
def test_nonexist_inverse(*, tmpdir: LocalPath) -> None:
    r"""
    Test transformation for count encoding with inversion on non-existing encodings.

    Args
    ----
    - tmpdir
        Temporary directory for this test.
        It is automatically provided by PyTest, so its value should not be explicitly defined.

    Returns
    -------
    """
    # Initialize testing transformation.
    root = str(tmpdir)
    factory = lib.transforms.FactoryTransform()

    # Generate inputs and outputs.
    unk = 0
    ood = None
    example, input, output, _ = synthesize(unk=unk, ood=ood)
    _, continuous, _, _ = output
    series = continuous.iloc[-1].copy()
    for name in series.index:
        # Change encoding dimensions to arbitrary values to create a non-existing encoding.
        if "categorical" in name:
            # Infinite is always an invalid encoding value.
            series[name] = float("inf")
    continuous.iloc[-1] = series

    # Run test template.
    template_test_transform(
        root,
        IDENTIFIER,
        factory,
        example,
        input,
        output,
        to_eq_plural_ordered(eq_dataframe),
        to_eq_plural_ordered(eq_dataframe),
        require_test_transform=False,
        require_test_transform_=False,
        fit_kwargs=dict(unk=unk, ood=ood),
    )


@pytest.mark.xfail(raises=lib.transforms.ErrorTransformUnsupportPartial)
def test_ambiguous_inverse(*, tmpdir: LocalPath) -> None:
    r"""
    Test transformation for count encoding with inversion with ambiguous parameters.

    Args
    ----
    - tmpdir
        Temporary directory for this test.
        It is automatically provided by PyTest, so its value should not be explicitly defined.

    Returns
    -------
    """
    # Initialize testing transformation.
    root = str(tmpdir)
    factory = lib.transforms.FactoryTransform()

    # Generate inputs and outputs.
    categorical = ["X", "X", "Y", "Y"]
    encodings = {"X": {"categorical-count": 2}, "Y": {"categorical-count": 2}}
    input = [
        pd.DataFrame({"categorical": categorical}),
        pd.DataFrame([], columns=[]),
        pd.DataFrame([], columns=[]),
        pd.DataFrame([], columns=[]),
    ]
    output = [
        pd.DataFrame([], columns=[]),
        pd.DataFrame(
            {"categorical-count": [encodings[cell]["categorical-count"] for cell in categorical]}
        ),
        pd.DataFrame([], columns=[]),
        pd.DataFrame([], columns=[]),
    ]
    example: Tuple[Input, Output]
    example = (input, [])

    # Run test template.
    template_test_transform(
        root,
        IDENTIFIER,
        factory,
        example,
        input,
        output,
        to_eq_plural_ordered(eq_dataframe),
        to_eq_plural_ordered(eq_dataframe),
        require_test_transform=False,
        require_test_transform_=False,
        fit_kwargs=dict(avoid_collide=False),
    )
