from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeRegressor
from sklearn.calibration import calibration_curve, CalibratedClassifierCV
from sklearn.frozen import FrozenEstimator
from glest.helpers import bins_from_strategy
from glest.core import GLEstimator, GLEstimatorResiduals, Partitioner
from data_generation import generate_y
import os
from scipy.interpolate import griddata
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

RESULTPATH = 'tree_no_depth_limit_minsample15_calbinning20percentall/sanity-check2d/y_residuals'



def run(X, y, h, f_star, seed, model, n_samples, depth, train_residuals=True, sim_y=False, sim_method="predict_proba"):
    
    X_train, X_test, y_train, y_test, h_train, h_test = train_test_split(
        X, y, h, test_size=0.5, random_state=seed
    )
    indices = np.random.default_rng(seed = seed).choice(len(X_train), size=n_samples, replace=False)
    X_train = X_train[indices]
    y_train = y_train[indices]
    h_train = h_train[indices]

    indices_test = np.random.default_rng(seed = seed).choice(len(X_test), size=n_samples, replace=False)
    X_test = X_test[indices_test]
    y_test = y_test[indices_test]
    h_test = h_test[indices_test]

    results = run_single_iteration(depth, seed, f_star, model, X_train, X_test, y_train, y_test, h_train, train_residuals, sim_method, sim_y=sim_y)
    return results


def run_single_iteration(depth, seed, f_star, model, X_train, X_test, y_train, y_test, h_train, train_residuals=True, sim_method="predict_proba", sim_y=False):
    """
    Run a single iteration of the experiment.
    Args:
        depth (int): The maximum depth of the decision tree.
        seed (int): The random seed for reproducibility.
        f_star (Any): The model to use for generating y.
        model (str): The type of model to use.
        X_train (np.ndarray): The training data.
        X_test (np.ndarray): The test data.
        y_train (np.ndarray): The training labels.
        y_test (np.ndarray): The test labels.
        h_train (np.ndarray): The training predictions.
    """
    f_star_name = f_star.__class__.__name__
    if sim_y:
        y_train_sim = generate_y(f_star, X_train, method=sim_method)
        y_test_sim = generate_y(f_star, X_test, method=sim_method)

        F = f_star.predict_proba(X_test)[:, 1]
        H = model.predict_proba(X_test)[:, 1]
        # C, _ = calibration_curve(y_test_sim, H, n_bins=500, strategy="quantile")
        # bins = bins_from_strategy(500, "quantile", H)
        # binids = np.searchsorted(bins[1:-1], H)
        # C = C[binids]
        # # C = CEstimator(H, y_test_sim).c_hat()
        # true_gl = 2 * np.mean(np.square(F - C))
        # print(f"True GL: {true_gl}")
        # print(y_train.shape)
        if train_residuals:
            get_results_residuals(depth, seed, model, X_train, X_test, y_train_sim, y_test_sim, h_train, sim_y=sim_y, sim_method=sim_method)
        else:
            get_gl_results(depth, seed, model, X_train, X_test, y_train_sim, y_test_sim, h_train, sim_y=sim_y, sim_method=sim_method, f_star_name=f_star_name)
        return 0
    else:
        if train_residuals:
            get_results_residuals(depth, seed, model, X_train, X_test, y_train, y_test, h_train, sim_y=sim_y, sim_method=sim_method)
        else:
            get_gl_results(depth, seed, model, X_train, X_test, y_train, y_test, h_train, sim_y=sim_y, sim_method=sim_method)
        return 0


