"""Script for evaluting models."""

import os
import time
import pandas as pd
import numpy as np
import argparse

from glob import glob
from sklearn.model_selection import (
    train_test_split,
    RepeatedKFold,
    RepeatedStratifiedKFold,
)
from sklearn.decomposition import PCA
from src.utils import (
    load_data_config,
    set_score_criterion,
    assign_estimator,
    calculate_output,
    reshape_pred_output,
    check_pred_output,
    return_score,
)
from src.param_search import run_param_search
from configs.path_configs import path_configs


def run_model(
    data_name,
    method,
    num_train,
    random_state,
    device,
):
    """Run model for specific experiment setting."""

    marker = f"{data_name}|{method}|nt-{num_train}|rs-{random_state}"
    print(marker + " start")

    # Set paths to save results
    save_path = "./results/llm_kg_comparison_linked"
    result_save_base_path = f"{save_path}/{data_name}"
    if not os.path.exists(result_save_base_path):
        os.makedirs(result_save_base_path, exist_ok=True)
    if not os.path.exists(result_save_base_path + "/score"):
        os.makedirs(result_save_base_path + "/score", exist_ok=True)
    if not os.path.exists(result_save_base_path + "/log"):
        os.makedirs(result_save_base_path + "/log", exist_ok=True)
    results_model_path = result_save_base_path + f"/score/{marker}.csv"
    log_path = result_save_base_path + f"/log/{marker}_log.csv"

    # Set preliminaries
    data_config = load_data_config(data_name)
    task = data_config["task"]
    scoring, result_criterion = set_score_criterion(task)
    estim_method = method.split("_")[-1].split("-")[0]
    embed_method = ("_").join(method.split("_")[:-1])

    # Set cross-validation settings
    if task == "reg":
        cv = RepeatedKFold(n_splits=5, n_repeats=5, random_state=1234)
    else:
        cv = RepeatedStratifiedKFold(n_splits=5, n_repeats=5, random_state=1234)
    n_iter, n_jobs = 100, min(len(os.sched_getaffinity(0)), 32)

    # Load extracted embeddings and set split
    if "wikidata" in method:
        base_folder = path_configs["emb_linked_data_folder"]
        data_folder = glob(
            f"{base_folder}/{embed_method}/embedded_tables/*/{data_name}/"
        )[0]
        x_path = f"{data_folder}/X_emb.npy"
        y_path = f"{data_folder}/y.npy"
        data_X, data_y = np.load(x_path), np.load(y_path)
        duration_preprocess_total = 0
    else:
        emb_data_folder = path_configs["emb_linked_data_folder"] / embed_method
        emb_data_path = f"{emb_data_folder}/{data_name}.parquet"
        data = pd.read_parquet(emb_data_path)
        data_X = np.array(data.drop(columns=data_config["target"]))
        data_y = np.array(data[data_config["target"]])
        time_folder = (
            f"{path_configs['results']}/embed_extraction_time_linked/{embed_method}"
        )
        time_path = f"{time_folder}/{embed_method}|{data_name}.npy"
        duration_preprocess_total = np.load(time_path)

    stratify = data_y if "clf" in task else None
    num_test = min(1024, data_X.shape[0] - num_train)

    # Set split
    X_train, X_test, y_train, y_test = train_test_split(
        data_X,
        data_y,
        train_size=num_train,
        test_size=num_test,
        random_state=random_state,
        stratify=stratify,
    )

    duration_preprocess = (
        duration_preprocess_total
        * (X_train.shape[0] + X_test.shape[0])
        / data_X.shape[0]
    )

    # PCA for large features
    start_time = time.perf_counter()

    # Apply PCA for xgb
    if method.split("_")[-1] == "xgb-pca":
        if X_train.shape[1] > 300:
            n_components = min(X_train.shape[0], 300)
            pca = PCA(n_components=n_components, random_state=1234)
            X_train = pca.fit_transform(X_train)
            X_test = pca.transform(X_test)

    end_time = time.perf_counter()
    duration_preprocess += round(end_time - start_time, 4)

    # Hyperparmeter search
    start_time = time.perf_counter()

    best_params, cv_results = run_param_search(
        X_train,
        y_train,
        task,
        estim_method,
        cv,
        n_iter,
        n_jobs,
        scoring,
        device,
    )

    end_time = time.perf_counter()
    duration_param_search = round(end_time - start_time, 4)

    # Save cv results
    if cv_results is not None:
        cv_results.to_csv(log_path, index=False)

    # Final fit and predict
    start_time = time.perf_counter()

    estimator = assign_estimator(
        estim_method,
        task,
        device,
        train_flag=True,
        best_params_estimator=best_params,
    )
    estimator.fit(X_train, y_train)
    y_prob, y_pred = calculate_output(X_test, estimator, task)

    # Reshape prediction
    if "clf" in task:
        y_prob = reshape_pred_output(y_prob)

    # Check the output
    if task == "reg":
        y_pred = check_pred_output(y_train, y_pred)

    # obtain scores
    score = return_score(y_test, y_prob, y_pred, task)

    end_time = time.perf_counter()
    duration_inference = round(end_time - start_time, 4)

    # Format the results
    results_ = dict()
    for i in range(len(result_criterion[:-4])):
        results_[result_criterion[i]] = score[i]
    results_[result_criterion[-4]] = duration_preprocess
    results_[result_criterion[-3]] = duration_param_search
    results_[result_criterion[-2]] = duration_inference
    results_[result_criterion[-1]] = (
        duration_preprocess + duration_param_search + duration_inference
    )
    results_model = pd.DataFrame([results_], columns=result_criterion)
    results_model["data_name"] = data_name
    results_model["method"] = method
    results_model["num_train"] = num_train
    results_model["random_state"] = random_state
    results_model["task"] = task

    # Save the results in csv
    results_model.to_csv(results_model_path, index=False)

    print(marker + " is complete")

    return None


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_name", type=str, required=True)
    parser.add_argument("--method", type=str, required=True)
    parser.add_argument("--num_train", type=int, required=True)
    parser.add_argument("--device", type=str, default="cpu")
    args = parser.parse_args()

    random_states = range(1, 11)

    for rs in random_states:
        run_model(
            data_name=args.data_name,
            num_train=args.num_train,
            method=args.method,
            random_state=rs,
            device=args.device,
        )
