# Import Python packages.
from typing import Any, Callable, Sequence, TypeVar, cast

# Import external packages.
import numpy as np
import pandas as pd

# Import developing library.
import fin_tech_py_toolkit as lib


# Type variables.
VarTA = TypeVar("VarTA")
VarTB = TypeVar("VarTB")


def eq_type(a: type, b: type, /) -> bool:
    r"""
    Compare two types.

    Args
    ----
    - a
        One argument.
    - b
        The other argument.

    Returns
    -------
    - flag
        If two arguments are the same.
    """
    # Exact match is not recommended for types.
    return a is b and b is a


def eq_scalar(a: Any, b: Any, /, *, loose: bool = True) -> bool:
    r"""
    Compare two dataframes.

    Args
    ----
    - a
        One argument.
    - b
        The other argument.
    - loose
        If True, allow loose equality comparison for numeric values.

    Returns
    -------
    - flag
        If two arguments are the same.
    """
    # NaN comparison is a special corner case.
    if a != a and b != b:
        # Two NaN values are treated as the same thing.
        return True

    # Value types will influence comparison criterion.
    if loose and pd.api.types.is_numeric_dtype(type(a)) and pd.api.types.is_numeric_dtype(type(b)):
        # For comparing both numeric values, we allow loose comparison.
        if not np.isclose(a, b):
            # Early stop.
            return False
    else:
        # Otherwise, the comparison should be strict.
        if a != b:
            # Early stop.
            return False
    return True


def eq_sequence(a: Sequence[Any], b: Sequence[Any], /, *, loose: bool = True) -> bool:
    r"""
    Compare two sequences.

    Args
    ----
    - a
        One argument.
    - b
        The other argument.
    - loose
        If True, allow loose equality comparison for numeric values.

    Returns
    -------
    - flag
        If two arguments are the same.
    """
    # Equal series must have same number of values.
    if len(a) != len(b):
        # Early stop.
        return False

    # Compare every value.
    return all(eq_scalar(it_a, it_b, loose=loose) for it_a, it_b in zip(a, b))


def eq_dataframe(a: pd.DataFrame, b: pd.DataFrame, /, *, loose: bool = True) -> bool:
    r"""
    Compare two dataframes.

    Args
    ----
    - a
        One argument.
    - b
        The other argument.
    - loose
        If True, allow loose equality comparison for numeric values.

    Returns
    -------
    - flag
        If two arguments are the same.
    """
    # Equal dataframes must have same column titles.
    if set(a.columns) != set(b.columns):
        # Early stop.
        return False
    columns = list(sorted(set(a.columns) | set(b.columns)))

    # Equal dataframes must have same number of rows.
    if len(a) != len(b):
        # Early stop.
        return False

    # Use alphabetical column order and range row order for disambiguation.
    a = a[columns].reset_index(drop=True)
    b = b[columns].reset_index(drop=True)

    # Compare every column series between two dataframes.
    return all(
        eq_sequence(cast(Sequence[Any], a[name]), cast(Sequence[Any], b[name]), loose=loose)
        for name in columns
    )


def to_eq_plural_ordered(
    eq_item: Callable[[VarTA, VarTB], bool], /  # noqa: W504
) -> Callable[[Sequence[VarTA], Sequence[VarTB]], bool]:
    r"""
    Convert item comparator to item collection comparator.

    Args
    ----
    - eq_item
        Item comparator.

    Returns
    -------
    - eq_plural
        Item collection comparator.
    """

    def eq_plural(a: Sequence[VarTA], b: Sequence[VarTB], /) -> bool:
        r"""
        Compare two item collections.

        Args
        ----
        - a
            One argument.
        - b
            The other argument.

        Returns
        -------
        - flag
            If two arguments are the same.
        """
        # Equal dataframe lists must have same length.
        if len(a) != len(b):
            # Early stop.
            return False
        else:
            # Equal ordered collections must have exact matching items at the same position.
            return all(
                eq_type(type(it_a), type(it_b)) and eq_item(it_a, it_b) for it_a, it_b in zip(a, b)
            )

    # Return decorated function.
    return eq_plural


def to_eq_data(
    eq_content: Callable[[VarTA, VarTB], bool], /  # noqa: W504
) -> Callable[[lib.data.BaseData[VarTA], lib.data.BaseData[VarTB]], bool]:
    r"""
    Convert content comparator to data container comparator.

    Args
    ----
    - eq_content
        Content comparator.

    Returns
    -------
    - eq_data
        Data container comparator.
    """

    def eq_data(a: lib.data.BaseData[VarTA], b: lib.data.BaseData[VarTB], /) -> bool:
        r"""
        Compare two data containers.

        Args
        ----
        - a
            One argument.
        - b
            The other argument.

        Returns
        -------
        - flag
            If two arguments are the same.
        """
        # For safe comparison, we only allow comparation between same content types.
        return eq_type(type(a._content), type(b._content)) and eq_content(a._content, b._content)

    # Return decorated function.
    return eq_data


def to_eq_dataset(
    eq_data: Callable[[lib.data.BaseData[VarTA], lib.data.BaseData[VarTB]], bool], /  # noqa: W504
) -> Callable[
    [
        lib.datasets.BaseDataset[lib.data.BaseData[VarTA]],
        lib.datasets.BaseDataset[lib.data.BaseData[VarTB]],
    ],
    bool,
]:
    r"""
    Convert content comparator to dataset comparator.

    Args
    ----
    - eq_data
        Data container comparator.

    Returns
    -------
    - eq_dataset
        Dataset comparator.
    """
    # Get memory comparator.
    eq_memory = to_eq_plural_ordered(eq_data)

    def eq_dataset(
        a: lib.datasets.BaseDataset[lib.data.BaseData[VarTA]],
        b: lib.datasets.BaseDataset[lib.data.BaseData[VarTB]],
        /,
    ) -> bool:
        r"""
        Compare two datasets.

        Args
        ----
        - a
            One argument.
        - b
            The other argument.

        Returns
        -------
        - flag
            If two arguments are the same.
        """
        # We only need to compare dataset memory, and other auxiliary information will be ignored.
        return eq_memory(a.memory, b.memory) and tuple(a.memory_names) == tuple(b.memory_names)

    # Return decorated function.
    return eq_dataset
