# Import Python packages.
import functools
import importlib.util as imputil
import os
import shutil
import sys
from types import ModuleType
from typing import Any, List, Mapping, Optional, Sequence, Tuple

# Import external packages.
import pandas as pd

# Import developing packages.
import fin_tech_py_toolkit as lib


# Runtime constants.
NAME_DATASET = "adult"
COLUMNS_FEATURE_DATASET = [
    "age",
    "workclass",
    "fnlwgt",
    "education",
    "education-num",
    "marital-status",
    "occupation",
    "relationship",
    "race",
    "sex",
    "capital-gain",
    "capital-loss",
    "hours-per-week",
    "native-country",
]
COLUMNS_TARGET_DATASET = ["label"]
SORT_COLUMNS_DATASET = f"supervised-{NAME_DATASET:s}"
ADDRESSES = [
    ("tune", os.path.join(NAME_DATASET, "adult.data")),
    ("test", os.path.join(NAME_DATASET, "adult.test")),
]
READ_ARGS: Sequence[Tuple[Sequence[str], Sequence[Any]]]
READ_ARGS = [(["tune", "test"], [])]
READ_KWARGS: Sequence[Tuple[Sequence[str], Mapping[str, Any]]]
READ_KWARGS = [
    (["tune"], dict(names=COLUMNS_FEATURE_DATASET + COLUMNS_TARGET_DATASET)),
    (["test"], dict(names=COLUMNS_FEATURE_DATASET + COLUMNS_TARGET_DATASET, skiprows=1)),
]
SORTS = (SORT_COLUMNS_DATASET, "rankable")
PROP_TRAIN = 7
PROP_VALID = 1
PROPS_INFER = [(["tune"], (1, 0)), (["test"], (0, 1))]
PROPS_LEARN = [(["tune", "test"], (PROP_TRAIN, PROP_VALID))]
COLUMNS_CATEGORICAL = [
    "label",
    "workclass",
    "education",
    "marital-status",
    "occupation",
    "relationship",
    "race",
    "sex",
    "native-country",
]
TARGET_TO_INT = {" <=50K": 0, " >50K": 1, " <=50K.": 0, " >50K.": 1}


def target_to_int(cell: Any, /) -> int:
    r"""
    Convert target category into an integer label.

    Args
    ----
    - cell
        A categorical cell value.

    Returns
    -------
    - label
        An integer label.
    """
    # Use defined mapping directly.
    return TARGET_TO_INT[str(cell)]


# Register tabular data disambiguition sorting algorithms.
lib.data.DataTabular.register_sort(
    functools.partial(
        lib.data.DataTabular.get_sort("columns", "alphabetic"),
        groups=[COLUMNS_TARGET_DATASET, None] if COLUMNS_TARGET_DATASET else [None],
    ),
    "columns",
    SORT_COLUMNS_DATASET,
)


def rcimport(relpath: str, /) -> ModuleType:
    r"""
    Runtime command import.

    Args
    ----
    - relpath
        Relative path of importing module w.r.t. current file.

    Returns
    -------
    - module
        Module.
    """
    # Load module from path.
    path = os.path.join(os.path.abspath(os.path.join(os.path.dirname(__file__), relpath)))
    name, _ = os.path.splitext(os.path.basename(path))
    spec = imputil.spec_from_file_location(name, path)
    assert spec is not None
    module = imputil.module_from_spec(spec)
    sys.modules[name] = module
    loader = spec.loader
    assert loader is not None
    loader.exec_module(module)
    return module


