#%%
from tree_depth import run
from data_loading import get_dataset, load_data_synthetic
from model_training import train_all_models
from constants import y_test_parameters, y_train_parameters
from joblib.parallel import Parallel, delayed
from tqdm import tqdm
from sklearn.calibration import calibration_curve, CalibratedClassifierCV
from sklearn.frozen import FrozenEstimator
from glest.helpers import bins_from_strategy
from sklearn.model_selection import train_test_split
import numpy as np
from glest.core import GLEstimator, GLEstimatorResiduals, Partitioner
from scipy.special import expit
from sklearn.tree import DecisionTreeRegressor
import os
import pandas as pd
#%%


class LogisticSigmoid:
    def __init__(self, w):
        self.w = w

    def predict_proba(self, X):
        return expit(np.dot(X, self.w))
    

def run_synthetic_experiment(n_samples=20000, n_dimensions=10, n_directions=3, ratio=2, var=1, alpha=1, depth=40, seed=42):
    X, Y, H, F, w = load_data_synthetic(n_samples, n_dimensions, n_directions, ratio, var, alpha)

    model = LogisticSigmoid(w)
    
    X_train, X_test, y_train, y_test, H_train, H_test, F_train, F_test = train_test_split(X, Y, H, F, test_size=0.5, random_state=seed)

    C, _ = calibration_curve(y_test, H_test, n_bins=1000, strategy="quantile")
    bins = bins_from_strategy(1000, "quantile", H_test)
    binids = np.searchsorted(bins[1:-1], H_test)
    C = C[binids]
    true_gl = 2 * np.mean(np.square(F_test - H_test))

    dt = DecisionTreeRegressor(max_depth=depth, min_samples_leaf=10)

    c_hat_train = H_train
    c_hat_test = H_test
    residuals_train = y_train - c_hat_train
    dt.fit(X_train, residuals_train)
    leaf_ids = dt.apply(X_test)

    # Residuals approach
    glest = GLEstimatorResiduals(model, None)
    glest.fit(X_test, y_test, y_scores_cal=c_hat_test, partition=leaf_ids)

    results_residuals = glest.metrics()
    results_residuals["config.depth"] = depth
    results_residuals["config.seed"] = seed
    results_residuals["config.n_samples"] = len(X_train)
    results_residuals["config.simulated"] = True
    results_residuals["true_gl"] = true_gl

    filename = '_'.join(str(v) for k, v in results_residuals.items() if k.startswith("config"))

    path = f"tree_depth/fully_synthetic_{n_dimensions}_{n_directions}_{ratio}_{var}_{alpha}/y_residuals/"
    if not os.path.exists(path):
        os.makedirs(path)
    pd.DataFrame(results_residuals, index=[0]).to_csv(f"{path}{filename}.csv", mode='a', header=False)

    # Direct approach
    glest2 = GLEstimator(model, partitioner=Partitioner(
        estimator=DecisionTreeRegressor(max_depth=depth, min_samples_leaf=10),
        predict_method='apply'
    ))

    #=================================================
    glest2.fit(X, Y)

    results = glest2.metrics()
    results["config.depth"] = depth
    results["config.seed"] = seed
    results["config.n_samples"] = len(X_train)
    results["config.simulated"] = True
    results["true_gl"] = true_gl
    filename = '_'.join(str(v) for k, v in results.items() if k.startswith("config"))

    path = f"tree_depth/fully_synthetic_{n_dimensions}_{n_directions}_{ratio}_{var}_{alpha}/y_perez/"
    if not os.path.exists(path):
        os.makedirs(path)
    pd.DataFrame(results, index=[0]).to_csv(f"{path}{filename}.csv", mode='a', header=False)
    
    # Return simple value to avoid serialization issues
    return {"completed": True, "n_samples": n_samples, "depth": depth, "seed": seed}

# %%

# Run experiments with different parameters
n_samples_list = [1600000, 3200000]
depth_list = list(range(1, 21))  # 1 to 20
seeds = [42, 123, 456, 789]  # 4 different seeds

# Use tqdm to show progress
# Set up parallel execution
n_jobs = 40  # Use all available cores
total_experiments = len(n_samples_list) * len(depth_list) * len(seeds)

# Create parameter list for parallel execution
param_list = []
for n_samples in n_samples_list:
    for depth in depth_list:
        for seed in seeds:
            param_list.append((n_samples, depth, seed))

# Run experiments in parallel with progress bar
results = Parallel(n_jobs=n_jobs, verbose=0, backend="multiprocessing")(
    delayed(run_synthetic_experiment)(
        n_samples=params[0],
        depth=params[1],
        seed=params[2]
    ) for params in tqdm(param_list, total=total_experiments)
)


# %%
