#%%
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, HistGradientBoostingClassifier
from joblib.memory import Memory

memory = Memory(location="joblib", verbose=10)

#%%


def train_models(model_type :str, X_train, y_train, model_params):
    """
    Train a model with the given parameters and data.

    Args:
        model_type (str): The type of model to train.
        data (dict[str, np.ndarray]): The training data.
        model_params (dict[str, Any]): The parameters for the model.

    Returns:
        Any: The trained model.
    """
    # Placeholder for actual training logic

    if model_type not in ["decision_tree", "random_forest", "hist_gradient_boosting"]:
        raise ValueError(f"Unsupported model type: {model_type}")
    elif model_type == "decision_tree":
        model = DecisionTreeClassifier(**model_params)
    elif model_type == "random_forest":
        model = RandomForestClassifier(**model_params)
    elif model_type == "hist_gradient_boosting":
        model = HistGradientBoostingClassifier(**model_params)
    
    model.fit(X_train, y_train)
    
    return model 


@memory.cache
def train_all_models(model_types, X_train, y_train, X_test, y_test, model_params):

    ## Fit on the train part
    X_train_f, X_train_star, y_train_f, y_train_star = train_test_split(
    X_train, y_train, test_size=0.5, random_state=42
    )

    ## f_star

    model_type_f_star =  model_types["f_star_train"]
    model_params_f_star_train = model_params["f_star_train"]
    f_star_train = train_models(model_type_f_star, X_train_star, y_train_star, model_params_f_star_train)

    ## f

    model_type_f = model_types["f"]
    model_params_f = model_params["f"]
    f = train_models(model_type_f, X_train_f, y_train_f, model_params_f)

    ## Fit on the test part + retrieving the evaluation set

    X_test_f_star, X_test_eval ,y_test_f_star, y_test_eval = train_test_split(
    X_test, y_test, test_size=0.5, random_state=42
    )

    ## f_star_test

    model_type_f_star_test = model_types["f_star_test"]
    model_params_f_star_test = model_params["f_star_test"]
    f_star_test = train_models(model_type_f_star_test, X_test_f_star, y_test_f_star, model_params_f_star_test)

    return f, f_star_train, f_star_test, X_test_eval, y_test_eval
