from typing import Tuple, Union

import pandas as pd
from src.visualization.visualization_funcs import show_test_results


def show_all_results(
    dataset: str,
    task: str,
    target: str,
    active: bool,
    split_types: Tuple[str, ...],
    metric: Union[None, str] = None,
):
    assert target in ["target_reg", "target_class", "target_class_2"]
    if task == "regression":
        embedding_types = (
            # "ONEHOT",
            "ONEHOT (MSA)",
            "ESM-2",
            "ESM-1B",
            "ESM-IF1",
            # "EVE (ELBO)",
            "EVE (z)",
            "AF2",
        )
        metric = "spearman" if metric is None else metric
        regressors = ("KNN", "Ridge", "RandomForest")
    else:
        active = False
        embedding_types = (
            # "ONEHOT",
            "ONEHOT (MSA)",
            "ESM-2",
            "ESM-1B",
            "ESM-IF1",
            "EVE (z)",
            "AF2",
        )
        metric = "mcc" if metric is None else metric
        regressors = ("KNN", "LogReg", "RandomForest")

    for split in split_types:
        # if split in ["random", "CV", "wt_vs_wt", "wt_vs_gen"]:
        #     show_results_holdout(
        #         dataset,
        #         task,
        #         regressors,
        #         embedding_types,
        #         metric,
        #         split,
        #         active,
        #         target,
        #     )
        # elif split == "wt_vs_identifiers":
        #     show_results_wt_vs_identifier(
        #         dataset,
        #         task,
        #         regressors,
        #         embedding_types,
        #         metric,
        #         identifiers,
        #         active,
        #         target,
        #     )
        # else:
        #     raise ValueError
        show_test_results(
            dataset,
            task,
            regressors,
            embedding_types,
            metric,
            split,
            active,
            target,
        )


def main():

    # Process all datasets
    # datasets = ["cm", "tim", "ppat"]
    datasets = ["tim"]
    split_types = ("CV", "holdout")

    for dataset in datasets:

        # Full regression results for all datasets
        task, target = "regression", "target_reg"
        metric = "spearman"
        show_all_results(dataset, task, target, False, split_types, metric=metric)

        if dataset in ["cm", "gh1"]:
            # Regression on active sequences for CM
            show_all_results(dataset, task, target, True, ("CV",), metric=metric)


if __name__ == "__main__":
    main()
    print("Finished.")
