# Import Python packages.
import importlib.util as imputil
import itertools
import os
import sys
from types import ModuleType
from typing import Any, List, Mapping, Protocol, Sequence, Tuple

# Import external packages.
import more_itertools as xitertools
import numpy as np
import pandas as pd

# Import developing library.
import fin_tech_py_toolkit as lib


# Considering datasets.
SMALL = [
    "adult",
    "ai4i_2020_predictive_maintenance_dataset",
    "bank_marketing",
    "blastchar",
    "estimation_of_obesity_levels_based_on_eating_habits_and_physical_condition",
    "insurance_claims",
    "iranian_churn_dataset",
    "jasmine",
    "online_shoppers_purchasing_intention_dataset",
    "qsar_biodegradation",
    "seismic_bumps",
    "shrutime",
    "vehicle_insurance",
]
LARGE = ["vehicle_claims"]


def rcimport(relpath: str, /) -> ModuleType:
    r"""
    Runtime command import.

    Args
    ----
    - relpath
        Relative path of importing module w.r.t. current file.

    Returns
    -------
    - module
        Module.
    """
    # Load module from path.
    path = os.path.join(os.path.abspath(os.path.join(os.path.dirname(__file__), relpath)))
    name, _ = os.path.splitext(os.path.basename(path))
    spec = imputil.spec_from_file_location(name, path)
    assert spec is not None
    module = imputil.module_from_spec(spec)
    sys.modules[name] = module
    loader = spec.loader
    assert loader is not None
    loader.exec_module(module)
    return module


class Instance(Protocol):
    r"""
    A ML instance.
    """

    def __call__(
        self: Any,
        train: Tuple[pd.DataFrame, pd.DataFrame],
        valid: Tuple[pd.DataFrame, pd.DataFrame],
        test: Tuple[pd.DataFrame, pd.DataFrame],
        /,
        *,
        weight_pos: float = 1.0,
    ) -> Mapping[str, Any]:
        r"""
        Experiment instance.

        Args
        ----
        - train
            Training features and labels.
        - valid
            Validation features and labels.
        - test
            Test features and labels.
        - weight_pos
            Positive label weight.

        Returns
        -------
        - profile
            Experiment profile.
        """


# Get ML instances.
instance_ = rcimport(os.path.join("..", "mlinstance.py"))
instances: Mapping[str, Instance]
instances = {
    "xgboost": instance_.instance_xgboost,
    "catboost": instance_.instance_catboost,
    "lightgbm": instance_.instance_lightgbm,
}
hyperparameter_ = rcimport("hyperparameter_transfer.py")
hyperparameters: Mapping[str, Mapping[str, Any]]
hyperparameters = hyperparameter_.hyperparameters
focuses: Sequence[Tuple[str, Sequence[str]]]
focuses = [
    ("cca", ["pca", "featagglo", "identity"]),
    ("cca-degree", ["pca", "featagglo", "identity"]),
    ("cca-distribute", ["pca", "featagglo", "identity"]),
    ("cca-range", ["pca", "featagglo", "identity"]),
    ("count", ["pca", "featagglo", "identity"]),
    ("sdv", ["pca", "featagglo", "identity"]),
    ("catboost", ["pca", "featagglo", "identity"]),
    ("discard", ["pca", "featagglo", "identity"]),
]
encodes_only_once = ["catboost", "discard"]


