#%%
from tree_depth import run
from data_loading import get_dataset
from model_training import train_all_models
from constants import y_test_parameters, y_train_parameters
from joblib.parallel import Parallel, delayed
from tqdm import tqdm
#%%

dataset_name = "sanity-check2d"
X_train, y_train, X_test, y_test = get_dataset(dataset_name)

#%%
model_types = {"f" : "hist_gradient_boosting",
               "f_star_train" : "hist_gradient_boosting",
               "f_star_test" : "hist_gradient_boosting"
               }

model_params = {"f": {},
                "f_star_train" : y_train_parameters["weather"]["HGB"],
                "f_star_test" : y_test_parameters["weather"]["HGB"]
                }


f, f_star_train, f_star_test, X_eval, y_eval = train_all_models(model_types, X_train, y_train, X_test, y_test, model_params)

#%%

# Do the values used for computation need to be simulated?
sim_y = True

# Do the trees get trained on the residuals?
train_residuals = False

# Run the experiment
# run(X_eval, y_eval, f.predict_proba(X_eval)[:,1], f_star_test, 42, f, 1500000, 20, train_residuals=train_residuals, sim_y=sim_y)
# %%

h = f.predict_proba(X_eval)[:, 1]

depths = [None]
seeds = range(0,40)

# samples = [5000, 6000, 7000, 9000, 10000, 15000, 20000, 50000, 100000, 200000, 300000] 
samples = [4100]
# samples = [500000, 1000000, 1500000]
# samples = [50100, 50500, 51000, 52000, 53000, 55000, 57500, 60000, 70000, 80000, 100000]
# samples = [525000, 530000]
# samples = [500000]

results = Parallel(n_jobs=30)(delayed(run)(
        X_eval,
        y_eval,
        h,
        f_star_test,
        seed,
        f,
        n_samples,
        None,
        train_residuals=train_residuals,
        sim_y=sim_y
    )
for n_samples, seed in tqdm([(n, s) for n in samples for s in seeds], desc="Processing")
)
    
# %%
