from itertools import permutations
from math import factorial
from typing import Tuple

import numpy as np
import pandas as pd
from scipy.stats import spearmanr
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.metrics import matthews_corrcoef, mean_squared_error, roc_auc_score
from sklearn.model_selection import GridSearchCV, PredefinedSplit
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
from sklearn.pipeline import Pipeline
from src.data.data_utils import extract_all_embeddings, extract_holdout_embeddings


def fit_regressor_random(
    dataset: str,
    embedding_type: str,
    eve_suffixes: Tuple[str, ...],
    seeds: [int, ...],
    threads: int = 20,
    active: bool = False,
):
    # Define regression parameters
    task = "regression"
    target = "target_reg"
    config = setup_CV_grids(task)

    if embedding_type == "EVE (z)":
        n_experiments = len(seeds) * len(config["representations"]) * len(eve_suffixes)
    elif embedding_type == "EVE (ELBO)":
        n_experiments = len(seeds) * len(eve_suffixes)
    else:
        n_experiments = len(seeds) * len(config["representations"])

    # Allocate result arrays
    spear_arr = np.zeros((n_experiments, 3))
    mse_arr = np.zeros((n_experiments, 3))
    seed_arr = np.zeros(n_experiments, dtype=int)
    model_lst, suffix_lst = [], []
    # Counter to track experiments
    global_counter = 0

    n_suffix = len(eve_suffixes) if embedding_type.startswith("EVE") else 1
    for i in range(n_suffix):
        suffix = eve_suffixes[i] if embedding_type.startswith("EVE") else None
        embeddings, y, names = extract_all_embeddings(
            dataset=dataset,
            embedding_type=embedding_type,
            suffix=suffix,
            target=target,
            active=active,
        )
        # Determine fixed split sizes
        n_obs = len(y)
        n_train = int(n_obs * 0.5)
        n_val = int(n_obs * 0.25)

        # Split index for CV in sklearn
        split_index = np.repeat([-1, 0], [n_train, n_val])
        cv = PredefinedSplit(test_fold=split_index)

        # Iterate through seeds
        for seed in seeds:
            np.random.seed(seed)
            # Create permutation
            perm = np.random.permutation(n_obs)
            train_idx = perm[:n_train]
            val_idx = perm[n_train : (n_train + n_val)]
            test_idx = perm[(n_train + n_val) :]
            # Extract inputs/targets
            embedding_train = embeddings[train_idx]
            embedding_val = embeddings[val_idx]
            embedding_test = embeddings[test_idx]
            y_train = y[train_idx]
            y_val = y[val_idx]
            y_test = y[test_idx]

            if embedding_type != "EVE (ELBO)":
                for model_str, param_grid in zip(
                    config["representations"], config["param_grids"]
                ):
                    # Define grid
                    grid = GridSearchCV(
                        estimator=config["pipe"],
                        param_grid=param_grid,
                        scoring=config["scoring"],
                        verbose=0,
                        n_jobs=threads,
                        cv=cv,
                    )

                    # Fit all representations in grid
                    grid.fit(
                        np.concatenate((embedding_train, embedding_val)),
                        np.concatenate((y_train, y_val)),
                    )
                    # Extract and fit best model to training set
                    model = grid.best_estimator_["model"].fit(embedding_train, y_train)

                    # Get training, validation, and test predictions
                    preds_train = model.predict(embedding_train)
                    preds_val = model.predict(embedding_val)
                    preds_test = model.predict(embedding_test)

                    # Compute and save Spearman correlations
                    corr, mse = compute_metrics(
                        task, preds_train, preds_val, preds_test, y_train, y_val, y_test
                    )

                    # Fill in results
                    spear_arr[global_counter] = corr
                    mse_arr[global_counter] = mse
                    model_lst.append(model_str)
                    seed_arr[global_counter] = seed
                    suffix_lst.append(suffix)
                    global_counter += 1
            else:
                # Compute and save Spearman correlations
                subtask = "unsupervised"
                corr, mse = compute_metrics(
                    subtask,
                    embedding_train,
                    embedding_val,
                    embedding_test,
                    y_train,
                    y_val,
                    y_test,
                )
                # Fill in results
                spear_arr[global_counter] = corr
                mse_arr[global_counter] = mse
                model_lst.append("")
                seed_arr[global_counter] = seed
                suffix_lst.append(suffix)
                global_counter += 1

    # Construct and return DataFrame
    df_results = pd.DataFrame(
        {
            "model": model_lst,
            "seed": seed_arr,
            "suffix": suffix_lst,
            "train_spearman": spear_arr[:, 0],
            "val_spearman": spear_arr[:, 1],
            "test_spearman": spear_arr[:, 2],
            "train_mse": mse_arr[:, 0],
            "val_mse": mse_arr[:, 1],
            "test_mse": mse_arr[:, 2],
        }
    )
    df_results["embedding"] = embedding_type
    df_results["split_type"] = "random"

    return df_results