def supervised_learning(
    root_cache: str, root_data: str, name_dateset: str, /  # noqa: W504
) -> Sequence[Mapping[str, Any]]:
    r"""
    Perform supervised learning.

    Args
    ----
    - root_cache
        Cache root directory.
    - root_data
        Dataset root directory.
    - name_dataset
        Name of dataset.

    Returns
    -------
    - profile
        Experiment profile.
    """
    # Load dataset.
    if os.path.isdir(os.path.join(root_cache, "lv0")):
        # Clean cache.
        lib.io.rmtree(os.path.join(root_cache, "lv0"))
    dataset_ = rcimport(os.path.join("..", "data", f"{name_dataset:s}.py"))
    dataset = dataset_.load(os.path.join(root_cache, "lv0"), root_data, props_infer=(1, 1))
    rng = np.random.RandomState(42)
    memory = [
        *(
            dataset[0].applymap(lambda cell: "0:" + str(cell)),
            dataset[1].applymap(float),
            dataset[2].applymap(dataset_.target_to_int),
            dataset[3].applymap(float),
        ),
        *(
            dataset[4].applymap(lambda cell: "0:" + str(cell)),
            dataset[5].applymap(float),
            dataset[6].applymap(dataset_.target_to_int),
            dataset[7].applymap(float),
        ),
        *(
            dataset[8].applymap(
                lambda cell: (
                    ("1:" if rng.randint(3) < 1 else "2:")
                    if sum(ord(char) for char in str(cell)) % 2 == 0
                    else "3:"
                )
                + str(cell)
            ),
            dataset[9].applymap(float),
            dataset[10].applymap(dataset_.target_to_int),
            dataset[11].applymap(float),
        ),
    ]
    assert len(memory[0]) == len(memory[1]) == len(memory[2]) == len(memory[3])
    assert len(memory[4]) == len(memory[5]) == len(memory[6]) == len(memory[7])
    assert len(memory[8]) == len(memory[9]) == len(memory[10]) == len(memory[11])
    assert len(memory) == 12 and all(isinstance(slot, pd.DataFrame) for slot in memory)

    # A configuration at arbitrary level.
    (supervision,) = dataset_.COLUMNS_TARGET_DATASET
    configs_model = list(
        xitertools.flatten(
            [
                [
                    (
                        model_name,
                        {
                            key: value
                            for key, value in zip(model_kwargs.keys(), model_kwargs_values)
                        },
                    )
                    for model_kwargs_values in itertools.product(*model_kwargs.values())
                ]
                for model_name, model_kwargs in hyperparameters.items()
            ]
        )
    )
    return list(
        xitertools.flatten(
            [
                supervised_learning_level_1(
                    root_cache,
                    root_data,
                    name_dateset,
                    memory,
                    name_encode,
                    names_normalize=["quantile"],
                    supervision=supervision,
                    names_dimereduce=names_dimereduces,
                    configs_model=configs_model,
                )
                for name_encode, names_dimereduces in focuses
            ]
        )
    )


def supervised_learning_level_1(
    root_cache: str,
    root_data: str,
    name_dateset: str,
    memory: List[pd.DataFrame],
    name_encode: str,
    /,
    *args: Any,
    names_normalize: Sequence[str] = [],
    supervision: str = "",
    **kwargs: Any,
) -> Sequence[Mapping[str, Any]]:
    r"""
    Perform supervised learning.

    Args
    ----
    - root_cache
        Cache root directory.
    - root_data
        Dataset root directory.
    - name_dataset
        Name of dataset.
    - memory
        Runtime memory.
    - name_encode
        Name of encoding.
    - names_normalize
        Names of normalization.
    - supervsion
        Supervision column.

    Returns
    -------
    - profile
        Experiment profile.
    """
    # Perform encoding.
    if os.path.isdir(os.path.join(root_cache, "lv1")):
        # Clean cache.
        lib.io.rmtree(os.path.join(root_cache, "lv1"))
    print("transfer", name_dataset, name_encode)
    encode_ = rcimport(os.path.join("..", "encode.py"))
    encode1, encode1_args, encode1_kwargs = encode_.prepare(
        name_encode, os.path.join(root_cache, "lv1", "train"), name_dataset, label=supervision
    )
    encode2, encode2_args, encode2_kwargs = encode_.prepare(
        name_encode, os.path.join(root_cache, "lv1", "test"), name_dataset, label=supervision
    )
    memory = [
        *(
            encode1.fit_transform(
                ([memory[0], memory[1], memory[2], memory[3]], []),
                [memory[0], memory[1], memory[2], memory[3]],
                *encode1_args,
                **encode1_kwargs,
            )
            if encode1.tags.parametric
            else encode1.transform([memory[0], memory[1], memory[2], memory[3]])
        ),
        *encode1.transform([memory[4], memory[5], memory[6], memory[7]]),
        *(
            encode1.transform([memory[8], memory[9], memory[10], memory[11]])
            if name_encode in encodes_only_once
            else (
                encode2.fit_transform(
                    ([memory[8], memory[9], memory[10], memory[11]], []),
                    [memory[8], memory[9], memory[10], memory[11]],
                    *encode2_args,
                    **encode2_kwargs,
                )
                if encode2.tags.parametric
                else encode2.transform([memory[8], memory[9], memory[10], memory[11]])
            )
        ),
    ]
    assert len(memory[0]) == len(memory[1]) == len(memory[2]) == len(memory[3])
    assert len(memory[4]) == len(memory[5]) == len(memory[6]) == len(memory[7])
    assert len(memory[8]) == len(memory[9]) == len(memory[10]) == len(memory[11])
    assert len(memory) == 12 and all(isinstance(slot, pd.DataFrame) for slot in memory)
    return list(
        xitertools.flatten(
            [
                supervised_learning_level_2(
                    root_cache,
                    root_data,
                    name_dateset,
                    memory,
                    name_encode,
                    name_normalize,
                    *args,
                    **kwargs,
                )
                for name_normalize in names_normalize
            ]
        )
    )