def get_gl_results(depth, seed, model, X_train, X_test, y_train, y_test, h_train, sim_y, sim_method, f_star_name=None):
    # dt = DecisionTreeRegressor(max_depth=depth, min_samples_leaf=10)

    # C, _ = calibration_curve(y_train, h_train, n_bins=15, strategy="quantile")
    # bins = bins_from_strategy(15, "quantile", h_train)
    # binids = np.searchsorted(bins[1:-1], h_train)
    # C = C[binids]
    # residuals_train = y_train - C


    # dt.fit(X_train, y_train)
    # leaf_ids = dt.apply(X_test)
    # unique_leaves, leaf_counts = np.unique(leaf_ids, return_counts=True)
    # mean_samples = np.mean(leaf_counts)
    partitioner = Partitioner(
    estimator=DecisionTreeRegressor(max_depth=depth, min_samples_leaf=15),
    predict_method='apply'
    )
    train_size = len(X_train)
    X = np.concatenate([X_train, X_test])
    y = np.concatenate([y_train, y_test])
    glest = GLEstimator(model, partitioner = partitioner, train_size = train_size)
    glest.fit(X, y)

    results = glest.metrics()
    # results["mean_samples"] = mean_samples
    results["config.depth"] = depth
    results["config.seed"] = seed
    results["config.n_samples"] = train_size
    results["config.simulated"] = sim_y

    filename = '_'.join(str(v) for k, v in results.items() if k.startswith("config"))
    if sim_y:
        path = f"{RESULTPATH}_{sim_method}_{f_star_name}/"
    else:
        path = f"{RESULTPATH}_real/"

    if not os.path.exists(path):
        os.makedirs(path)
    pd.DataFrame(results, index=[0]).to_csv(f"{path}{filename}.csv", mode='a', header=False)


def get_results_residuals(depth, seed, model, X_train, X_test, y_train, y_test, h_train, sim_y, sim_method, plot=False):
    dt = DecisionTreeRegressor(max_depth=depth, min_samples_leaf=15)
    
    calibrated_classifier = CalibratedClassifierCV(estimator = FrozenEstimator(model), method="sigmoid")
    test_size = max(int(len(X_train) * 0.2), 4000)
    # test_size = 0.2

    # train_size = 50000
    X_train, X_cal, y_train, y_cal = train_test_split(X_train, y_train, test_size=test_size, random_state=seed)

    # def get_calibrated_probability(f_scores, c_hat_cal, f_cal):
    #     """
    #     Get calibrated probabilities by mapping f_scores to the appropriate calibrated bin.
        
    #     Args:
    #         f_scores: Array of uncalibrated probability scores
    #         c_hat_cal: Calibrated probabilities for each bin
    #         f_cal: Bin boundaries (predicted probabilities)
        
    #     Returns:
    #         Array of calibrated probabilities
    #     """
    #     calibrated_probs = np.zeros_like(f_scores)
        
    #     for i, score in enumerate(f_scores):
    #         # Find the bin this score belongs to
    #         bin_idx = np.searchsorted(f_cal, score, side='right') - 1
    #         bin_idx = np.clip(bin_idx, 0, len(c_hat_cal) - 1)
    #         calibrated_probs[i] = c_hat_cal[bin_idx]
        
    #     return calibrated_probs
    
    # c_hat_cal, f_cal = calibration_curve(y_cal, model.predict_proba(X_cal)[:, 1], n_bins=test_size//1000, strategy="quantile")
    calibrated_classifier.fit(X_cal, y_cal)
    c_hat_train = calibrated_classifier.predict_proba(X_train)[:, 1]
    c_hat_test = calibrated_classifier.predict_proba(X_test)[:, 1]

    # c_hat_train = get_calibrated_probability(model.predict_proba(X_train)[:, 1], c_hat_cal, f_cal)
    # c_hat_test = get_calibrated_probability(model.predict_proba(X_test)[:, 1], c_hat_cal, f_cal)
    residuals_train = y_train - c_hat_train
    dt.fit(X_train, residuals_train)
    leaf_ids = dt.apply(X_test)


    glest = GLEstimatorResiduals(model, None)
    glest.fit(X_test, y_test, y_scores_cal = c_hat_test, partition = leaf_ids)

    results = glest.metrics()
    # results["mean_samples"] = 1
    results["config.depth"] = depth
    results["config.seed"] = seed
    results["config.n_samples"] = len(X_train) + len(X_cal) #TODO don't forget to change here if need be
    results["config.simulated"] = sim_y
    filename = '_'.join(str(v) for k, v in results.items() if k.startswith("config"))
    if sim_y:
        path = f"{RESULTPATH}_{sim_method}/"
    else:
        path = f"{RESULTPATH}_real/"
    if not os.path.exists(path):
        os.makedirs(path)
    pd.DataFrame(results, index=[0]).to_csv(f"{path}{filename}.csv", mode='a', header=False)

    #