#%%
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from folktexts.acs import ACSDataset
from sklearn.tree import DecisionTreeRegressor
import glest
from folktexts.acs import ACSTaskMetadata
from deferral_experiment.regret_helpers import compute_regret_CL, get_constant_utilty, get_threshold_from_utility
from utils import honest_tree_pred
from sklearn.metrics import roc_auc_score, accuracy_score, brier_score_loss
import seaborn as sns
from itertools import combinations
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsRegressor
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.neural_network import MLPRegressor, MLPClassifier
#%%


llama1 = pd.read_csv('deferral_experiment/llama1_instruct.csv')
llama3 = pd.read_csv('deferral_experiment/llama3_instruct.csv')
llama8 = pd.read_csv('deferral_experiment/llama8_instruct.csv')
llama70 = pd.read_csv('deferral_experiment/llama70_instruct.csv')
phi4 = pd.read_csv('deferral_experiment/phi4_instruct.csv')
gemma27 = pd.read_csv('deferral_experiment/gemma27_instruct.csv')
mixtral8x7b = pd.read_csv('deferral_experiment/mixtral8x7b_instruct.csv')
embeddings = np.load('deferral_experiment/sentence_embeddings_MiniLM_L12_v2.npy')
#%%

costs_per_model = {
    'Llama 1B': 0.04,
    'Llama 3B': 0.06,
    'Llama 8B': 0.18,
    'Llama 70B': 0.88,
    'Gemma 27B' : 0.25,
    'Mixtral8x7B': 0.60,
    'Phi 4' : 0.22
}

costs_of_all_models = {'Llama 1B' : 0.04 * len(llama1)//10,
    'Llama 3B': 0.06 * len(llama3)//10,
    'Llama 8B': 0.18 * len(llama8)//10,
    'Llama 70B': 0.88 * len(llama70)//10,
    'Gemma 27B' : 0.25 * len(gemma27)//10,
    'Mixtral8x7B': 0.70 * len(mixtral8x7b)//10,
    'phi4' : 0.22 * len(phi4)//10
}

#%%

# Process all models
models = {'Llama 1B' : llama1, 'Llama 3B': llama3, 'Llama 8B': llama8, 'Llama 70B': llama70, 'Gemma 27B': gemma27, 'Mixtral8x7B': mixtral8x7b, 'Phi 4': phi4}
results = {}
seed = 0

# Store results for multiple seeds
all_seeds_results = {}
seeds = range(5)  # You can modify this list of seeds

for seed in seeds:
    print(f"\n=== Processing seed {seed} ===")
    seed_results = {}
    
    for model_name, model_df in models.items():
        print(f"Processing {model_name} with seed {seed}...")
        
        X = model_df.drop(columns=['risk_score', 'label']).values
        y = model_df['label'].values
        S = model_df['risk_score'].values
        
        t_target = [0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975, 0.99]
        U = get_constant_utilty(100, t_target)  # (n_utilities, 2, 2)
        t = get_threshold_from_utility(U) 

        calibrated_classifier = LogisticRegression()

        X, X_leftover, y, y_leftover, S, S_leftover, embeddings_kept, embeddings_leftover = train_test_split(
            X, y, S, embeddings, test_size=0.1, random_state=seed
        )



        # Check if predictions already exist for this seed
        predictions_file_mlp = f'deferral_experiment/baselines/{model_name}_MLP_correct_predictions_seed_{seed}.npy'
        predictions_file_knn = f'deferral_experiment/baselines/{model_name}_KNN_correct_predictions_seed_{seed}.npy'
        
        print(f"Training new model for {model_name}, seed {seed}")
        # nn = KNeighborsRegressor(n_neighbors=40, metric = "cosine", n_jobs=20)
        mlp = MLPRegressor(hidden_layer_sizes=(100,100), max_iter=200, random_state=None, verbose=True, batch_size=512)
        correct = (y == (S >= 0.5).astype(int)).astype(int)
        mlp.fit(embeddings_kept, correct)
        correct_pred_leftover_mlp = mlp.predict(embeddings_leftover)

        nn = KNeighborsRegressor(n_neighbors=40, metric = "cosine", n_jobs=40)
        nn.fit(embeddings_kept, correct)
        correct_pred_leftover_knn = nn.predict(embeddings_leftover)
        print(f"Finished training new model for {model_name}, seed {seed}")
        # Save correct predictions for this seed
        np.save(predictions_file_mlp, correct_pred_leftover_mlp)
        np.save(predictions_file_knn, correct_pred_leftover_knn)


        X_train, X_test, y_train, y_test, S_train, S_test = train_test_split(
            X, y, S, test_size=0.5, random_state=seed
        )

        X_train, X_cal, y_train, y_cal, S_train, S_cal = train_test_split(
            X_train, y_train, S_train, test_size=max(int(len(X_train) * 0.2),4000), random_state=seed
        )

        calibrated_classifier.fit(S_cal.reshape(-1,1), y_cal)

        c_hat_train = calibrated_classifier.predict_proba(S_train.reshape(-1,1))[:, 1]
        c_hat_test = calibrated_classifier.predict_proba(S_test.reshape(-1,1))[:, 1]

        residuals_train = y_train - c_hat_train
        residuals_test = y_test - c_hat_test
        dt = DecisionTreeRegressor(max_depth = None, min_samples_leaf= 15)
        dt.fit(X_train, residuals_train)
        leaf_ids = dt.apply(X_test)

        gle = glest.core.GLEstimatorResiduals(None, None)
        gle.fit(X_test, y_test, y_scores_cal = c_hat_test, partition = leaf_ids)

        c_hat_leftover = calibrated_classifier.predict_proba(S_leftover.reshape(-1,1))[:, 1]
        r_hat_leftover = honest_tree_pred(dt, gle.honest_rj, X_leftover)

        t = 0.5
        a = (S_leftover[:, None] >= t).astype(int)
        RCL = compute_regret_CL(c_hat_leftover, t, a)  # (n, k)

        a = (c_hat_leftover[:, None] >= t).astype(int)  # (n, k)
        RGL = compute_regret_CL(c_hat_leftover + r_hat_leftover, t, a)  # (n, k)
        
        # Store results for this seed
        seed_results[model_name] = {
            'X_test': X_leftover,
            'y_test': y_leftover, 
            'S_test': S_leftover,
            'embeddings_test': embeddings_leftover,
            'c_hat_test': c_hat_leftover,
            'r_hat': r_hat_leftover,
            'RCL': RCL,
            'RGL': RGL,
            'gle': gle,
            'tree': dt,
            'knn': nn
        }
    
    # Store results for this seed
    all_seeds_results[seed] = seed_results
