"""Script for evaluting models."""

import os
import time
import joblib
import pandas as pd
import numpy as np
import argparse

from sklearn.model_selection import RepeatedKFold, RepeatedStratifiedKFold
from sklearn.decomposition import PCA
from src.utils import (
    load_data_config,
    set_score_criterion,
    assign_estimator,
    calculate_output,
    reshape_pred_output,
    check_pred_output,
    return_score,
)
from src.utils import set_split
from src.preprocess import preprocess_with_cache
from src.param_search import run_param_search
from configs.path_configs import path_configs
from configs.preprocess_configs import preprocess_configs


def run_model(
    data_name,
    method,
    num_train,
    random_state,
    device,
):
    """Run model for specific experiment setting."""

    marker = f"{data_name}|{method}|nt-{num_train}|rs-{random_state}"
    print(marker + " start")

    # Set paths to save results
    save_path = "./results/llm_kg_comparison"
    result_save_base_path = f"{save_path}/{data_name}"
    if not os.path.exists(result_save_base_path):
        os.makedirs(result_save_base_path, exist_ok=True)
    if not os.path.exists(result_save_base_path + "/score"):
        os.makedirs(result_save_base_path + "/score", exist_ok=True)
    if not os.path.exists(result_save_base_path + "/log"):
        os.makedirs(result_save_base_path + "/log", exist_ok=True)
    results_model_path = result_save_base_path + f"/score/{marker}.csv"
    log_path = result_save_base_path + f"/log/{marker}_log.csv"

    # Set preliminaries
    data_config = load_data_config(data_name)
    task = data_config["task"]
    scoring, result_criterion = set_score_criterion(task)
    embed_method = ("_").join(method.split("_")[:-1])
    estim_method = method.split("_")[-1].split("-")[0]

    # Set cross-validation settings
    if task == "reg":
        cv = RepeatedKFold(n_splits=5, n_repeats=5, random_state=1234)
    else:
        cv = RepeatedStratifiedKFold(n_splits=5, n_repeats=5, random_state=1234)
    n_iter, n_jobs = 100, min(len(os.sched_getaffinity(0)), 32)

    # Preprocess data

    if preprocess_configs[embed_method]["cache_embedding"]:
        mem = joblib.Memory(
            f"{path_configs['emb_cache_folder']}/{embed_method}", verbose=0
        )
        X_train, X_test, y_train, y_test, duration_preprocess = mem.cache(
            preprocess_with_cache
        )(
            data_name,
            embed_method,
            num_train,
            random_state,
        )
    else:
        emb_data_folder = f"{path_configs['emb_data_folder']}/{embed_method}"
        emb_data_path = f"{emb_data_folder}/{data_name}.parquet"
        time_folder = f"{path_configs['results']}/embed_extraction_time/{embed_method}"
        time_path = f"{time_folder}/{embed_method}|{data_name}.npy"
        # Load Data and time
        data = pd.read_parquet(emb_data_path)
        duration_preprocess_total = np.load(time_path)

        X_train, X_test, y_train, y_test = set_split(
            data, data_config, num_train, random_state, extracted_emb=True
        )
        X_train, X_test = np.array(X_train), np.array(X_test)

        duration_preprocess = (
            duration_preprocess_total
            * (X_train.shape[0] + X_test.shape[0])
            / data.shape[0]
        )

    # PCA for large features
    start_time = time.perf_counter()

    # Apply PCA for xgb
    if method.split("_")[-1] == "xgb-pca":
        if X_train.shape[1] > 300:
            n_components = min(X_train.shape[0], 300)
            pca = PCA(n_components=n_components, random_state=1234)
            X_train = pca.fit_transform(X_train)
            X_test = pca.transform(X_test)

    end_time = time.perf_counter()
    duration_preprocess += round(end_time - start_time, 4)

    # Hyperparmeter search
    start_time = time.perf_counter()

    best_params, cv_results = run_param_search(
        X_train,
        y_train,
        task,
        estim_method,
        cv,
        n_iter,
        n_jobs,
        scoring,
        device,
    )

    end_time = time.perf_counter()
    duration_param_search = round(end_time - start_time, 4)

    # Save cv results
    if cv_results is not None:
        cv_results.to_csv(log_path, index=False)

    # Final fit and predict
    start_time = time.perf_counter()

    estimator = assign_estimator(
        estim_method,
        task,
        device,
        train_flag=True,
        best_params_estimator=best_params,
    )
    estimator.fit(X_train, y_train)
    y_prob, y_pred = calculate_output(X_test, estimator, task)

    # Reshape prediction
    if "clf" in task:
        y_prob = reshape_pred_output(y_prob)

    # Check the output
    if task == "reg":
        y_pred = check_pred_output(y_train, y_pred)

    # obtain scores
    score = return_score(y_test, y_prob, y_pred, task)

    end_time = time.perf_counter()
    duration_inference = round(end_time - start_time, 4)

    # Format the results
    results_ = dict()
    for i in range(len(result_criterion[:-4])):
        results_[result_criterion[i]] = score[i]
    results_[result_criterion[-4]] = duration_preprocess
    results_[result_criterion[-3]] = duration_param_search
    results_[result_criterion[-2]] = duration_inference
    results_[result_criterion[-1]] = (
        duration_preprocess + duration_param_search + duration_inference
    )
    results_model = pd.DataFrame([results_], columns=result_criterion)
    results_model["data_name"] = data_name
    results_model["method"] = method
    results_model["num_train"] = num_train
    results_model["random_state"] = random_state
    results_model["task"] = task

    # Save the results in csv
    results_model.to_csv(results_model_path, index=False)

    print(marker + " is complete")

    return None


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_name", type=str, required=True)
    parser.add_argument("--method", type=str, required=True)
    parser.add_argument("--num_train", type=int, required=True)
    parser.add_argument("--device", type=str, default="cpu")
    args = parser.parse_args()

    random_states = range(1, 11)

    for rs in random_states:
        run_model(
            data_name=args.data_name,
            num_train=args.num_train,
            method=args.method,
            random_state=rs,
            device=args.device,
        )
