"""Common functions used for TabBack package."""

import json
import numpy as np
import pandas as pd

from glob import glob
from configs.path_configs import path_configs
from sklearn.model_selection import train_test_split


def load_raw_data(data_name):
    """Loads the locally saved raw data."""

    # Path of the folder containing data
    data_folder = path_configs["data"]

    # Dataset
    data_path = glob(f"{data_folder}/*/{data_name}*")[0]
    data = pd.read_parquet(data_path)
    data.fillna(value=np.nan, inplace=True)

    data = data.replace("\n", " ", regex=True)
    data = data.replace("\t", " ", regex=True)
    data = data.replace("\r", "", regex=True)

    # Configs
    config_path = f"{data_folder}/data_configs.json"
    filename = open(config_path)
    data_config = json.load(filename)
    filename.close()

    return data, data_config[data_name]


def load_data_config(data_name):
    """Loads the data configuration."""

    # Path of the folder containing data
    data_folder = path_configs["data"]

    # Configs
    config_path = f"{data_folder}/data_configs.json"
    filename = open(config_path)
    data_config = json.load(filename)
    filename.close()

    return data_config[data_name]


def load_linked_ent_data(data_name):
    """Loads the locally saved raw data."""

    # Path of the folder containing data
    data_folder_config = path_configs["data"]
    data_folder = path_configs["linked_ent_data_folder"]

    # Dataset
    data_path = glob(f"{data_folder}/*/{data_name}*")[0]
    data = pd.read_parquet(data_path)
    data.fillna(value=np.nan, inplace=True)

    data = data.replace("\n", " ", regex=True)
    data = data.replace("\t", " ", regex=True)
    data = data.replace("\r", "", regex=True)

    # Configs
    config_path = f"{data_folder_config}/data_configs.json"
    filename = open(config_path)
    data_config = json.load(filename)
    filename.close()

    return data, data_config[data_name]


def col_names_per_type(data, target_name):
    """Extract column names per type."""

    from skrub import to_datetime

    # Preprocess for Datetime information
    data_ = data.drop(columns=target_name)
    dat_col_names = []
    for col in data_:
        if pd.api.types.is_datetime64_any_dtype(to_datetime(data_[col])):
            dat_col_names.append(col)
    # Use original column names without lowercasing to avoid mismatches
    cat_col_names_ = data_.select_dtypes(include="object").columns.str.replace(
        "\n", " ", regex=True
    )
    cat_col_names = list(set(cat_col_names_) - set(dat_col_names))
    num_col_names_ = data_.select_dtypes(exclude="object").columns.str.replace(
        "\n", " ", regex=True
    )
    num_col_names = list(set(num_col_names_) - set(dat_col_names))
    return num_col_names, cat_col_names, dat_col_names


def set_split(data, data_config, num_train, random_state, extracted_emb=False):
    """Train/Test split of the data."""

    # Load data, Set preliminary
    target_name = data_config["target"]
    task = data_config["task"]

    # Set the data
    num_test = min(1024, data.shape[0] - num_train)
    if extracted_emb:
        data_cat = data.drop(columns=target_name)
    else:
        _, cat_col, _ = col_names_per_type(data, target_name)
        data_cat = data[cat_col].copy()

    y = data[target_name].copy()
    y = np.array(y)
    stratify = y if "clf" in task else None

    # Set split
    X_train, X_test, y_train, y_test = train_test_split(
        data_cat,
        y,
        train_size=num_train,
        test_size=num_test,
        random_state=random_state,
        stratify=stratify,
    )

    return X_train, X_test, y_train, y_test


def shorten_param(param_name):
    """Shorten the param_names for column names in search results."""

    if "__" in param_name:
        return param_name.rsplit("__", 1)[1]
    return param_name


def check_pred_output(y_train, y_pred):
    """Set the output as the mean of train data if it is nan."""

    if np.isnan(y_pred).sum() > 0:
        mean_pred = np.mean(y_train)
        y_pred[np.isnan(y_pred)] = mean_pred
    return y_pred


def reshape_pred_output(y_pred):
    """Reshape the predictive output accordingly."""

    from scipy.special import softmax

    num_pred = len(y_pred)
    if y_pred.shape == (num_pred, 2):
        y_pred = y_pred[:, 1]
    elif y_pred.shape == (num_pred, 1):
        y_pred = y_pred.ravel()
    else:
        if len(y_pred.shape) > 1:
            y_pred = softmax(y_pred, axis=1)
        else:
            pass
    return y_pred