def fit_regressor_holdout(
    dataset: str,
    embedding_type: str,
    split_key: str,
    eve_suffixes: Tuple[str, ...],
    threads: int = 20,
    active: bool = False,
):
    # Define regression parameters
    task = "regression"
    config = setup_CV_grids(task)
    target = "target_reg"

    if embedding_type == "EVE (z)":
        n_experiments = len(config["representations"]) * len(eve_suffixes)
    elif embedding_type == "EVE (ELBO)":
        n_experiments = len(eve_suffixes)
    else:
        n_experiments = len(config["representations"])

    # Allocate result arrays
    spear_arr = np.zeros((n_experiments, 3))
    mse_arr = np.zeros((n_experiments, 3))
    seed_arr = np.zeros(n_experiments, dtype=int)
    model_lst, suffix_lst = [], []
    # Counter to track experiments
    global_counter = 0

    n_suffix = len(eve_suffixes) if embedding_type.startswith("EVE") else 1
    for i in range(n_suffix):
        suffix = eve_suffixes[i] if embedding_type.startswith("EVE") else None
        (
            embedding_train,
            embedding_val,
            embedding_test,
            y_train,
            y_val,
            y_test,
            train_names,
            val_names,
            test_names,
        ) = extract_holdout_embeddings(
            dataset=dataset,
            embedding_type=embedding_type,
            target=target,
            suffix=suffix,
            split_key=split_key,
            active=active,
        )
        n_train, n_val, n_test = len(y_train), len(y_val), len(y_test)

        # Split index for CV
        split_index = np.repeat([-1, 0], [n_train, n_val])
        cv = PredefinedSplit(test_fold=split_index)
        if embedding_type != "EVE (ELBO)":
            for model_str, param_grid in zip(
                config["representations"], config["param_grids"]
            ):
                # Define grid
                grid = GridSearchCV(
                    estimator=config["pipe"],
                    param_grid=param_grid,
                    scoring=config["scoring"],
                    verbose=0,
                    n_jobs=threads,
                    cv=cv,
                )

                # Fit all representations in grid
                grid.fit(
                    np.concatenate((embedding_train, embedding_val)),
                    np.concatenate((y_train, y_val)),
                )
                # Extract and fit best model to training set
                model = grid.best_estimator_["model"].fit(embedding_train, y_train)

                # Get training, validation, and test predictions
                preds_train = model.predict(embedding_train)
                preds_val = model.predict(embedding_val)
                preds_test = model.predict(embedding_test)

                # Compute and save metrics
                corr, mse = compute_metrics(
                    task, preds_train, preds_val, preds_test, y_train, y_val, y_test
                )
                spear_arr[global_counter] = corr
                mse_arr[global_counter] = mse
                # Fill in results
                model_lst.append(model_str)
                seed_arr[global_counter] = 0
                suffix_lst.append(suffix)
                global_counter += 1
        else:
            # Compute and save Spearman correlations
            subtask = "unsupervised"
            # Compute and save metrics
            corr, mse = compute_metrics(
                subtask,
                embedding_train,
                embedding_val,
                embedding_test,
                y_train,
                y_val,
                y_test,
            )
            spear_arr[global_counter] = corr
            mse_arr[global_counter] = mse
            # Fill in results
            model_lst.append("")
            seed_arr[global_counter] = 0
            suffix_lst.append(suffix)
            global_counter += 1

    # Construct and return DataFrame
    df_results = pd.DataFrame(
        {
            "model": model_lst,
            "seed": seed_arr,
            "suffix": suffix_lst,
            "train_spearman": spear_arr[:, 0],
            "val_spearman": spear_arr[:, 1],
            "test_spearman": spear_arr[:, 2],
            "train_mse": mse_arr[:, 0],
            "val_mse": mse_arr[:, 1],
            "test_mse": mse_arr[:, 2],
        }
    )
    df_results["embedding"] = embedding_type
    df_results["split_type"] = split_key

    return df_results


