# Import Python packages.
import functools
import importlib.util as imputil
import os
import shutil
import sys
from types import ModuleType
from typing import Any, List, Mapping, Optional, Sequence, Tuple

# Import external packages.
import pandas as pd

# Import developing packages.
import fin_tech_py_toolkit as lib


# Runtime constants.
NAME_DATASET = "albert"
COLUMNS_FEATURE_DATASET = [str(i) for i in range(78)]
COLUMNS_TARGET_DATASET = ["label"]
SORT_COLUMNS_DATASET = f"supervised-{NAME_DATASET:s}"
ADDRESSES = [("full", os.path.join(NAME_DATASET, "albert_exp.data"))]
READ_ARGS: Sequence[Tuple[Sequence[str], Sequence[Any]]]
READ_ARGS = [(["full"], [])]
READ_KWARGS: Sequence[Tuple[Sequence[str], Mapping[str, Any]]]
READ_KWARGS = [
    (["full"], dict(names=COLUMNS_FEATURE_DATASET + COLUMNS_TARGET_DATASET, header=None))
]
SORTS = (SORT_COLUMNS_DATASET, "rankable")
PROP_TRAIN = 7
PROP_VALID = 1
PROP_TEST = 2
PROPS_INFER = [(["full"], (PROP_TRAIN + PROP_VALID, PROP_TEST))]
PROPS_LEARN = [(["full"], (PROP_TRAIN, PROP_VALID))]
COLUMNS_CATEGORICAL = [
    "label",
    "13",
    "14",
    "15",
    "16",
    "17",
    "18",
    "19",
    "20",
    "21",
    "22",
    "23",
    "24",
    "25",
    "26",
    "27",
    "28",
    "29",
    "30",
    "31",
    "32",
    "33",
    "34",
    "35",
    "36",
    "37",
    "38",
    "40",
    "43",
    "44",
    "45",
    "46",
    "47",
    "48",
    "53",
    "54",
    "55",
    "56",
    "57",
    "59",
    "60",
    "61",
    "62",
    "64",
    "65",
    "67",
    "69",
    "70",
    "72",
    "73",
    "75",
    "76",
    "77",
]
TARGET_TO_INT = {"0": 0, "1": 1}


def target_to_int(cell: Any, /) -> int:
    r"""
    Convert target category into an integer label.

    Args
    ----
    - cell
        A categorical cell value.

    Returns
    -------
    - label
        An integer label.
    """
    # Use defined mapping directly.
    return TARGET_TO_INT[str(cell)]


# Register tabular data disambiguition sorting algorithms.
lib.data.DataTabular.register_sort(
    functools.partial(
        lib.data.DataTabular.get_sort("columns", "alphabetic"),
        groups=[COLUMNS_TARGET_DATASET, None] if COLUMNS_TARGET_DATASET else [None],
    ),
    "columns",
    SORT_COLUMNS_DATASET,
)


def rcimport(relpath: str, /) -> ModuleType:
    r"""
    Runtime command import.

    Args
    ----
    - relpath
        Relative path of importing module w.r.t. current file.

    Returns
    -------
    - module
        Module.
    """
    # Load module from path.
    path = os.path.join(os.path.abspath(os.path.join(os.path.dirname(__file__), relpath)))
    name, _ = os.path.splitext(os.path.basename(path))
    spec = imputil.spec_from_file_location(name, path)
    assert spec is not None
    module = imputil.module_from_spec(spec)
    sys.modules[name] = module
    loader = spec.loader
    assert loader is not None
    loader.exec_module(module)
    return module


