import argparse
import sys
from typing import Tuple

import numpy as np
import pandas as pd
from src.training.training_utils import (
    fit_regressor_CV,
    fit_regressor_holdout,
    fit_regressor_random,
)


def run_regressors(
    dataset: str,
    embedding_types: Tuple[str, ...],
    split_types: Tuple[str, ...],
    n_partitions: int,
    active: bool,
):
    print(f"Fitting regressors to {dataset} dataset.")
    # Run parameters
    seeds = [0, 1, 2, 3, 4]
    eve_suffixes = ("0", "1", "2")
    threads = 20
    if active:
        out_path = f"results/{dataset}/baseline_regression_active_{dataset}.csv"
    else:
        out_path = f"results/{dataset}/baseline_regression_{dataset}.csv"

    df = pd.DataFrame()
    # Iterate through all embedding types sequentially
    for embedding_type in embedding_types:
        print(f"Fitting  regressors to {embedding_type}.")
        # Iterate through split strategies
        for split in split_types:
            print(f"- Using split: {split}")
            if split == "random":
                df_embed = fit_regressor_random(
                    dataset, embedding_type, eve_suffixes, seeds, threads, active
                )
            elif split == "CV":
                df_embed = fit_regressor_CV(
                    dataset, embedding_type, n_partitions, eve_suffixes, threads, active
                )

            elif split == "holdout":
                df_embed = fit_regressor_holdout(
                    dataset, embedding_type, split, eve_suffixes, threads, active
                )
            else:
                raise ValueError
            df = pd.concat((df, df_embed))

    df = df.sort_values(by=["embedding", "split_type"])
    df = df.reset_index(drop=True)
    df["suffix"] = df["suffix"].replace({None: np.nan})
    df.to_csv(out_path, index_label="index")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("dataset", type=str)
    parser.add_argument("--active", default=False, action="store_true")
    args = parser.parse_args()
    dataset = args.dataset
    active = args.active

    if active and dataset not in ["cm"]:
        sys.exit(0)

    embedding_types = (
        "ONEHOT (MSA)",
        "ESM-1B",
        "ESM-2",
        "ESM-IF1",
        "AF2",
        "EVE (z)",
    )
    n_partitions = 3

    split_types = ("CV", "holdout", "random")

    run_regressors(dataset, embedding_types, split_types, n_partitions, active)

    print("Finished.")


if __name__ == "__main__":
    main()