def fit_regressor_CV(
    dataset: str,
    embedding_type: str,
    n_partitions: int,
    eve_suffixes: Tuple[str, ...],
    threads: int = 20,
    active: bool = False,
):
    # Load dataset
    df = pd.read_csv(f"data/processed/{dataset}/{dataset}.csv", index_col=0)

    # Define regression parameters
    task = "regression"
    config = setup_CV_grids(task)
    target = "target_reg"

    if active:
        df = df.loc[df["target_class"].astype(bool)]

    df = df[df[["part_0", "part_1", "part_2"]].sum(axis=1) == 1]

    partition_headers = [f"part_{i}" for i in range(n_partitions)]
    if embedding_type == "EVE (z)":
        n_experiments = int(
            len(config["representations"])
            * len(eve_suffixes)
            * int(factorial(n_partitions) / (factorial(n_partitions - 3)))
        )
    elif embedding_type == "EVE (ELBO)":
        n_experiments = len(eve_suffixes) * int(
            factorial(n_partitions) / (factorial(n_partitions - 3))
        )
    else:
        n_experiments = len(config["representations"]) * int(
            factorial(n_partitions) / (factorial(n_partitions - 3))
        )

    # Allocate result arrays
    spear_arr = np.zeros((n_experiments, 3))
    # spear_act_arr = np.zeros((n_experiments, 3))
    mse_arr = np.zeros((n_experiments, 3))
    seed_arr = np.zeros(n_experiments, dtype=int)
    model_lst, suffix_lst, train_id, val_id, test_id = [], [], [], [], []
    # Counter to track experiments
    global_counter = 0

    n_suffix = len(eve_suffixes) if embedding_type.startswith("EVE") else 1
    for i in range(n_suffix):
        suffix = eve_suffixes[i] if embedding_type.startswith("EVE") else None

        # Extract embeddings
        embeddings, y, names = extract_all_embeddings(
            dataset=dataset,
            embedding_type=embedding_type,
            suffix=suffix,
            target=target,
            active=active,
        )

        # Iterate through CV partitions
        for perm in permutations(partition_headers, 3):
            # Extract inputs/targets
            train_idx = df[perm[0]].values.astype(bool)
            val_idx = df[perm[1]].values.astype(bool)
            test_idx = df[perm[2]].values.astype(bool)

            embedding_train = embeddings[train_idx]
            embedding_val = embeddings[val_idx]
            embedding_test = embeddings[test_idx]
            y_train = y[train_idx]
            y_val = y[val_idx]
            y_test = y[test_idx]

            # Split index for CV in sklearn
            n_train = len(embedding_train)
            n_val = len(embedding_val)
            split_index = np.repeat([-1, 0], [n_train, n_val])
            cv = PredefinedSplit(test_fold=split_index)

            if embedding_type != "EVE (ELBO)":
                for model_str, param_grid in zip(
                    config["representations"], config["param_grids"]
                ):
                    # Define grid
                    grid = GridSearchCV(
                        estimator=config["pipe"],
                        param_grid=param_grid,
                        scoring=config["scoring"],
                        verbose=0,
                        n_jobs=threads,
                        cv=cv,
                    )

                    # Fit all representations in grid
                    grid.fit(
                        np.concatenate((embedding_train, embedding_val)),
                        np.concatenate((y_train, y_val)),
                    )
                    # Extract and fit best model to training set
                    model = grid.best_estimator_["model"].fit(embedding_train, y_train)

                    # Get training, validation, and test predictions
                    preds_train = model.predict(embedding_train)
                    preds_val = model.predict(embedding_val)
                    preds_test = model.predict(embedding_test)

                    # Compute and save metrics
                    corr, mse = compute_metrics(
                        task, preds_train, preds_val, preds_test, y_train, y_val, y_test
                    )
                    spear_arr[global_counter] = corr
                    spear_arr[global_counter] = corr
                    mse_arr[global_counter] = mse
                    # Fill in results
                    model_lst.append(model_str)
                    seed_arr[global_counter] = 0
                    suffix_lst.append(suffix)
                    train_id.append(perm[0])
                    val_id.append(perm[1])
                    test_id.append(perm[2])
                    global_counter += 1
            else:
                # Compute and save Spearman correlations
                subtask = "unsupervised"
                # Compute and save metrics
                corr, mse = compute_metrics(
                    subtask,
                    embedding_train,
                    embedding_val,
                    embedding_test,
                    y_train,
                    y_val,
                    y_test,
                )
                if dataset == "gh1":
                    corr *= -1.0
                spear_arr[global_counter] = corr
                mse_arr[global_counter] = mse
                # Fill in results
                model_lst.append("")
                seed_arr[global_counter] = 0
                suffix_lst.append(suffix)
                # Track performance across partitions
                train_id.append(perm[0])
                val_id.append(perm[1])
                test_id.append(perm[2])
                global_counter += 1

    # Construct and return DataFrame
    df_results = pd.DataFrame(
        {
            "model": model_lst,
            "seed": seed_arr,
            "suffix": suffix_lst,
            "train_spearman": spear_arr[:, 0],
            "val_spearman": spear_arr[:, 1],
            "test_spearman": spear_arr[:, 2],
            "train_mse": mse_arr[:, 0],
            "val_mse": mse_arr[:, 1],
            "test_mse": mse_arr[:, 2],
            "train_id": train_id,
            "val_id": val_id,
            "test_id": test_id,
        }
    )
    df_results["embedding"] = embedding_type
    df_results["split_type"] = "CV"

    return df_results