def load(
    root_cache: str, root_data: str, /, props_infer: Optional[Tuple[int, int]] = None
) -> List[pd.DataFrame]:
    r"""
    Load dataset into memory.

    Args
    ----
    - root_cache
        Cache root directory.
    - root_data
        Dataset root directory.
    - props_infer
        Inference split proportions.

    Returns
    -------
    - memory
        Dataset in memory.
    """
    # Clean cache directories.
    if os.path.isdir(root_cache):
        # Each experiment cache should be independent, thus all existing content in the cache should
        # be totally removed.
        shutil.rmtree(root_cache)

    # Load dataset.
    dataset = lib.datasets.DatasetTabularSimple.from_storage(
        [(name, os.path.join(root_data, relpath)) for name, relpath in ADDRESSES],
        cache_prefix=root_cache,
        cache_suffix=NAME_DATASET,
        link=True,
        cache_read=False,
        read_args=READ_ARGS,
        read_kwargs=READ_KWARGS,
        sorts=SORTS,
    )

    # Memory formalization.
    split = lib.transdatasets.TransdatasetSplitTabular(
        cache_prefix=root_cache, cache_suffix=NAME_DATASET, allow_alias=False
    )
    dataset_tune, dataset_test = split.fit_transform(
        ([dataset], []),
        [dataset],
        [],
        props=[(names, props_infer) for names, _ in PROPS_INFER] if props_infer else PROPS_INFER,
    )
    dataset_train, dataset_valid = split.transform([dataset_tune], [], props=PROPS_LEARN)
    unravel = lib.transdatasets.TransdatasetUnravelTabular(
        cache_prefix=root_cache, cache_suffix=NAME_DATASET, allow_alias=True
    )
    memory = [
        pd.concat(unravel.fit_transform(([dataset_train], []), [dataset_train], []), axis=0),
        pd.concat(unravel.transform([dataset_valid], []), axis=0),
        pd.concat(unravel.transform([dataset_test], []), axis=0),
    ]
    assert len(memory) == 3
    tabularize = lib.transforms.TransformTabularize(
        cache_prefix=root_cache, cache_suffix=NAME_DATASET, allow_alias=True
    )
    memory = [
        *tabularize.fit_transform(
            ([memory[0]], []), [memory[0]], [], discretizable=COLUMNS_CATEGORICAL
        ),
        *tabularize.transform([memory[1]], []),
        *tabularize.transform([memory[2]], []),
    ]
    assert len(memory) == 6
    memory[0] = memory[0].applymap(str)
    memory[1] = memory[1].applymap(float)
    memory[2] = memory[2].applymap(str)
    memory[3] = memory[3].applymap(float)
    memory[4] = memory[4].applymap(str)
    memory[5] = memory[5].applymap(float)
    memory = [
        *(
            memory[0][list(sorted(set(memory[0].columns) - set(COLUMNS_TARGET_DATASET)))],
            memory[1][list(sorted(set(memory[1].columns) - set(COLUMNS_TARGET_DATASET)))],
            memory[0][list(sorted(set(memory[0].columns) & set(COLUMNS_TARGET_DATASET)))],
            memory[1][list(sorted(set(memory[1].columns) & set(COLUMNS_TARGET_DATASET)))],
        ),
        *(
            memory[2][list(sorted(set(memory[2].columns) - set(COLUMNS_TARGET_DATASET)))],
            memory[3][list(sorted(set(memory[3].columns) - set(COLUMNS_TARGET_DATASET)))],
            memory[2][list(sorted(set(memory[2].columns) & set(COLUMNS_TARGET_DATASET)))],
            memory[3][list(sorted(set(memory[3].columns) & set(COLUMNS_TARGET_DATASET)))],
        ),
        *(
            memory[4][list(sorted(set(memory[4].columns) - set(COLUMNS_TARGET_DATASET)))],
            memory[5][list(sorted(set(memory[5].columns) - set(COLUMNS_TARGET_DATASET)))],
            memory[4][list(sorted(set(memory[4].columns) & set(COLUMNS_TARGET_DATASET)))],
            memory[5][list(sorted(set(memory[5].columns) & set(COLUMNS_TARGET_DATASET)))],
        ),
    ]
    return memory


# Main program.
if __name__ == "__main__":
    # Output memory.
    root = rcimport("../root.py")
    if not os.path.isfile(os.path.join(root.ROOT, "data", NAME_DATASET, "albert_exp.data")):
        # Raw dataset need future processing.
        # Only training data is provided with external labels, thus we will ignore validation and
        # test data in raw dataset for experiment.
        features = pd.read_csv(
            os.path.join(root.ROOT, "data", NAME_DATASET, "albert_train.data"),
            sep=r"\s+",
            names=COLUMNS_FEATURE_DATASET + COLUMNS_TARGET_DATASET,
            header=None,
        )
        labels = pd.read_csv(
            os.path.join(root.ROOT, "data", NAME_DATASET, "albert_train.solution"),
            sep=r"\s+",
            names=COLUMNS_TARGET_DATASET,
            header=None,
        )
        features["label"] = labels["label"]
        features.to_csv(
            os.path.join(root.ROOT, "data", NAME_DATASET, "albert_exp.data"),
            index=False,
            header=False,
        )
    memory = load(os.path.join(root.ROOT, "cache"), os.path.join(root.ROOT, "data"))
    print(pd.concat([memory[2], memory[3], memory[0], memory[1]], axis=1))
    print(pd.concat([memory[6], memory[7], memory[4], memory[5]], axis=1))
    print(pd.concat([memory[10], memory[11], memory[8], memory[9]], axis=1))
    assert len(COLUMNS_TARGET_DATASET) == 1
    stats = memory[2][COLUMNS_TARGET_DATASET[0]].map(target_to_int).value_counts().to_dict()
    assert set(stats.keys()) == {0, 1}
    with open(
        os.path.abspath(os.path.join(os.path.dirname(__file__), f"{NAME_DATASET:s}.txt")), "w"
    ) as file:
        # Save essential statistics.
        file.write(str(stats[1] / stats[0]) + "\n")
        file.write(str(stats[0] + stats[1]) + "\n")
        file.write(str(stats[0]) + "\n")
        file.write(str(stats[1]) + "\n")