def load(
    root_cache: str, root_data: str, /, props_infer: Optional[Tuple[int, int]] = None
) -> List[pd.DataFrame]:
    r"""
    Load dataset into memory.

    Args
    ----
    - root_cache
        Cache root directory.
    - root_data
        Dataset root directory.
    - props_infer
        Inference split proportions.

    Returns
    -------
    - memory
        Dataset in memory.
    """
    # Clean cache directories.
    if os.path.isdir(root_cache):
        # Each experiment cache should be independent, thus all existing content in the cache should
        # be totally removed.
        shutil.rmtree(root_cache)

    # Load dataset.
    dataset = lib.datasets.DatasetTabularSimple.from_storage(
        [(name, os.path.join(root_data, relpath)) for name, relpath in ADDRESSES],
        cache_prefix=root_cache,
        cache_suffix=NAME_DATASET,
        link=True,
        cache_read=False,
        read_args=READ_ARGS,
        read_kwargs=READ_KWARGS,
        sorts=SORTS,
    )

    # Memory formalization.
    split = lib.transdatasets.TransdatasetSplitTabular(
        cache_prefix=root_cache, cache_suffix=NAME_DATASET, allow_alias=False
    )
    dataset_tune, dataset_test = split.fit_transform(
        ([dataset], []),
        [dataset],
        [],
        props=[(names, props_infer) for names, _ in PROPS_INFER] if props_infer else PROPS_INFER,
    )
    dataset_train, dataset_valid = split.transform([dataset_tune], [], props=PROPS_LEARN)
    unravel = lib.transdatasets.TransdatasetUnravelTabular(
        cache_prefix=root_cache, cache_suffix=NAME_DATASET, allow_alias=True
    )
    memory = [
        pd.concat(unravel.fit_transform(([dataset_train], []), [dataset_train], []), axis=0),
        pd.concat(unravel.transform([dataset_valid], []), axis=0),
        pd.concat(unravel.transform([dataset_test], []), axis=0),
    ]
    assert len(memory) == 3
    tabularize = lib.transforms.TransformTabularize(
        cache_prefix=root_cache, cache_suffix=NAME_DATASET, allow_alias=True
    )
    memory = [
        *tabularize.fit_transform(
            ([memory[0]], []), [memory[0]], [], discretizable=COLUMNS_CATEGORICAL
        ),
        *tabularize.transform([memory[1]], []),
        *tabularize.transform([memory[2]], []),
    ]
    assert len(memory) == 6
    memory[0] = memory[0].applymap(str)
    memory[1] = memory[1].applymap(float)
    memory[2] = memory[2].applymap(str)
    memory[3] = memory[3].applymap(float)
    memory[4] = memory[4].applymap(str)
    memory[5] = memory[5].applymap(float)
    memory = [
        *(
            memory[0][list(sorted(set(memory[0].columns) - set(COLUMNS_TARGET_DATASET)))],
            memory[1][list(sorted(set(memory[1].columns) - set(COLUMNS_TARGET_DATASET)))],
            memory[0][list(sorted(set(memory[0].columns) & set(COLUMNS_TARGET_DATASET)))],
            memory[1][list(sorted(set(memory[1].columns) & set(COLUMNS_TARGET_DATASET)))],
        ),
        *(
            memory[2][list(sorted(set(memory[2].columns) - set(COLUMNS_TARGET_DATASET)))],
            memory[3][list(sorted(set(memory[3].columns) - set(COLUMNS_TARGET_DATASET)))],
            memory[2][list(sorted(set(memory[2].columns) & set(COLUMNS_TARGET_DATASET)))],
            memory[3][list(sorted(set(memory[3].columns) & set(COLUMNS_TARGET_DATASET)))],
        ),
        *(
            memory[4][list(sorted(set(memory[4].columns) - set(COLUMNS_TARGET_DATASET)))],
            memory[5][list(sorted(set(memory[5].columns) - set(COLUMNS_TARGET_DATASET)))],
            memory[4][list(sorted(set(memory[4].columns) & set(COLUMNS_TARGET_DATASET)))],
            memory[5][list(sorted(set(memory[5].columns) & set(COLUMNS_TARGET_DATASET)))],
        ),
    ]
    return memory


# Main program.
if __name__ == "__main__":
    # Output memory.
    root = rcimport("../root.py")
    memory = load(os.path.join(root.ROOT, "cache"), os.path.join(root.ROOT, "data"))
    print(pd.concat([memory[2], memory[3], memory[0], memory[1]], axis=1))
    print(pd.concat([memory[6], memory[7], memory[4], memory[5]], axis=1))
    print(pd.concat([memory[10], memory[11], memory[8], memory[9]], axis=1))
    assert len(COLUMNS_TARGET_DATASET) == 1
    stats = memory[2][COLUMNS_TARGET_DATASET[0]].map(target_to_int).value_counts().to_dict()
    assert set(stats.keys()) == {0, 1}
    with open(
        os.path.abspath(os.path.join(os.path.dirname(__file__), f"{NAME_DATASET:s}.txt")), "w"
    ) as file:
        # Save essential statistics.
        file.write(str(len(memory[0].columns)) + "\n")
        file.write(str(len(memory[1].columns)) + "\n")
        file.write(str(stats[1] / stats[0]) + "\n")
        file.write(str(stats[0] + stats[1]) + "\n")
        file.write(str(stats[0]) + "\n")
        file.write(str(stats[1]) + "\n")
