# Import Python packages.
from typing import Any, Callable, List, Mapping, Sequence, Tuple, cast

# Import external packages.
import pandas as pd

# Import PyTest packagtes.
import pytest

# Import PyTest external packages.
from py._path.local import LocalPath

# Import developing library.
import fin_tech_py_toolkit as lib

from ...transforms.utils import template_test_io, template_test_transform

# Import testing library.
from ...utils import eq_dataframe, to_eq_data, to_eq_dataset, to_eq_plural_ordered


# Type aliases.
Input = List[lib.datasets.DatasetTabular]
Output = List[lib.datasets.DatasetTabular]


# Runtime constants.
IDENTIFIER = lib.transdatasets.TransdatasetSplitTabular._IDENTIFIER


# Cast comparator.
eq_dataset_tabular = cast(
    Callable[[List[lib.datasets.DatasetTabular], List[lib.datasets.DatasetTabular]], bool],
    to_eq_plural_ordered(to_eq_dataset(to_eq_data(eq_dataframe))),
)


def synthesize() -> Tuple[Tuple[Input, Output], Input, Output, Mapping[str, Any]]:
    r"""
    Synthesize test I/O.

    Args
    ----

    Returns
    -------
    - example
        Input and output examples.
    - input
        Input case.
    - output
        Output case.
    - supplement
        Supplementary materical for synthesized test.
    """
    # Create a simple example.
    dataframe = pd.DataFrame({"int": [1, 1, 1, 1, 0], "float": [1.0, 2.0, 3.0, 4.0, 0.0]})
    dataframe_head = pd.DataFrame({"int": [1, 1, 0], "float": [1.0, 4.0, 0.0]})
    dataframe_tail = pd.DataFrame({"int": [1, 1], "float": [2.0, 3.0]})
    input: Input
    input = [
        lib.datasets.DatasetTabularSimple.from_memalias(
            [lib.data.DataTabular(dataframe, sort_columns="alphabetic", sort_rows="rankable")],
            ["full"],
            sorts=("alphabetic", "rankable"),
        )
    ]
    output: Output
    output = [
        lib.datasets.DatasetTabularSimple.from_memalias(
            [lib.data.DataTabular(dataframe_head, sort_columns="alphabetic", sort_rows="rankable")],
            ["full"],
            sorts=("alphabetic", "rankable"),
        ),
        lib.datasets.DatasetTabularSimple.from_memalias(
            [lib.data.DataTabular(dataframe_tail, sort_columns="alphabetic", sort_rows="rankable")],
            ["full"],
            sorts=("alphabetic", "rankable"),
        ),
    ]
    return (input, output), input, output, {}


@pytest.mark.parametrize(
    ("raw_input", "raw_output"),
    [
        pytest.param(
            ...,
            None,
            id="unsupport-input",
            marks=[pytest.mark.xfail(raises=lib.transforms.ErrorTransformUnsupportPartial)],
        ),
        pytest.param(
            None,
            ...,
            id="unsupport-output",
            marks=[pytest.mark.xfail(raises=lib.transforms.ErrorTransformUnsupportPartial)],
        ),
        pytest.param(None, None, id="both-null"),
    ],
)
def test_io(*, raw_input: Any, raw_output: Any) -> None:
    r"""
    Test transformation input and output domain formalization.

    Args
    ----
    - raw_input
        Raw input.
    - raw_output
        Raw output.

    Returns
    -------
    """
    # Initialize testing transformation.
    factory = lib.transforms.FactoryTransform()

    # Run test template.
    template_test_io(
        IDENTIFIER,
        factory,
        raw_input,
        raw_output,
        eq_dataset_tabular,
        eq_dataset_tabular,
        init_kwargs=dict(allow_alias=False),
        fit_kwargs=dict(sorts=None, groupbys=[]),
        transform_kwargs=dict(props=[(["full"], (1, 1))]),
    )


@pytest.mark.parametrize(
    "inverse_kwargs",
    [
        pytest.param({}, id="common"),
        pytest.param(
            dict(sort_columns="identity", sort_rows="identity"),
            id="ambiguous-inverse",
            marks=[pytest.mark.xfail(raises=AssertionError)],
        ),
    ],
)
def test_default(*, tmpdir: LocalPath, inverse_kwargs: Mapping[str, Any]) -> None:
    r"""
    Test split dataset transformation for tabular dataset.

    Args
    ----
    - tmpdir
        Temporary directory for this test.
        It is automatically provided by PyTest, so its value should not be explicitly defined.
    - inverse_kwargs
        Keyword arguments for testing inversion.

    Returns
    -------
    """
    # Initialize testing transformation.
    root = str(tmpdir)
    factory = lib.transforms.FactoryTransform()

    # Generate inputs and outputs.
    example, input, output, supplement = synthesize()

    # Run test template.
    template_test_transform(
        root,
        IDENTIFIER,
        factory,
        example,
        input,
        output,
        eq_dataset_tabular,
        eq_dataset_tabular,
        init_kwargs=dict(allow_alias=False),
        fit_kwargs=dict(sorts=("alphabetic", "rankable"), groupbys=["int"]),
        transform_kwargs=dict(props=[(["full"], (1, 1))]),
        inverse_kwargs=inverse_kwargs,
    )


