#%%
from tree_depth import run
from data_loading import get_dataset
from model_training import train_all_models
from constants import y_test_parameters, y_train_parameters
from joblib.parallel import Parallel, delayed
from tqdm import tqdm
import itertools
#%%

dataset_name = "weather"
X_train, y_train, X_test, y_test = get_dataset(dataset_name)

#%%
model_types = {"f" : "hist_gradient_boosting",
               "f_star_train" : "hist_gradient_boosting",
               "f_star_test" : "hist_gradient_boosting"
               }

model_params = {"f": {},
                "f_star_train" : y_train_parameters[dataset_name]["HGB"],
                "f_star_test" : y_test_parameters[dataset_name]["HGB"]
                }


f, f_star_train, f_star_test, X_eval, y_eval = train_all_models(model_types, X_train, y_train, X_test, y_test, model_params)

#%%

# Do the values used for computation need to be simulated?
sim_y = True

# Do the trees get trained on the residuals?
train_residuals = True

# Run the experiment
run(X_eval, y_eval, f.predict_proba(X_eval)[:,1], f_star_test, 42, f, 2000000, 20, train_residuals=train_residuals, sim_y=sim_y)
# %%

h = f.predict_proba(X_eval)[:, 1]
 #%%
depths = range(1, 21)
seeds = range(10)
results = Parallel(n_jobs=10)(
    delayed(run)(
        X_eval,
        y_eval,
        h,
        f_star_test,
        seed,
        f,
        100000,
        depth,
        train_residuals=train_residuals,
        sim_y=sim_y
    )
    for seed, depth in tqdm(itertools.product(seeds, depths), 
                                  desc=f'Processing n={100000}'))

# %%