def fit_classifier_holdout(
    dataset: str,
    embedding_type: str,
    split_key: str,
    target: str,
    eve_suffixes: Tuple[str, ...],
    threads: int = 20,
):
    task = "classification"
    config = setup_CV_grids(task)

    if embedding_type == "EVE (z)":
        n_experiments = len(config["representations"]) * len(eve_suffixes)
    else:
        n_experiments = len(config["representations"])

    # Allocate result arrays
    mcc_arr = np.zeros((n_experiments, 3))
    auroc_arr = np.zeros((n_experiments, 3))
    seed_arr = np.zeros(n_experiments, dtype=int)
    model_lst, suffix_lst = [], []
    # Counter to track experiments
    global_counter = 0

    n_suffix = len(eve_suffixes) if embedding_type.startswith("EVE") else 1
    for i in range(n_suffix):
        suffix = eve_suffixes[i] if embedding_type.startswith("EVE") else None
        (
            embedding_train,
            embedding_val,
            embedding_test,
            y_train,
            y_val,
            y_test,
            train_names,
            val_names,
            test_names,
        ) = extract_holdout_embeddings(
            dataset=dataset,
            embedding_type=embedding_type,
            target=target,
            suffix=suffix,
            split_key=split_key,
        )
        n_train, n_val, n_test = len(y_train), len(y_val), len(y_test)

        # Split index for CV
        split_index = np.repeat([-1, 0], [n_train, n_val])
        cv = PredefinedSplit(test_fold=split_index)
        for model_str, param_grid in zip(
            config["representations"], config["param_grids"]
        ):
            # Define grid
            grid = GridSearchCV(
                estimator=config["pipe"],
                param_grid=param_grid,
                scoring=config["scoring"],
                verbose=0,
                n_jobs=threads,
                cv=cv,
            )

            # Fit all representations in grid
            grid.fit(
                np.concatenate((embedding_train, embedding_val)),
                np.concatenate((y_train, y_val)),
            )
            # Extract and fit best model to training set
            model = grid.best_estimator_["model"].fit(embedding_train, y_train)

            # Get training, validation, and test predictions
            preds_train = model.predict(embedding_train)
            preds_val = model.predict(embedding_val)
            preds_test = model.predict(embedding_test)

            corr, auroc = compute_metrics(
                task, preds_train, preds_val, preds_test, y_train, y_val, y_test
            )
            # Compute and save correlations
            mcc_arr[global_counter] = corr
            # Compute and save AUROC
            auroc_arr[global_counter] = auroc

            # Fill in results
            model_lst.append(model_str)
            seed_arr[global_counter] = 0
            suffix_lst.append(suffix)
            global_counter += 1

    # Construct and return DataFrame
    df_results = pd.DataFrame(
        {
            "model": model_lst,
            "seed": seed_arr,
            "suffix": suffix_lst,
            "train_mcc": mcc_arr[:, 0],
            "val_mcc": mcc_arr[:, 1],
            "test_mcc": mcc_arr[:, 2],
            "train_auroc": auroc_arr[:, 0],
            "val_auroc": auroc_arr[:, 1],
            "test_auroc": auroc_arr[:, 2],
        }
    )
    df_results["embedding"] = embedding_type
    df_results["split_type"] = split_key

    return df_results


