"""Functions used for hyperparameter serach."""

import pandas as pd

from sklearn.model_selection import cross_validate, RandomizedSearchCV
from configs.model_parameters import param_distributions_total

from src.utils import assign_estimator, shorten_param


def _get_default_cv_score(estimator, X_train, y_train, cv, n_jobs, scoring):
    """Function to get CV score of the default parameters."""
    # Run CV
    default_score = cross_validate(
        estimator,
        X_train,
        y_train,
        cv=cv,
        n_jobs=n_jobs,
        scoring=scoring,
    )
    default_results_ = pd.DataFrame(default_score)
    # Split results
    default_split_results = default_results_[["test_score"]].transpose()
    default_split_results.reset_index(drop=True, inplace=True)
    rename_col = ["split" + str(col) + "_test_score" for col in default_split_results]
    default_split_results.columns = rename_col
    # Mean results
    default_mean_results = pd.DataFrame(default_results_.mean()).transpose()
    rename_col = ["mean_" + col for col in default_mean_results]
    default_mean_results.columns = rename_col
    # Std results
    default_std_results = pd.DataFrame(default_results_.std()).transpose()
    rename_col = ["std_" + col for col in default_std_results]
    default_std_results.columns = rename_col
    # Clean-up
    default_results = pd.concat(
        [default_mean_results, default_std_results, default_split_results], axis=1
    )
    default_results["params"] = {}
    return default_results


def run_param_search(
    X_train,
    y_train,
    task,
    estim_method,
    cv,
    n_iter,
    n_jobs,
    scoring,
    device,
):

    # Basic settings
    param_method = estim_method.split("-")[-1]
    no_search_estimators = ["ridge", "tabpfn", "realmlp", "tabm"]
    # realmlp,tabm have its owns search mechanism, so included in no_search logic

    # Parameter distribution
    param_distributions = param_distributions_total[param_method]

    # CV settings
    refit = False

    # Run hyperparmeter search
    if param_method in no_search_estimators:
        best_params = {}
        cv_results = None

    else:
        estimator = assign_estimator(
            estim_method,
            task,
            device,
            train_flag=False,
        )
        default_results = _get_default_cv_score(
            estimator,
            X_train,
            y_train,
            cv=cv,
            n_jobs=n_jobs,
            scoring=scoring,
        )
        estimator = assign_estimator(
            estim_method,
            task,
            device,
            train_flag=False,
        )
        hyperparameter_search = RandomizedSearchCV(
            estimator,
            param_distributions=param_distributions,
            n_iter=n_iter - 1,  # Excluding default
            cv=cv,
            scoring=scoring,
            refit=refit,
            n_jobs=n_jobs,
            random_state=1234,
        )
        hyperparameter_search.fit(X_train, y_train)
        cv_results_ = pd.DataFrame(hyperparameter_search.cv_results_)
        cv_results_ = cv_results_.rename(shorten_param, axis=1)

        # Format the cv results with the default added
        cv_results = pd.concat([cv_results_, default_results], axis=0)
        cv_results.reset_index(drop=True, inplace=True)
        rank = (
            cv_results["mean_test_score"]
            .rank(method="min", ascending=False)
            .astype(int)
            .copy()
        )
        cv_results["rank_test_score"] = rank
        params_ = cv_results["params"]
        best_params = params_[cv_results["rank_test_score"] == 1].iloc[0]
        if str(best_params) == "nan":
            best_params = {}

    return best_params, cv_results
