#%%
from data_loading import load_data, load_data_fair, get_dataset
from model_training import train_models, train_all_models
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from utils import gl_true, sampling
from glest.core import GLEstimator
from data_generation import generate_y

#%%

y_train_parameters = {
    "weather": {
        "HGB": {
            'learning_rate': 0.1,
            'max_depth': None,
            'max_iter': 200,
            'min_samples_leaf': 5
        }
    },
    "delivery-eta": {
        "HGB": {
            'learning_rate': 0.3, 
            'max_depth': None, 
            'max_iter': 200, 
            'min_samples_leaf': 20
        }
    },
    "cooking-time": {
        "HGB": {
            'learning_rate': 0.3, 
            'max_depth': 5, 
            'max_iter': 200, 
            'min_samples_leaf': 20
        }
    },
    "maps-routing": {
        "HGB": {
            'learning_rate': 0.01, 
            'max_depth': 3, 
            'max_iter': 100, 
            'min_samples_leaf': 10
        }
    }
}

y_test_parameters = {
    "weather": {
        "HGB": {
            'learning_rate': 0.1, 
            'max_depth': None, 
            'max_iter': 200, 
            'min_samples_leaf': 20
        }
    },
    "delivery-eta": {
        "HGB": {
            'learning_rate': 0.3, 
            'max_depth': 5, 
            'max_iter': 200, 
            'min_samples_leaf': 20
        }
    },
    "cooking-time": {
        "HGB": {
            'learning_rate': 0.3, 
            'max_depth': 5, 
            'max_iter': 200, 
            'min_samples_leaf': 20
        }
    },
    "maps-routing": {
        "HGB": {
            'learning_rate': 0.1, 
            'max_depth': 10, 
            'max_iter': 100, 
            'min_samples_leaf': 20
        }
    }
}

# %%
# model_type = "decision_tree"
# model_params = {
#     "max_depth": 5,
#     "min_samples_split": 2,
#     "min_samples_leaf": 1,
#     "max_features": None,
# }
# model = train_models(model_type, X_train, y_train, {})

# %%
# The entire pipeline

# 1. Load the data

dataset_name = "weather"

X_train, y_train, X_test, y_test = get_dataset(dataset_name)

# 2. Train the models f and f_star

model_types = {"f" : "hist_gradient_boosting",
               "f_star_train" : "hist_gradient_boosting",
               "f_star_test" : "hist_gradient_boosting"
               }

model_params = {"f": {},
                "f_star_train" : y_train_parameters[dataset_name]["HGB"],
                "f_star_test" : y_test_parameters[dataset_name]["HGB"]
                }
f, f_star_train, f_star_test, X_test_eval, y_test_eval = train_all_models(
    model_types, X_train, y_train, X_test, y_test, model_params)

#%%
# 4. Evaluate the model f


## Get the true GL
y_test_eval_sim = generate_y(f_star_test, X_test_eval, method="predict_proba")
true_gl = gl_true(f, f_star_test, X_test_eval, y_test_eval_sim, n_bins=1000, strategy='quantile')


# Compute the estimated GL
solver = DecisionTreeRegressor(max_depth=None)


# Samples used for both tree fitting and gl estimation
n_samples = 0.5
X_sampled_eval, y_sampled_eval = sampling(X_test_eval, y_test_eval, proportion=n_samples)

X_eval_train, X_eval_test, y_eval_train, y_eval_test = train_test_split(
    X_sampled_eval, y_sampled_eval, test_size=0.5, random_state=42
)
solver.fit(X_eval_train, y_eval_train)

partition = solver.apply(X_eval_test)


glest = GLEstimator(f, None)
glest.fit(X_eval_test, y_eval_test, partition = partition)
results = glest.metrics()


#%%

# Evaluate on the simulated data

X_sampled_eval_sim, y_sampled_eval_sim = sampling(X_test_eval, y_test_eval_sim, proportion=n_samples)
X_eval_train_sim, X_eval_test_sim, y_eval_train_sim, y_eval_test_sim = train_test_split(
    X_sampled_eval_sim, y_sampled_eval_sim, test_size=0.5, random_state=42
)
solver.fit(X_eval_train_sim, y_eval_train_sim)

partition_sim = solver.apply(X_eval_test_sim)
glest_sim = GLEstimator(f, None)
glest_sim.fit(X_eval_test_sim, y_eval_test_sim, partition = partition_sim)
results_sim = glest_sim.metrics()

#%%

print("Results on the true data",results)
print("Results on the simulated data",results_sim)
print("True GL", true_gl)

# %%
