#%%
from tree_depth import run
from data_loading import get_dataset
from model_training import train_all_models
from constants import y_test_parameters, y_train_parameters
from joblib.parallel import Parallel, delayed
from tqdm import tqdm
#%%

dataset_name = "sanity-check"
X_train, y_train, X_test, y_test = get_dataset(dataset_name)

#%%
model_types = {"f" : "hist_gradient_boosting",
               "f_star_train" : "hist_gradient_boosting",
               "f_star_test" : "hist_gradient_boosting"
               }

model_params = {"f": {},
                "f_star_train" : y_train_parameters["weather"]["HGB"],
                "f_star_test" : y_test_parameters["weather"]["HGB"]
                }


f, f_star_train, f_star_test, X_eval, y_eval = train_all_models(model_types, X_train, y_train, X_test, y_test, model_params)

#%%

# Do the values used for computation need to be simulated?
sim_y = True

# Do the trees get trained on the residuals?
train_residuals = True

# Run the experiment
# run(X_eval, y_eval, f.predict_proba(X_eval)[:,1], f_star_test, 42, f, 1500000, 20, train_residuals=train_residuals, sim_y=sim_y)
# %%

h = f.predict_proba(X_eval)[:, 1]

depths = range(1, 41)
seeds = range(4)

samples = [1000000, 1600000, 2000000]
for n_samples in samples:
    print(f"Running for {n_samples} samples")
    results = Parallel(n_jobs=30)(
    delayed(run)(
        X_eval,
        y_eval,
        h,
        f_star_test,
        seed,
        f,
        n_samples,
        depth,
        train_residuals=train_residuals,
        sim_y=sim_y
    )
    for depth, seed in tqdm([(d, s) for d in depths for s in seeds], desc="Processing")
)
# %%