def set_score_criterion(task):
    """Set scoring method for CV and score criterion in final result."""

    if task == "reg":
        scoring = "r2"
        score_criterion = ["r2", "rmse"]
    else:
        if task == "m-clf":
            scoring = "roc_auc_ovr"
        else:
            scoring = "roc_auc"
        score_criterion = [
            "roc_auc",
            "brier_score_loss",
            "f1_weighted",
        ]
    score_criterion += ["preprocess_time"]
    score_criterion += ["param_search_time"]
    score_criterion += ["inference_time"]
    score_criterion += ["run_time"]
    return scoring, score_criterion


def return_score(y_target, y_prob, y_pred, task):
    """Return score results for given task."""

    from sklearn.metrics import (
        r2_score,
        root_mean_squared_error,
        roc_auc_score,
        brier_score_loss,
        f1_score,
    )

    if task == "reg":
        score_r2 = r2_score(y_target, y_pred)
        score_rmse = root_mean_squared_error(y_target, y_pred)
        return score_r2, score_rmse
    else:
        if len(np.unique(y_target)) > 2:
            score_auc = roc_auc_score(
                y_target,
                y_prob,
                multi_class="ovr",
                average="macro",
            )
        else:
            score_auc = roc_auc_score(y_target, y_prob)
        score_brier = brier_score_loss(y_target, y_prob)
        score_f1 = f1_score(y_target, y_pred, average="weighted")
        return score_auc, score_brier, score_f1


def calculate_output(X_test, estimator, task):

    est_name = estimator.__class__.__name__
    if task == "reg":
        if est_name == "TabPFNRegressor":
            y_pred = _calculate_tabpfn_output(
                estimator,
                X_test,
                8192,
                task,
            )
        else:
            y_pred = estimator.predict(X_test)
            y_prob = None
    else:
        if est_name == "RidgeClassifierCV":
            y_prob = estimator.decision_function(X_test)
            y_prob = 1 / (1 + np.exp(-y_prob))
            y_pred = estimator.predict(X_test)
        else:
            if est_name == "TabPFNClassifier":
                y_pred = _calculate_tabpfn_output(
                    estimator,
                    X_test,
                    8192,
                    task,
                )
                y_prob = None
            else:
                y_prob = estimator.predict_proba(X_test)
                y_pred = estimator.predict(X_test)

    return y_prob, y_pred


def _calculate_tabpfn_output(estimator, X_test, batch_size, task):
    """Calculate tabpfn output for large datasets."""

    # Check with tabpfn repo for more efficient calculation of the output
    test_size = len(X_test)
    if test_size < batch_size:
        if task == "reg":
            y_pred = estimator.predict(X_test)
        else:
            y_pred = estimator.predict_proba(X_test)
    else:
        mok = test_size // batch_size
        if task == "reg":
            y_pred = np.empty(shape=(0,))
        else:
            y_pred = np.empty(shape=(0, 2))
        for x in range(mok):
            idx_1 = x * batch_size
            idx_2 = (x + 1) * batch_size
            if task == "reg":
                y_pred_ = estimator.predict(X_test[idx_1:idx_2])
                y_pred = np.hstack([y_pred, y_pred_])
            else:
                y_pred_ = estimator.predict_proba(X_test[idx_1:idx_2])
                y_pred = np.vstack([y_pred, y_pred_])
        if task == "reg":
            y_pred_ = estimator.predict(X_test[idx_2:])
            y_pred = np.hstack([y_pred, y_pred_])
        else:
            y_pred_ = estimator.predict_proba(X_test[idx_2:])
            y_pred = np.vstack([y_pred, y_pred_])
    return y_pred