def fit_classifier_CV(
    dataset: str,
    embedding_type: str,
    target: str,
    n_partitions: int,
    eve_suffixes: Tuple[str, ...],
    threads: int = 20,
):
    task = "classification"
    config = setup_CV_grids(task)
    if embedding_type == "EVE (z)":
        n_experiments = (
            len(config["representations"])
            * len(eve_suffixes)
            * int(factorial(n_partitions) / (factorial(n_partitions - 3)))
        )
    else:
        n_experiments = len(config["representations"]) * int(
            factorial(n_partitions) / (factorial(n_partitions - 3))
        )
    partition_headers = [f"part_{i}" for i in range(n_partitions)]

    # Load dataset
    df = pd.read_csv(f"data/processed/{dataset}/{dataset}.csv", index_col=0)
    df = df[df[["part_0", "part_1", "part_2"]].sum(axis=1) == 1]

    # Allocate result arrays
    mcc_arr = np.zeros((n_experiments, 3))
    auroc_arr = np.zeros((n_experiments, 3))
    seed_arr = np.zeros(n_experiments, dtype=int)
    model_lst, suffix_lst, train_id, val_id, test_id = [], [], [], [], []
    # Counter to track experiments
    global_counter = 0

    n_suffix = len(eve_suffixes) if embedding_type.startswith("EVE") else 1
    for i in range(n_suffix):
        suffix = eve_suffixes[i] if embedding_type.startswith("EVE") else None
        # Extract embeddings
        embeddings, y, names = extract_all_embeddings(
            dataset=dataset, embedding_type=embedding_type, suffix=suffix, target=target
        )

        # Iterate through CV partitions
        for perm in permutations(partition_headers, 3):
            # Extract inputs/targets
            train_idx = df[perm[0]].values.astype(bool)
            val_idx = df[perm[1]].values.astype(bool)
            test_idx = df[perm[2]].values.astype(bool)
            embedding_train = embeddings[train_idx]
            embedding_val = embeddings[val_idx]
            embedding_test = embeddings[test_idx]
            y_train = y[train_idx]
            y_val = y[val_idx]
            y_test = y[test_idx]

            # Split index for CV in sklearn
            n_train = len(embedding_train)
            n_val = len(embedding_val)
            split_index = np.repeat([-1, 0], [n_train, n_val])
            cv = PredefinedSplit(test_fold=split_index)

            for model_str, param_grid in zip(
                config["representations"], config["param_grids"]
            ):
                # Define grid
                grid = GridSearchCV(
                    estimator=config["pipe"],
                    param_grid=param_grid,
                    scoring=config["scoring"],
                    verbose=0,
                    n_jobs=threads,
                    cv=cv,
                )

                # Fit all representations in grid
                grid.fit(
                    np.concatenate((embedding_train, embedding_val)),
                    np.concatenate((y_train, y_val)),
                )
                # Extract and fit best model to training set
                model = grid.best_estimator_["model"].fit(embedding_train, y_train)

                # Get training, validation, and test predictions
                preds_train = model.predict(embedding_train)
                preds_val = model.predict(embedding_val)
                preds_test = model.predict(embedding_test)

                # Compute and save metrics
                corr, auroc = compute_metrics(
                    task, preds_train, preds_val, preds_test, y_train, y_val, y_test
                )
                mcc_arr[global_counter] = corr
                auroc_arr[global_counter] = auroc

                # Fill in results
                model_lst.append(model_str)
                seed_arr[global_counter] = 0
                suffix_lst.append(suffix)
                train_id.append(perm[0])
                val_id.append(perm[1])
                test_id.append(perm[2])
                global_counter += 1

    # Construct and return DataFrame
    df_results = pd.DataFrame(
        {
            "model": model_lst,
            "seed": seed_arr,
            "suffix": suffix_lst,
            "train_mcc": mcc_arr[:, 0],
            "val_mcc": mcc_arr[:, 1],
            "test_mcc": mcc_arr[:, 2],
            "train_auroc": auroc_arr[:, 0],
            "val_auroc": auroc_arr[:, 1],
            "test_auroc": auroc_arr[:, 2],
            "train_id": train_id,
            "val_id": val_id,
            "test_id": test_id,
        }
    )
    df_results["embedding"] = embedding_type
    df_results["split_type"] = "CV"

    return df_results