def test_empty(*, tmpdir: LocalPath) -> None:
    r"""
    Test split dataset transformation for empty tabular dataset.

    Args
    ----
    - tmpdir
        Temporary directory for this test.
        It is automatically provided by PyTest, so its value should not be explicitly defined.

    Returns
    -------
    """
    # Initialize testing transformation.
    root = str(tmpdir)
    factory = lib.transforms.FactoryTransform()

    # Generate inputs and outputs.
    transform = factory.from_args(IDENTIFIER)
    input = transform.input(None)
    output = transform.output(None)
    example: Tuple[Input, Output]
    example = (input, [])

    # Run test template.
    template_test_transform(
        root,
        IDENTIFIER,
        factory,
        example,
        input,
        output,
        eq_dataset_tabular,
        eq_dataset_tabular,
        init_kwargs=dict(allow_alias=False),
        fit_kwargs=dict(sorts=None, groupbys=[]),
        transform_kwargs=dict(props=[(["full"], (1, 1))]),
    )


@pytest.mark.parametrize(
    ("major", "minor"),
    [
        pytest.param([], [], id="(0,0)"),
        pytest.param([0], [], id="(1,0)"),
        pytest.param([1, 0], [], id="(2,0)"),
        pytest.param([2, 1, 0], [], id="(3,0)"),
        pytest.param([3, 0], [2, 1], id="(2,2)"),
        pytest.param([4, 0, 2], [3, 1], id="(3,2)"),
        pytest.param([5, 3, 0, 2], [4, 1], id="(4,2)"),
        pytest.param([6, 4, 3, 0, 2], [5, 1], id="(5,2)"),
        pytest.param([7, 4, 5, 0, 3, 2], [6, 1], id="(6,2)"),
        pytest.param([8, 5, 6, 4, 0, 3, 2], [7, 1], id="(7,2)"),
        pytest.param([9, 6, 7, 0, 3, 2], [8, 5, 1, 4], id="(6,4)"),
        pytest.param([10, 7, 8, 5, 0, 3, 1], [9, 6, 1, 4], id="(7,4)"),
        pytest.param([11, 8, 9, 6, 0, 3, 2, 5], [10, 7, 1, 4], id="(8,4)"),
        pytest.param([12, 9, 10, 7, 6, 0, 3, 2, 5], [11, 8, 1, 4], id="(9,4)"),
    ],
)
def test_split_size(*, major: Sequence[int], minor: Sequence[int]) -> None:
    r"""
    Test unravel dataset transformation for tabular dataset.

    Args
    ----
    - major
        Integers in major half.
    - minor
        Integers in minor half.

    Returns
    -------
    """
    # Create the split transformation.
    factory = lib.transforms.FactoryTransform()
    transform = factory.from_args(IDENTIFIER, allow_alias=False)

    # Create a simple dataset.
    dataset = lib.datasets.DatasetTabularSimple.from_memalias(
        [
            lib.data.DataTabular(
                pd.DataFrame({"int": [*major, *minor]}),
                sort_columns="alphabetic",
                sort_rows="rankable",
            )
        ],
        ["full"],
        sorts=("alphabetic", "rankable"),
    )
    dataset_major = lib.datasets.DatasetTabularSimple.from_memalias(
        [
            lib.data.DataTabular(
                pd.DataFrame({"int": major}), sort_columns="alphabetic", sort_rows="rankable"
            )
        ],
        ["full"],
        sorts=("alphabetic", "rankable"),
    )
    dataset_minor = lib.datasets.DatasetTabularSimple.from_memalias(
        [
            lib.data.DataTabular(
                pd.DataFrame({"int": minor}), sort_columns="alphabetic", sort_rows="rankable"
            )
        ],
        ["full"],
        sorts=("alphabetic", "rankable"),
    )
    input: List[lib.datasets.DatasetTabular]
    input = [dataset]
    output: List[lib.datasets.DatasetTabular]
    output = [dataset_major, dataset_minor]

    # Make the transformation on various dataset sizes to be splitted.
    output_ = transform.fit_transform((input, []), input, props=[(["full"], (2, 1))])
    assert eq_dataset_tabular(output, output_)