def assign_estimator(
    estim_method,
    task,
    device,
    train_flag,
    best_params_estimator={},
):
    """Assign the specific estimator to train model."""

    if estim_method == "xgb":

        # from src.gbdt_bagging_es import XGBRegressor_ESB, XGBClassifier_ESB

        # fixed_params = dict()
        # if train_flag:
        #     fixed_params["num_model"] = 1
        # else:
        #     fixed_params["num_model"] = 1
        # fixed_params["random_state"] = 20122024
        # fixed_params["n_jobs"] = -1
        # fixed_params["early_stopping_patience"] = 300
        # fixed_params["n_estimators"] = 1000
        # if task == "reg":
        #     estimator_ = XGBRegressor_ESB(**fixed_params, **best_params_estimator)
        # else:
        #     estimator_ = XGBClassifier_ESB(**fixed_params, **best_params_estimator)

        from xgboost import XGBRegressor, XGBClassifier

        fixed_params = dict()
        if task == "reg":
            estimator_ = XGBRegressor(**fixed_params, **best_params_estimator)
        else:
            estimator_ = XGBClassifier(**fixed_params, **best_params_estimator)

    elif estim_method == "histgb":

        from sklearn.ensemble import (
            HistGradientBoostingRegressor,
            HistGradientBoostingClassifier,
        )

        fixed_params = dict()
        fixed_params["early_stopping"] = True
        fixed_params["n_iter_no_change"] = 50
        if task == "reg":
            estimator_ = HistGradientBoostingRegressor(
                **fixed_params,
                **best_params_estimator,
            )
        else:
            estimator_ = HistGradientBoostingClassifier(
                **fixed_params,
                **best_params_estimator,
            )

    elif estim_method == "randomforest":

        from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier

        fixed_params = dict()
        if task == "reg":
            estimator_ = RandomForestRegressor(**fixed_params, **best_params_estimator)
        else:
            estimator_ = RandomForestClassifier(**fixed_params, **best_params_estimator)

    elif estim_method == "ridge":

        from sklearn.linear_model import RidgeCV, RidgeClassifierCV

        fixed_params = dict()
        fixed_params["alphas"] = [1e-2, 1e-1, 1, 10, 100]
        if task == "reg":
            estimator_ = RidgeCV(**fixed_params)
        else:
            estimator_ = RidgeClassifierCV(**fixed_params)

    elif estim_method == "tabpfn":

        from tabpfn import TabPFNRegressor, TabPFNClassifier

        fixed_params = dict()
        fixed_params["device"] = device
        if task == "reg":
            estimator_ = TabPFNRegressor(**fixed_params)
        else:
            estimator_ = TabPFNClassifier(**fixed_params)

    elif estim_method == "realmlp":

        import uuid
        from pytabkit import RealMLP_HPO_Regressor, RealMLP_HPO_Classifier

        fixed_params = dict()
        fixed_params["device"] = device
        fixed_params["n_cv"] = 5
        fixed_params["n_repeats"] = 3
        fixed_params["n_hyperopt_steps"] = 50
        fixed_params["random_state"] = 20122024
        fixed_params["tmp_folder"] = "./pytabkit/" + str(uuid.uuid4())
        if task == "reg":
            estimator_ = RealMLP_HPO_Regressor(**fixed_params, **best_params_estimator)
        else:
            fixed_params["val_metric_name"] = "1-auc_ovr"
            estimator_ = RealMLP_HPO_Classifier(**fixed_params, **best_params_estimator)

    elif estim_method == "tabm":

        import uuid
        from pytabkit import TabM_HPO_Regressor, TabM_HPO_Classifier

        fixed_params = dict()
        fixed_params["device"] = device
        fixed_params["n_cv"] = 5
        fixed_params["n_repeats"] = 3
        fixed_params["n_hyperopt_steps"] = 50
        fixed_params["random_state"] = 20122024
        fixed_params["tmp_folder"] = "./pytabkit/" + str(uuid.uuid4())
        if task == "reg":
            estimator_ = TabM_HPO_Regressor(**fixed_params, **best_params_estimator)
        else:
            fixed_params["val_metric_name"] = "1-auc_ovr"
            estimator_ = TabM_HPO_Classifier(**fixed_params, **best_params_estimator)

    return estimator_


def _clean_entity_names(data_entity_name, lowercase=False):
    """Function to clean strings."""

    data_entity_name = (
        data_entity_name.str.replace("<", "")
        .str.replace(">", "")
        .str.replace("\n", "")
        .str.replace("_", " ")
    )
    data_entity_name.index = (
        data_entity_name.index.str.replace("<", "")
        .str.replace(">", "")
        .str.replace("\n", "")
        .str.replace("_", " ")
    )
    if lowercase:
        data_entity_name = data_entity_name.str.lower()
        data_entity_name.index = data_entity_name.index.str.lower()
    return data_entity_name