def fit_classifier_random(
    dataset: str,
    target: str,
    embedding_type: str,
    eve_suffixes: Tuple[str, ...],
    seeds: [int, ...],
    threads: int = 20,
):
    # Define regression parameters
    task = "classification"
    config = setup_CV_grids(task)

    if embedding_type == "EVE (z)":
        n_experiments = len(seeds) * len(config["representations"]) * len(eve_suffixes)
    else:
        n_experiments = len(seeds) * len(config["representations"])

    # Allocate result arrays
    mcc_arr = np.zeros((n_experiments, 3))
    auroc_arr = np.zeros((n_experiments, 3))
    seed_arr = np.zeros(n_experiments, dtype=int)
    model_lst, suffix_lst = [], []
    # Counter to track experiments
    global_counter = 0

    n_suffix = len(eve_suffixes) if embedding_type.startswith("EVE") else 1
    for i in range(n_suffix):
        suffix = eve_suffixes[i] if embedding_type.startswith("EVE") else None
        embeddings, y, names = extract_all_embeddings(
            dataset=dataset, embedding_type=embedding_type, suffix=suffix, target=target
        )
        # Determine fixed split sizes
        n_obs = len(y)
        n_train = int(n_obs * 0.5)
        n_val = int(n_obs * 0.25)

        # Split index for CV in sklearn
        split_index = np.repeat([-1, 0], [n_train, n_val])
        cv = PredefinedSplit(test_fold=split_index)

        # Iterate through seeds
        for seed in seeds:
            np.random.seed(seed)
            # Create permutation
            perm = np.random.permutation(n_obs)
            train_idx = perm[:n_train]
            val_idx = perm[n_train : (n_train + n_val)]
            test_idx = perm[(n_train + n_val) :]
            # Extract inputs/targets
            embedding_train = embeddings[train_idx]
            embedding_val = embeddings[val_idx]
            embedding_test = embeddings[test_idx]
            y_train = y[train_idx]
            y_val = y[val_idx]
            y_test = y[test_idx]

            for model_str, param_grid in zip(
                config["representations"], config["param_grids"]
            ):
                # Define grid
                grid = GridSearchCV(
                    estimator=config["pipe"],
                    param_grid=param_grid,
                    scoring=config["scoring"],
                    verbose=0,
                    n_jobs=threads,
                    cv=cv,
                )

                # Fit all representations in grid
                grid.fit(
                    np.concatenate((embedding_train, embedding_val)),
                    np.concatenate((y_train, y_val)),
                )
                # Extract and fit best model to training set
                model = grid.best_estimator_["model"].fit(embedding_train, y_train)

                # Get training, validation, and test predictions
                preds_train = model.predict(embedding_train)
                preds_val = model.predict(embedding_val)
                preds_test = model.predict(embedding_test)

                corr, auroc = compute_metrics(
                    task, preds_train, preds_val, preds_test, y_train, y_val, y_test
                )
                # Compute and save correlations
                mcc_arr[global_counter] = corr
                # Compute and save AUROC
                auroc_arr[global_counter] = auroc

                # Fill in results
                model_lst.append(model_str)
                seed_arr[global_counter] = seed
                suffix_lst.append(suffix)
                global_counter += 1

    # Construct and return DataFrame
    df_results = pd.DataFrame(
        {
            "model": model_lst,
            "seed": seed_arr,
            "suffix": suffix_lst,
            "train_mcc": mcc_arr[:, 0],
            "val_mcc": mcc_arr[:, 1],
            "test_mcc": mcc_arr[:, 2],
            "train_auroc": auroc_arr[:, 0],
            "val_auroc": auroc_arr[:, 1],
            "test_auroc": auroc_arr[:, 2],
        }
    )
    df_results["embedding"] = embedding_type
    df_results["split_type"] = "random"

    return df_results