def supervised_learning_level_2(
    root_cache: str,
    root_data: str,
    name_dateset: str,
    memory: List[pd.DataFrame],
    name_encode: str,
    name_normalize: str,
    /,
    *args: Any,
    names_dimereduce: Sequence[str] = [],
    **kwargs: Any,
) -> Sequence[Mapping[str, Any]]:
    r"""
    Perform supervised learning.

    Args
    ----
    - root_cache
        Cache root directory.
    - root_data
        Dataset root directory.
    - name_dataset
        Name of dataset.
    - memory
        Runtime memory.
    - name_encode
        Name of encoding.
    - name_normalize
        Names of normalization.
    - names_dimereduce
        Names of dimentionality reduction.

    Returns
    -------
    - profile
        Experiment profile.
    """
    # Perform normalization.
    if os.path.isdir(os.path.join(root_cache, "lv2")):
        # Clean cache.
        lib.io.rmtree(os.path.join(root_cache, "lv2"))
    print("transfer", name_dataset, name_encode, name_normalize)
    normalize_ = rcimport(os.path.join("..", "normalize.py"))
    normalize1, normalize1_args, normalize1_kwargs = normalize_.prepare(
        name_normalize, os.path.join(root_cache, "lv2", "train"), name_dataset
    )
    normalize2, normalize2_args, normalize2_kwargs = normalize_.prepare(
        name_normalize, os.path.join(root_cache, "lv2", "test"), name_dataset
    )
    memory = [
        *(
            memory[0],
            *(
                normalize1.fit_transform(
                    ([memory[1]], []), [memory[1]], *normalize1_args, **normalize1_kwargs
                )
                if normalize1.tags.parametric
                else normalize1.transform([memory[1]])
            ),
            memory[2],
            memory[3],
        ),
        *(memory[4], *normalize1.transform([memory[5]]), memory[6], memory[7]),
        *(
            memory[8],
            *(
                normalize1.transform([memory[9]])
                if name_encode in encodes_only_once
                else (
                    normalize2.fit_transform(
                        ([memory[9]], []), [memory[9]], *normalize2_args, **normalize2_kwargs
                    )
                    if normalize2.tags.parametric
                    else normalize2.transform([memory[9]])
                )
            ),
            memory[10],
            memory[11],
        ),
    ]
    assert len(memory[0]) == len(memory[1]) == len(memory[2]) == len(memory[3])
    assert len(memory[4]) == len(memory[5]) == len(memory[6]) == len(memory[7])
    assert len(memory[8]) == len(memory[9]) == len(memory[10]) == len(memory[11])
    assert len(memory) == 12 and all(isinstance(slot, pd.DataFrame) for slot in memory)
    return list(
        xitertools.flatten(
            [
                supervised_learning_level_3(
                    root_cache,
                    root_data,
                    name_dateset,
                    memory,
                    name_encode,
                    name_normalize,
                    name_dimereduce,
                    *args,
                    **kwargs,
                )
                for name_dimereduce in names_dimereduce
            ]
        )
    )


