#%%
from utils.contexts import Timer
from decision.xp.post_training import (
    GLAR,
    FineTuning,
    GLARThresholded,
    HistogramBinnin Recalibration,
    Identity,
    PartitionerDict,
    PlattBinnerRecalibration,
    PlattRecalibration,
    PostTraining,
    PostTrainingDict,
    SigmoidFineTuning,
    SklearnRecalibration,
    Stacking,
)
from decision.xp.test_xp import compute_metrics_residuals_normal, compute_metrics_residuals, metrics_to_df
import numpy as np
import pandas as pd
from decision.xp.data.base import ds_registry, ds_rename
from decision.xp.model.base import model_registry, model_rename
from decision.xp.common import (
    # fit_predict_clf,
    get_constant_utilty,
    # get_optimal_thresholds_norecal,
    # get_optimal_thresholds_recal,
    get_threshold_from_utility,
    # recalibrate_scores,
    u_emp_from_score,
)

import glest

from sklearn.linear_model import LogisticRegression 
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from decision.xp.regrets import compute_regret_CL, compute_regret_CL_normal, compute_regret_CL_beta, compute_regret_CL_truncated_normal
import matplotlib.pyplot as plt
#%%
ds_name = "hate_merged_en"
model_name = "cnerg1"
post_training_name = "recal_isotonic"

ds = ds_registry[ds_name]()
model = model_registry[model_name]()
ds_name2 = ds_rename[ds_name]
model_name2 = model_rename[model_name]
n_utilities = 100


post_training = PostTrainingDict()[post_training_name]
partitioners_dict = PartitionerDict()
partitioner_names = [
"depth10",
# "unconstrained",
]
partitioners = {n: partitioners_dict[n] for n in partitioner_names}

kwargs = dict(ds=ds_name, m=model_name, t=post_training_name)

rs = 0
finetuned = isinstance(post_training, FineTuning)

print(f"ds_name: {ds.ds_name}, post training {post_training}, rules {partitioners}")

(X, y, S, G), (idx_val1, idx_val2, idx_test) = ds.get_arrays(model, finetuned, rs)

idx_val = np.concatenate([idx_val1, idx_val2])

X_val = X[idx_val]
S_val = S[idx_val]
y_val = y[idx_val]
X_test = X[idx_test]
S_test = S[idx_test]
y_test = y[idx_test]

# We add some extra values of t we would like in the plots
t_target = [0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975, 0.99]
U = get_constant_utilty(n_utilities, t_target)  # (n_utilities, 2, 2)
t = get_threshold_from_utility(U)  # (n_utilities,)

# Apply post-training methods
if not finetuned:  # S is already post trained when finetuned is True
    with Timer("post_training_fit") as timer_post_training_fit:
        post_training.fit(S=S_val, y=y_val, X=X_val)
    with Timer("post_training_predict") as timer_post_training_predict:
        # print(S_test)
        # print("=======================")
        # print(X_test)
        Sp_test = post_training.predict_proba(S_test, X_test)

# Compute the utility associated with the post trained scores
u_test_emp = u_emp_from_score(Sp_test, y_test, t, U, return_action=False)
u_test_emp = u_test_emp.mean(axis=0)  # (n_utilities,)

df = pd.DataFrame({"u_test_emp_mean": u_test_emp}, index=t)
df.index.name = "t"

# Compute the residual metrics after post training
dfs_t = {}
dfs_one = {}
for rule_name, rule in partitioners.items():

    S = Sp_test[:, 0] if np.ndim(Sp_test) == 2 else Sp_test
    y = y_test
    X = X_test

    calibrated_classifier = LogisticRegression()
    X_train, X_test, y_train, y_test, S_train, S_test = train_test_split(
        X, y, S, test_size=0.5, random_state=0
    )

    X_train, X_cal, y_train, y_cal, S_train, S_cal = train_test_split(
        X_train, y_train, S_train, test_size=max(int(len(X_train) * 0.2),4000), random_state=0
    )

    calibrated_classifier.fit(S_cal.reshape(-1,1), y_cal)

    c_hat_train = calibrated_classifier.predict_proba(S_train.reshape(-1,1))[:, 1]
    c_hat_test = calibrated_classifier.predict_proba(S_test.reshape(-1,1))[:, 1]

    residuals_train = y_train - c_hat_train
    residuals_test = y_test - c_hat_test
    dt = DecisionTreeRegressor(max_depth = 10, min_samples_leaf= 10)
    dt.fit(X_train, residuals_train)
    leaf_ids = dt.apply(X_test)


    gle = glest.core.GLEstimatorResiduals(None, None)
    gle.fit(X_test, y_test, y_scores_cal = c_hat_test, partition = leaf_ids)
    # fig = gle.plot(fig_kw=dict(figsize=(2.5, 2.5)))

    r_hat = gle.honest_tree_pred
    var_x = gle.var_x
    # C, H = glest_calibration_curve(
    #     gle.frac_pos_, gle.counts_, gle.mean_scores_, remove_empty=False
    # )

    # bins = gle.partitioner.bins_
    # binids = np.searchsorted(bins[1:-1], S)

    a = c_hat_test[:, None] >= t[None, :]

    var_bis = np.zeros_like(c_hat_test)
    var_bis[:] = 0.01

    RGL = compute_regret_CL(c_hat_test + r_hat, t, a)  # (n, k)
    RGL = RGL.mean(axis=0)  # (k,)
    RGL_normal = compute_regret_CL_truncated_normal(c_hat_test, t, a, var_x, r_hat) # (n, k)
    RGL_normal = RGL_normal.mean(axis=0)  # (k,)
    print("Without normal correction ",RGL)


    print("With normal correction ", RGL_normal)
    print(t)
    plt.plot(t, RGL, label="RGL with determinate distribution")
    plt.plot(t, RGL_normal, label="RGL with normal distribution")
    plt.xlabel("t")
    plt.ylabel("RGL")
    plt.legend()
    plt.title(f"RGL for {ds_name2} - {model_name2} - {post_training_name}")
    plt.show()
    # metrics_one.update(timer_post_training_fit.to_dict())
    # metrics_one.update(timer_post_training_predict.to_dict())
    # df_t, df_one = metrics_to_df(metrics_t, metrics_one, t)
    # dfs_t[rule_name] = df_t
    # dfs_one[rule_name] = df_one

# %%
# metrics_t["RGL_residuals"]
# %%