def setup_CV_grids(task: str):
    if task == "regression":
        model_strs = ["KNN", "Ridge", "RandomForest"]
        scoring = "neg_mean_squared_error"
        knn_grid = [
            {
                "model": [KNeighborsRegressor()],
                "model__n_neighbors": [1, 2, 5, 10, 25],
            }
        ]

        lm_grid = [
            {
                "model": [Ridge(random_state=0)],
                "model__alpha": [0.0001, 0.001, 0.01, 0.1, 1, 10, 25],
            }
        ]

        rf_grid = [
            {
                "model": [RandomForestRegressor(random_state=0, max_features="sqrt")],
                "model__min_samples_split": [2, 5],
            }
        ]
    elif task == "classification":
        scoring = "neg_log_loss"
        model_strs = ["KNN", "LogReg", "RandomForest"]
        knn_grid = [
            {
                "model": [KNeighborsClassifier()],
                "model__n_neighbors": [1, 2, 5, 10],
            }
        ]

        lm_grid = [
            {
                "model": [
                    LogisticRegression(random_state=0, penalty="l2", max_iter=1000)
                ],
                "model__C": [0.1, 0.15, 0.20, 1, 5, 10, 15, 20, 25, 30, 40, 50],
            }
        ]

        rf_grid = [
            {
                "model": [RandomForestClassifier(random_state=0, max_features="sqrt")],
                "model__min_samples_split": [2, 5],
            }
        ]
    else:
        raise NotImplementedError

    param_grid_list = [knn_grid, lm_grid, rf_grid]
    pipe = Pipeline([("model", "passthrough")])

    return {
        "param_grids": param_grid_list,
        "pipe": pipe,
        "representations": model_strs,
        "scoring": scoring,
    }


def compute_metrics(
    task: str,
    preds_train: np.ndarray,
    preds_val: np.ndarray,
    preds_test: np.ndarray,
    y_train: np.ndarray,
    y_val: np.ndarray,
    y_test: np.ndarray,
):
    if task == "classification":
        corr_1 = matthews_corrcoef(y_train, preds_train)
        corr_2 = matthews_corrcoef(y_val, preds_val)
        corr_3 = matthews_corrcoef(y_test, preds_test)
        auroc_1 = roc_auc_score(y_train, preds_train)
        auroc_2 = roc_auc_score(y_val, preds_val)
        auroc_3 = roc_auc_score(y_test, preds_test)
        corr = np.array((corr_1, corr_2, corr_3))
        auroc = np.array((auroc_1, auroc_2, auroc_3))
        return corr, auroc
    elif task in ["regression", "unsupervised"]:
        corr_1, _ = spearmanr(y_train, preds_train)
        corr_2, _ = spearmanr(y_val, preds_val)
        corr_3, _ = spearmanr(y_test, preds_test)
        corr = np.array((corr_1, corr_2, corr_3))
        if task == "regression":
            mse_1 = mean_squared_error(y_train, preds_train)
            mse_2 = mean_squared_error(y_val, preds_val)
            mse_3 = mean_squared_error(y_test, preds_test)
            mse = np.array((mse_1, mse_2, mse_3))
        else:
            mse = np.array((np.nan, np.nan, np.nan))
        return corr, mse
    else:
        raise ValueError