def supervised_learning_level_3(
    root_cache: str,
    root_data: str,
    name_dateset: str,
    memory: List[pd.DataFrame],
    name_encode: str,
    name_normalize: str,
    name_dimereduce: str,
    /,
    *args: Any,
    configs_model: Sequence[Tuple[str, Mapping[str, Any]]] = [],
    **kwargs: Any,
) -> Sequence[Mapping[str, Any]]:
    r"""
    Perform supervised learning.

    Args
    ----
    - root_cache
        Cache root directory.
    - root_data
        Dataset root directory.
    - name_dataset
        Name of dataset.
    - memory
        Runtime memory.
    - name_encode
        Name of encoding.
    - name_normalize
        Names of normalization.
    - name_dimereduce
        Names of dimentionality reduction.
    - configs_model
        Configuration of models.

    Returns
    -------
    - profile
        Experiment profile.
    """
    # Perform dimensionality reduction.
    if os.path.isdir(os.path.join(root_cache, "lv3")):
        # Clean cache.
        lib.io.rmtree(os.path.join(root_cache, "lv3"))
    print("transfer", name_dataset, name_encode, name_normalize, name_dimereduce)
    dimereduce_ = rcimport(os.path.join("..", "dimereduce.py"))
    if len(memory[1].columns) < dimereduce_.MAX_DIMS and name_dimereduce != "identity":
        # If features are not large, there is no need for dimensionality reduction.
        return []
    dimereduce, dimereduce_args, dimereduce_kwargs = dimereduce_.prepare(
        name_dimereduce, os.path.join(root_cache, "lv3"), name_dataset
    )
    memory = [
        *(
            memory[0],
            *(
                dimereduce.fit_transform(
                    ([memory[1]], []), [memory[1]], *dimereduce_args, **dimereduce_kwargs
                )
                if dimereduce.tags.parametric
                else dimereduce.transform([memory[1]])
            ),
            memory[2],
            memory[3],
        ),
        *(memory[4], *dimereduce.transform([memory[5]]), memory[6], memory[7]),
        *(memory[8], *dimereduce.transform([memory[9]]), memory[10], memory[11]),
    ]
    assert len(memory[0]) == len(memory[1]) == len(memory[2]) == len(memory[3])
    assert len(memory[4]) == len(memory[5]) == len(memory[6]) == len(memory[7])
    assert len(memory[8]) == len(memory[9]) == len(memory[10]) == len(memory[11])
    assert len(memory) == 12 and all(isinstance(slot, pd.DataFrame) for slot in memory)
    return [
        supervised_learning_level_4(
            root_cache,
            root_data,
            name_dateset,
            memory,
            name_encode,
            name_normalize,
            name_dimereduce,
            config_model,
            *args,
            **kwargs,
        )
        for config_model in configs_model
    ]


def supervised_learning_level_4(
    root_cache: str,
    root_data: str,
    name_dateset: str,
    memory: List[pd.DataFrame],
    name_encode: str,
    name_normalize: str,
    name_dimereduce: str,
    config_model: Tuple[str, Mapping[str, Any]],
    /,
    *args: Any,
    **kwargs: Any,
) -> Mapping[str, Any]:
    r"""
    Perform supervised learning.

    Args
    ----
    - root_cache
        Cache root directory.
    - root_data
        Dataset root directory.
    - name_dataset
        Name of dataset.
    - memory
        Runtime memory.
    - name_encode
        Name of encoding.
    - name_normalize
        Names of normalization.
    - name_dimereduce
        Names of dimentionality reduction.
    - config_model
        Configuration of model.

    Returns
    -------
    - profile
        Experiment profile.
    """
    # Collect label imbalance.
    train_labels = np.reshape(memory[2].values, (len(memory[2]),))
    num_negatives = int(np.sum(train_labels == 0))
    num_positives = int(np.sum(train_labels == 1))
    weight_pos = float(num_negatives) / float(num_positives)

    # Run experiments.
    model_name, model_kwargs = config_model
    print("+", model_name, model_kwargs)
    profile = instances[model_name](
        (memory[1], memory[2]),
        (memory[5], memory[6]),
        (memory[9], memory[10]),
        weight_pos=weight_pos,
        **model_kwargs,
    )

    # Construct final profile.
    assert len(set(model_kwargs) & set(profile)) == 0
    return {
        "task": "transfer",
        "dataset": name_dataset,
        "encode": name_encode,
        "normalize": name_normalize,
        "dimereduce": name_dimereduce,
        "model": model_name,
        **model_kwargs,
        **profile,
    }


# Main program.
if __name__ == "__main__":
    # Output memory.
    root = rcimport(os.path.join("..", "root.py"))
    assert len(sys.argv) == 2, "Usage: python transfer.py <name_dataset>"
    _, name_dataset = sys.argv
    assert name_dataset in SMALL + LARGE
    profile = pd.DataFrame(
        supervised_learning(
            os.path.join(root.ROOT, "cache"), os.path.join(root.ROOT, "data"), name_dataset
        )
    )
    lib.io.mkdirs(os.path.join(root.ROOT, "profile"))
    profile.to_csv(
        os.path.join(root.ROOT, "profile", f"transfer_{name_dataset:s}.csv"), index=False
    )
