import argparse
from typing import Tuple

import numpy as np
import pandas as pd
from src.training.training_utils import (
    fit_classifier_CV,
    fit_classifier_holdout,
    fit_classifier_random,
)


def run_classifiers(
    dataset: str,
    embedding_types: Tuple[str, ...],
    split_types: Tuple[str, ...],
    target: str,
    n_partitions: int,
):
    print(f"Fitting classifiers to {dataset} dataset.")
    # Run parameters
    seeds = [0, 1, 2, 3, 4]
    eve_suffixes = ("0", "1", "2")
    threads = 20
    if target == "target_class":
        out_path = f"results/{dataset}/baseline_classification_{dataset}.csv"
    elif target == "target_class_2":
        out_path = f"results/{dataset}/baseline_classification_high_{dataset}.csv"
    else:
        raise ValueError

    df = pd.DataFrame()
    # Iterate through all embedding types sequentially
    for embedding_type in embedding_types:
        print(f"Fitting classifiers to {embedding_type}.")
        # Iterate through split strategies
        for split in split_types:
            print(f"- Using split: {split}")
            if split == "random":
                df_embed = fit_classifier_random(
                    dataset, target, embedding_type, eve_suffixes, seeds, threads
                )
            elif split == "CV":
                df_embed = fit_classifier_CV(
                    dataset, embedding_type, target, n_partitions, eve_suffixes, threads
                )

            elif split == "holdout":
                df_embed = fit_classifier_holdout(
                    dataset, embedding_type, split, target, eve_suffixes, threads
                )
            else:
                raise ValueError
            df = pd.concat((df, df_embed))

    df = df.sort_values(by=["embedding", "split_type"])
    df = df.reset_index(drop=True)
    df["suffix"] = df["suffix"].replace({None: np.nan})
    df.to_csv(out_path, index_label="index")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("dataset", type=str)
    parser.add_argument("target", type=str)
    args = parser.parse_args()
    dataset = args.dataset
    target = args.target

    embedding_types = (
        "ONEHOT (MSA)",
        "ESM-1B",
        "ESM-2",
        "ESM-IF1",
        "AF2",
        "EVE (z)",
    )
    n_partitions = 3

    split_types = ("CV", "holdout")
    run_classifiers(dataset, embedding_types, split_types, target, n_partitions)

    print("Finished.")


if __name__ == "__main__":
    main()
