#%%
 
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from folktexts.acs import ACSDataset
from sklearn.tree import DecisionTreeRegressor
import glest
from folktexts.acs import ACSTaskMetadata
from deferral_experiment.regret_helpers import compute_regret_CL, get_constant_utilty, get_threshold_from_utility
from utils import honest_tree_pred
from sklearn.metrics import roc_auc_score, accuracy_score, brier_score_loss
import seaborn as sns
from itertools import combinations
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsRegressor
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.neural_network import MLPRegressor, MLPClassifier
#%%


llama1 = pd.read_csv('deferral_experiment/llama1_instruct.csv')
llama3 = pd.read_csv('deferral_experiment/llama3_instruct.csv')
llama8 = pd.read_csv('deferral_experiment/llama8_instruct.csv')
llama70 = pd.read_csv('deferral_experiment/llama70_instruct.csv')
phi4 = pd.read_csv('deferral_experiment/phi4_instruct.csv')
gemma27 = pd.read_csv('deferral_experiment/gemma27_instruct.csv')
mixtral8x7b = pd.read_csv('deferral_experiment/mixtral8x7b_instruct.csv')
embeddings = np.load('deferral_experiment/sentence_embeddings_MiniLM_L12_v2.npy')
#%%

costs_per_model = {
    'Llama 1B': 0.04,
    'Llama 3B': 0.06,
    'Llama 8B': 0.18,
    'Llama 70B': 0.88,
    'Gemma 27B' : 0.25,
    'Mixtral8x7B': 0.60,
    'Phi 4' : 0.22
}

costs_of_all_models = {'Llama 1B' : 0.04 * len(llama1)//10,
    'Llama 3B': 0.06 * len(llama3)//10,
    'Llama 8B': 0.18 * len(llama8)//10,
    'Llama 70B': 0.88 * len(llama70)//10,
    'Gemma 27B' : 0.25 * len(gemma27)//10,
    'Mixtral8x7B': 0.70 * len(mixtral8x7b)//10,
    'phi4' : 0.22 * len(phi4)//10
}

#%%

# Process all models
models = {'Llama 1B' : llama1, 'Llama 3B': llama3, 'Llama 8B': llama8, 'Llama 70B': llama70, 'Gemma 27B': gemma27, 'Mixtral8x7B': mixtral8x7b, 'Phi 4': phi4}
results = {}
seed = 0

# Store results for multiple seeds
all_seeds_results = {}
seeds = range(5)  # You can modify this list of seeds

for seed in seeds:
    print(f"\n=== Processing seed {seed} ===")
    seed_results = {}
    
    for model_name, model_df in models.items():
        print(f"Processing {model_name} with seed {seed}...")
        
        # Keep only half of the samples at random
        # half_size = len(model_df)
        # model_df_half = model_df.sample(n=half_size, random_state=seed)
        # embeddings_half = embeddings[model_df_half.index]
        X = model_df.drop(columns=['risk_score', 'label']).values
        y = model_df['label'].values
        S = model_df['risk_score'].values
        
        t_target = [0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975, 0.99]
        U = get_constant_utilty(100, t_target)  # (n_utilities, 2, 2)
        t = get_threshold_from_utility(U) 

        calibrated_classifier = LogisticRegression()

        X_kept, X_leftover, y_kept, y_leftover, S_kept, S_leftover, embeddings_kept, embeddings_leftover = train_test_split(
            X, y, S, embeddings, test_size=0.1, random_state=seed
        )


        correct = (y_kept == (S_kept >= 0.5).astype(int)).astype(int)
        dtcorrect = DecisionTreeRegressor(max_depth = None, min_samples_leaf= 15)

        # dtcorrect = KNeighborsRegressor(n_neighbors=40, metric='cosine', n_jobs=40)
        dtcorrect.fit(X_kept, correct)

        mlp_pred = np.load(f'deferral_experiment/baselines/{model_name}_MLP_correct_predictions_seed_{seed}.npy')
        knn_pred = np.load(f'deferral_experiment/baselines/{model_name}_KNN_correct_predictions_seed_{seed}.npy')

        X_train, X_test, y_train, y_test, S_train, S_test = train_test_split(
            X_kept, y_kept, S_kept, test_size=0.5, random_state=seed
        )

        X_train, X_cal, y_train, y_cal, S_train, S_cal = train_test_split(
            X_train, y_train, S_train, test_size=max(int(len(X_train) * 0.2),4000), random_state=seed
        )

        calibrated_classifier.fit(S_cal.reshape(-1,1), y_cal)

        c_hat_train = calibrated_classifier.predict_proba(S_train.reshape(-1,1))[:, 1]
        c_hat_test = calibrated_classifier.predict_proba(S_test.reshape(-1,1))[:, 1]

        residuals_train = y_train - c_hat_train
        residuals_test = y_test - c_hat_test
        dt = DecisionTreeRegressor(max_depth = None, min_samples_leaf= 15)
        dt.fit(X_train, residuals_train)
        leaf_ids = dt.apply(X_test)

        gle = glest.core.GLEstimatorResiduals(None, None)
        gle.fit(X_test, y_test, y_scores_cal = c_hat_test, partition = leaf_ids)

        c_hat_leftover = calibrated_classifier.predict_proba(S_leftover.reshape(-1,1))[:, 1]
        r_hat_leftover = honest_tree_pred(dt, gle.honest_rj, X_leftover)

        t = 0.5
        a = (S_leftover[:, None] >= t).astype(int)
        RCL = compute_regret_CL(c_hat_leftover, t, a)  # (n, k)

        a = (c_hat_leftover[:, None] >= t).astype(int)  # (n, k)
        RGL = compute_regret_CL(c_hat_leftover + r_hat_leftover, t, a)  # (n, k)
        
        # Store results for this seed
        seed_results[model_name] = {
            'X_test': X_leftover,
            'y_test': y_leftover, 
            'S_test': S_leftover,
            'embeddings_test': X_leftover,
            'c_hat_test': c_hat_leftover,
            'r_hat': r_hat_leftover,
            'RCL': RCL,
            'RGL': RGL,
            'gle': gle,
            'tree': dt,
            'knn': dtcorrect,
            'mlp_predictions': mlp_pred,
            'knn_predictions': knn_pred
        }
    
    # Store results for this seed
    all_seeds_results[seed] = seed_results

# Set results to the first seed for backward compatibility
results = all_seeds_results[seeds[0]]

#%%

all_seeds_results
# %%
def cascade_eval_rcl(cascade, costs_model, results, threshold=0.05):
    list_models = []
    for model_name in cascade:
        if model_name not in list_models:
            list_models.append(model_name)
    y_test = results[list_models[0]]['y_test']
    cascade_predictions = np.full(len(results[list_models[0]]["S_test"]), np.nan)
    cascade_model_used = np.full(len(results[list_models[0]]["S_test"]), -1)
    total_cost = 0
    
    for j in range(len(y_test)):
        model_accepted = False

        # Try models in the cascade order
        for model_name in cascade:
            # Calculate regret for this sample and model
            model_regret = results[model_name]['RCL'][j]
            total_cost += costs_model[model_name]
            # If regret is below threshold, accept this model's prediction
            if model_regret <= threshold:
                cascade_predictions[j] = results[model_name]['S_test'][j]
                cascade_model_used[j] = list_models.index(model_name)
                model_accepted = True
                break
        
        # If no model meets threshold, choose model with lowest regret
        if not model_accepted:
            min_regret = float('inf')
            best_model_idx = -1
            
            for i, model_name in enumerate(cascade):
                total_regret = results[model_name]['RCL'][j] + results[model_name]['RGL'][j]
                if total_regret < min_regret:
                    min_regret = total_regret
                    best_model_idx = i
            
            if best_model_idx != -1:
                cascade_predictions[j] = results[cascade[best_model_idx]]['S_test'][j]
                cascade_model_used[j] = best_model_idx
                # total_cost += costs_model[cascade[best_model_idx]]
    
    # Calculate performance metrics
    valid_predictions = ~np.isnan(cascade_predictions)
    if np.sum(valid_predictions) > 0:
        accuracy = accuracy_score(y_test[valid_predictions], 
                                (cascade_predictions[valid_predictions] >= 0.5).astype(int))
        avg_cost = total_cost / len(y_test)
    else:
        accuracy = 0
        avg_cost = 0
    
    return avg_cost, accuracy



def cascade_eval_rcl_rgl(cascade, costs_model, results, threshold=0.05):
    list_models = []
    for model_name in cascade:
        if model_name not in list_models:
            list_models.append(model_name)

    y_test = results[list_models[0]]['y_test']
    cascade_predictions = np.full(len(y_test), np.nan)
    cascade_model_used = np.full(len(y_test), -1)
    total_cost = 0
    
    for j in range(len(y_test)):
        model_accepted = False
        # Try models in the cascade order
        for model_name in cascade:
            # Calculate regret for this sample and model (RCL + RGL)
            model_regret = results[model_name]['RCL'][j] + results[model_name]['RGL'][j]
            total_cost += costs_model[model_name]
            # If regret is below threshold, accept this model's prediction
            if model_regret <= threshold:
                cascade_predictions[j] = results[model_name]['S_test'][j]
                cascade_model_used[j] = list_models.index(model_name)
                # total_cost += costs_model[model_name]
                model_accepted = True
                break
        
        # If no model meets threshold, choose model with lowest regret
        if not model_accepted:
            min_regret = float('inf')
            best_model_idx = -1
            
            for i, model_name in enumerate(cascade):
                total_regret = results[model_name]['RCL'][j] + results[model_name]['RGL'][j]
                if total_regret < min_regret:
                    min_regret = total_regret
                    best_model_idx = i
            
            if best_model_idx != -1:
                cascade_predictions[j] = results[cascade[best_model_idx]]['S_test'][j]
                cascade_model_used[j] = best_model_idx
                # total_cost += costs_model[cascade[best_model_idx]]
    
    # Calculate performance metrics
    valid_predictions = ~np.isnan(cascade_predictions)
    if np.sum(valid_predictions) > 0:
        accuracy = accuracy_score(y_test[valid_predictions], 
                                (cascade_predictions[valid_predictions] >= 0.5).astype(int))
        avg_cost = total_cost / len(y_test)
    else:
        accuracy = 0
        avg_cost = 0
    
    return avg_cost, accuracy

def cascade_eval_knn(cascade, costs_model, results, lamb=1):
    """
    For each sample, use KNN to predict which model will be most accurate,
    then use that model's prediction.
    """
    list_models = []
    for model_name in cascade:
        if model_name not in list_models:
            list_models.append(model_name)
    
    y_test = results[list_models[0]]['y_test']
    cascade_predictions = np.full(len(y_test), np.nan)
    cascade_model_used = np.full(len(y_test), -1)
    total_cost = 0
    
    # Get X_test from the first model (all models should have same X_test)
    embeddings_test = results[list_models[0]]['embeddings_test']
    
    # Get KNN predictions for all models at once
    knn_scores = np.zeros((len(embeddings_test), len(cascade)))
    for i, model_name in enumerate(cascade):
        knn = results[model_name]['knn']
        knn_scores[:, i] = lamb * knn.predict(embeddings_test) - costs_model[model_name]

    # Choose model with highest KNN score for each sample
    best_model_indices = np.argmax(knn_scores, axis=1)
    
    # Get predictions from best models
    for j in range(len(embeddings_test)):
        best_model_name = cascade[best_model_indices[j]]
        cascade_predictions[j] = results[best_model_name]['S_test'][j]
        cascade_model_used[j] = best_model_indices[j]
        total_cost += costs_model[best_model_name]
    
    # Calculate performance metrics
    valid_predictions = ~np.isnan(cascade_predictions)
    if np.sum(valid_predictions) > 0:
        accuracy = accuracy_score(y_test[valid_predictions], 
                                (cascade_predictions[valid_predictions] >= 0.5).astype(int))
        avg_cost = total_cost / len(y_test)
    else:
        accuracy = 0
        avg_cost = 0
    
    return avg_cost, accuracy

def plot_threshold_lambda_analysis(cascade, costs_model, all_seeds_results):
    """
    Plot cascade performance for different threshold values (RCL/RCL+RGL) and lambda values (KNN)
    with means and standard deviations across all seeds
    """
    # Define parameter ranges
    thresholds = [0, 0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.2, 0.3, 0.5]
    lambdas = [0.1, 1, 2, 5, 10, 20, 50, 100, 200]
    
    # Store results for all seeds
    all_threshold_results_rcl = []
    all_threshold_results_rcl_rgl = []
    all_lambda_results_knn = []
    
    # Calculate results for each seed
    for seed in all_seeds_results.keys():
        results = all_seeds_results[seed]
        
        threshold_results_rcl = []
        threshold_results_rcl_rgl = []
        print(results.keys())
        for threshold in thresholds:
            avg_cost_rcl, accuracy_rcl = cascade_eval_rcl(cascade, costs_model, results, threshold)
            threshold_results_rcl.append((avg_cost_rcl, accuracy_rcl))
            
            avg_cost_rcl_rgl, accuracy_rcl_rgl = cascade_eval_rcl_rgl(cascade, costs_model, results, threshold)
            threshold_results_rcl_rgl.append((avg_cost_rcl_rgl, accuracy_rcl_rgl))
        
        lambda_results_knn = []
        for lamb in lambdas:
            avg_cost_knn, accuracy_knn = cascade_eval_knn(cascade, costs_model, results, lamb)
            lambda_results_knn.append((avg_cost_knn, accuracy_knn))
        
        all_threshold_results_rcl.append(threshold_results_rcl)
        all_threshold_results_rcl_rgl.append(threshold_results_rcl_rgl)
        all_lambda_results_knn.append(lambda_results_knn)
    
    # Calculate means and standard deviations
    threshold_results_rcl_mean = []
    threshold_results_rcl_std = []
    threshold_results_rcl_rgl_mean = []
    threshold_results_rcl_rgl_std = []
    lambda_results_knn_mean = []
    lambda_results_knn_std = []
    
    # Process RCL results
    for i in range(len(thresholds)):
        costs = [results[i][0] for results in all_threshold_results_rcl]
        accs = [results[i][1] for results in all_threshold_results_rcl]
        threshold_results_rcl_mean.append((np.mean(costs), np.mean(accs), thresholds[i]))
        threshold_results_rcl_std.append((np.std(costs), np.std(accs)))
    
    # Process RCL+RGL results
    for i in range(len(thresholds)):
        costs = [results[i][0] for results in all_threshold_results_rcl_rgl]
        accs = [results[i][1] for results in all_threshold_results_rcl_rgl]
        threshold_results_rcl_rgl_mean.append((np.mean(costs), np.mean(accs), thresholds[i]))
        threshold_results_rcl_rgl_std.append((np.std(costs), np.std(accs)))
    
    # Process KNN results
    for i in range(len(lambdas)):
        costs = [results[i][0] for results in all_lambda_results_knn]
        accs = [results[i][1] for results in all_lambda_results_knn]
        lambda_results_knn_mean.append((np.mean(costs), np.mean(accs), lambdas[i]))
        lambda_results_knn_std.append((np.std(costs), np.std(accs)))
    
    # Create the plot
    plt.figure(figsize=(4,3))
    
    # Extract data for plotting
    costs_rcl_mean, accs_rcl_mean, thresh_vals = zip(*threshold_results_rcl_mean)
    costs_rcl_std, accs_rcl_std = zip(*threshold_results_rcl_std)
    
    costs_rcl_rgl_mean, accs_rcl_rgl_mean, _ = zip(*threshold_results_rcl_rgl_mean)
    costs_rcl_rgl_std, accs_rcl_rgl_std = zip(*threshold_results_rcl_rgl_std)
    
    costs_knn_mean, accs_knn_mean, lamb_vals = zip(*lambda_results_knn_mean)
    costs_knn_std, accs_knn_std = zip(*lambda_results_knn_std)
    
    # Plot RCL results with error bars

    # Plot RCL+RGL results with error bars
    plt.errorbar(costs_rcl_rgl_mean, accs_rcl_rgl_mean,
                xerr=costs_rcl_rgl_std, yerr=accs_rcl_rgl_std,
                color="tab:orange", marker='P', alpha=0.8, label='Cascade with \n$\widehat{R}_{CL}+\widehat{R}_{GL}$', capsize=3)
    
    plt.errorbar(costs_rcl_mean, accs_rcl_mean, 
                xerr=costs_rcl_std, yerr=accs_rcl_std,
                color="tab:green", marker='o', alpha=0.8, label='Cascade with $\widehat{R}_{CL}$', capsize=3)
    
    # Plot KNN results with error bars
    plt.errorbar(costs_knn_mean, accs_knn_mean,
                xerr=costs_knn_std, yerr=accs_knn_std,
                color='tab:blue', marker='^', alpha=0.8, label='Routing with DT', capsize=3)
    
    # Add parameter value annotations (only for mean values)
    for i, (cost, acc, thresh) in enumerate(threshold_results_rcl_mean):
        if i % 4 == 1:  # Annotate every third point to avoid clutter
            plt.annotate(f't={thresh}', (cost, acc), 
                        textcoords="offset points", xytext=(-10,5), 
                        ha='right', fontsize=8, alpha=0.7, color='tab:green')
    
    for i, (cost, acc, thresh) in enumerate(threshold_results_rcl_rgl_mean):
        if i % 4 == 1:
            plt.annotate(f't={thresh}', (cost, acc), 
                        textcoords="offset points", xytext=(-18,10), 
                        ha='left', fontsize=8, alpha=0.7, color='tab:orange')
    
    for i, (cost, acc, lamb) in enumerate(lambda_results_knn_mean):
        if i % 4 == 1:
            plt.annotate(f'λ={lamb}', (cost, acc), 
                        textcoords="offset points", xytext=(10,-5), 
                        ha='left', fontsize=8, alpha=0.7, color='tab:blue')
    
    # Plot individual models for reference (using first seed)
    cascade_str = ' → '.join(cascade)
    first_seed_results = all_seeds_results[list(all_seeds_results.keys())[0]]
    y_test = first_seed_results[cascade[0]]['y_test']
    
    for model_name in cascade:
        # Calculate mean accuracy across all seeds
        model_accs = []
        for seed in all_seeds_results.keys():
            model_predictions = all_seeds_results[seed][model_name]['S_test']
            acc = accuracy_score(all_seeds_results[seed][model_name]['y_test'], 
                               (model_predictions >= 0.5).astype(int))
            model_accs.append(acc)
        
        mean_acc = np.mean(model_accs)
        std_acc = np.std(model_accs)
        avg_cost = costs_per_model[model_name]
        
        plt.ylim(0.62, None)

        # Check if model is out of bounds
        y_min, y_max = plt.ylim()
        if mean_acc < y_min:
            # Plot arrow pointing to out-of-bounds model
            plt.annotate('', xy=(avg_cost, y_min), xytext=(avg_cost, y_min + 0.01),
                arrowprops=dict(arrowstyle='->', color='gold', lw=2, edgecolor='black'))
            plt.annotate(model_name, (avg_cost + 0.15, y_min + 0.0025), 
                ha='center', fontsize=10, alpha=0.8, color='black')
        elif model_name == 'Llama 70B':
            plt.scatter(avg_cost, mean_acc,
                color='gold', marker='*', s=100, alpha=1, 
                edgecolors='black', linewidth=0.5,
                zorder=5, legend='Individual Models')
            plt.annotate(model_name, (avg_cost, mean_acc), 
                textcoords="offset points", xytext=(-33,-8), 
                ha='center', fontsize=10, alpha=0.8, color='black')
        else:
            plt.scatter(avg_cost, mean_acc,
                color='gold', marker='*', s=100, alpha=1, 
                edgecolors='black', linewidth=0.5,
                zorder=5)
            plt.annotate(model_name, (avg_cost, mean_acc), 
                textcoords="offset points", xytext=(30,0), 
                ha='center', fontsize=10, alpha=0.8, color='black')
        # plt.annotate(model_name, (avg_cost, mean_acc), 
        #             textcoords="offset points", xytext=(0,-20), 
        #             ha='center', fontsize=10, alpha=0.8)
    
    plt.xlabel('Average Cost per Sample ($)', fontsize=14)
    plt.ylabel('Accuracy', fontsize=14)
    plt.title(f'{cascade_str}', fontsize=11)
    plt.grid(True, alpha=0.4)
    plt.legend(loc='lower right', fontsize = 9)
    plt.savefig('cascade_parameter_analysistestloadmlp.pdf', dpi=800, bbox_inches='tight')
    plt.show()

    # Print detailed results with means and standard deviations
    print(f"\nParameter Analysis for Cascade: {cascade_str}")
    print("=" * 80)
    
    print("\nRCL Method (varying threshold):")
    for i, (cost_mean, acc_mean, thresh) in enumerate(threshold_results_rcl_mean):
        cost_std, acc_std = threshold_results_rcl_std[i]
        print(f"  Threshold {thresh:6.3f}: Accuracy = {acc_mean:.4f}±{acc_std:.4f}, Cost = ${cost_mean:.3f}±{cost_std:.3f}")
    
    print("\nRCL+RGL Method (varying threshold):")
    for i, (cost_mean, acc_mean, thresh) in enumerate(threshold_results_rcl_rgl_mean):
        cost_std, acc_std = threshold_results_rcl_rgl_std[i]
        print(f"  Threshold {thresh:6.3f}: Accuracy = {acc_mean:.4f}±{acc_std:.4f}, Cost = ${cost_mean:.3f}±{cost_std:.3f}")
    
    print("\nDT Method (varying lambda):")
    for i, (cost_mean, acc_mean, lamb) in enumerate(lambda_results_knn_mean):
        cost_std, acc_std = lambda_results_knn_std[i]
        print(f"  Lambda {lamb:8.1f}: Accuracy = {acc_mean:.4f}±{acc_std:.4f}, Cost = ${cost_mean:.3f}±{cost_std:.3f}")
#%%

test_cascade = ["Llama 1B", 'Llama 3B', 'Llama 8B', 'Llama 70B']
plot_threshold_lambda_analysis(test_cascade, costs_per_model, all_seeds_results)
# %%

def cascade_eval_confidence(cascade, costs_model, results, confidence_threshold=0.9):
    """
    Confidence-based cascade routing: use a model's prediction only if its confidence
    (max probability) exceeds the threshold, otherwise move to next model in cascade.
    """
    list_models = []
    for model_name in cascade:
        if model_name not in list_models:
            list_models.append(model_name)
    
    y_test = results[list_models[0]]['y_test']
    cascade_predictions = np.full(len(y_test), np.nan)
    cascade_model_used = np.full(len(y_test), -1)
    total_cost = 0
    
    for j in range(len(y_test)):
        model_accepted = False
        
        # Try models in the cascade order
        for i, model_name in enumerate(cascade):
            # Get model's prediction probability
            prob = results[model_name]["S_test"][j] #results[model_name]['S_test'][j]
            
            # Calculate normalized sum of all predictions considered so far
            if i == 0:
                # First model - just use its prediction
                normalized_sum = prob
            else:
                # Sum all previous predictions including current one
                sum_predictions = 0
                for k in range(i + 1):
                    sum_predictions += results[model_name]["S_test"][j] #results[cascade[k]]['S_test'][j]
                normalized_sum = sum_predictions / (i + 1)
            # Calculate confidence (distance from 0.5)
            confidence = max(normalized_sum, 1 - normalized_sum)
            
            total_cost += costs_model[model_name]
            
            # If confidence exceeds threshold, accept this model's prediction
            if confidence >= confidence_threshold:
                cascade_predictions[j] = normalized_sum
                cascade_model_used[j] = i
                model_accepted = True
                break
        
        # If no model meets confidence threshold, use the last model in cascade
        if not model_accepted:
            last_model = cascade[-1]
            cascade_predictions[j] = normalized_sum
            cascade_model_used[j] = len(cascade) - 1
    
    # Calculate performance metrics
    valid_predictions = ~np.isnan(cascade_predictions)
    if np.sum(valid_predictions) > 0:
        accuracy = accuracy_score(y_test[valid_predictions], 
                                (cascade_predictions[valid_predictions] >= 0.5).astype(int))
        avg_cost = total_cost / len(y_test)
    else:
        accuracy = 0
        avg_cost = 0
    
    return avg_cost, accuracy


def plot_confidence_analysis(cascade, costs_model, all_seeds_results):
    """
    Plot cascade performance for different confidence threshold values
    with means and standard deviations across all seeds
    """
    # Define confidence threshold range
    confidence_thresholds = [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.99]
    
    # Store results for all seeds
    all_confidence_results = []
    
    # Calculate results for each seed
    for seed in all_seeds_results.keys():
        results = all_seeds_results[seed]
        
        confidence_results = []
        for conf_thresh in confidence_thresholds:
            avg_cost, accuracy = cascade_eval_confidence(cascade, costs_model, results, conf_thresh)
            confidence_results.append((avg_cost, accuracy))
        
        all_confidence_results.append(confidence_results)
    
    # Calculate means and standard deviations
    confidence_results_mean = []
    confidence_results_std = []
    
    for i in range(len(confidence_thresholds)):
        costs = [results[i][0] for results in all_confidence_results]
        accs = [results[i][1] for results in all_confidence_results]
        confidence_results_mean.append((np.mean(costs), np.mean(accs), confidence_thresholds[i]))
        confidence_results_std.append((np.std(costs), np.std(accs)))
    
    # Create the plot
    plt.figure(figsize=(4,3))
    
    # Extract data for plotting
    costs_mean, accs_mean, conf_vals = zip(*confidence_results_mean)
    costs_std, accs_std = zip(*confidence_results_std)
    
    # Plot results with error bars
    plt.errorbar(costs_mean, accs_mean, 
                xerr=costs_std, yerr=accs_std,
                color="tab:purple", marker='o', alpha=0.8, label='Confidence-based Cascade', capsize=3)
    
    # Add parameter value annotations (only for mean values)
    for i, (cost, acc, conf) in enumerate(confidence_results_mean):
        if i % 2 == 1:  # Annotate every second point to avoid clutter
            plt.annotate(f'c={conf}', (cost, acc), 
                        textcoords="offset points", xytext=(-10,5), 
                        ha='right', fontsize=8, alpha=0.7
                        )
        
    # Plot individual models for reference (using first seed)
    cascade_str = ' → '.join(cascade)
    first_seed_results = all_seeds_results[list(all_seeds_results.keys())[0]]
    y_test = first_seed_results[cascade[0]]['y_test']
    for model_name in cascade:
        # Calculate mean accuracy across all seeds
        model_accs = []
        for seed in all_seeds_results.keys():
            model_predictions = all_seeds_results[seed][model_name]['S_test']
            acc = accuracy_score(all_seeds_results[seed][model_name]['y_test'], 
                               (model_predictions >= 0.5).astype(int))
            model_accs.append(acc)
        
        mean_acc = np.mean(model_accs)
        std_acc = np.std(model_accs)
        avg_cost = costs_per_model[model_name]
        
        plt.ylim(0.62, None)

        # Check if model is out of bounds
        y_min, y_max = plt.ylim()
        if mean_acc < y_min:
            # Plot arrow pointing to out-of-bounds model
            plt.annotate('', xy=(avg_cost, y_min), xytext=(avg_cost, y_min + 0.01),
                arrowprops=dict(arrowstyle='->', color='gold', lw=2, edgecolor='black'))
            plt.annotate(model_name, (avg_cost + 0.1, y_min + 0.0025), 
                ha='center', fontsize=10, alpha=0.8, color='black')
        elif model_name == 'Llama 70B':
            plt.scatter(avg_cost, mean_acc,
                color='gold', marker='*', s=100, alpha=1, 
                edgecolors='black', linewidth=0.5,
                zorder=5)
            plt.annotate(model_name, (avg_cost, mean_acc), 
                textcoords="offset points", xytext=(-33,-8), 
                ha='center', fontsize=10, alpha=0.8, color='black')
        else:
            plt.scatter(avg_cost, mean_acc,
                color='gold', marker='*', s=100, alpha=1, 
                edgecolors='black', linewidth=0.5,
                zorder=5)
            plt.annotate(model_name, (avg_cost, mean_acc), 
                textcoords="offset points", xytext=(30,0), 
                ha='center', fontsize=10, alpha=0.8, color='black')
        # plt.annotate(model_name, (avg_cost, mean_acc), 
        #             textcoords="offset points", xy
        #             text=(-10,-20),
        #             ha='center', fontsize=10, alpha=0.8)


    plt.xlabel('Average Cost per Sample ($)', fontsize=14)
    plt.ylabel('Accuracy', fontsize=14)
    plt.title(f'Confidence-based Cascade Analysis\nCascade: {cascade_str}')
    plt.grid(True, alpha=0.4)
    plt.legend(loc='lower right')
    # plt.savefig('cascade_confidence_analysis_testloadmlp.pdf', dpi=800, bbox_inches='tight')
    plt.show()
    # Print detailed results with means and standard deviations
    print(f"\nConfidence-based Cascade Analysis for Cascade: {cascade_str}")
    print("=" * 80)
    print("\nConfidence-based Method (varying confidence threshold):")
    for i, (cost_mean, acc_mean, conf) in enumerate(confidence_results_mean):
        cost_std, acc_std = confidence_results_std[i]
        print(f"  Confidence {conf:5.2f}: Accuracy = {acc_mean:.4f}±{acc_std:.4f}, Cost = ${cost_mean:.3f}±{cost_std:.3f}")
#%%
test_cascade = ["Llama 1B", 'Llama 3B', 'Llama 8B', 'Llama 70B']
plot_confidence_analysis(test_cascade, costs_per_model, all_seeds_results)

# %%
def plot_threshold_lambda_analysis_with_confidence(cascade, costs_model, all_seeds_results):
    """
    Plot cascade performance for different threshold values (RCL/RCL+RGL), lambda values (KNN),
    and confidence thresholds with means and standard deviations across all seeds
    """
    # Define parameter ranges
    thresholds = [0, 0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.2, 0.3, 0.5]
    lambdas = [0.1, 1, 2, 5, 10, 20, 50, 100, 200]
    confidence_thresholds = [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.99]
    
    # Store results for all seeds
    all_threshold_results_rcl = []
    all_threshold_results_rcl_rgl = []
    all_lambda_results_knn = []
    all_confidence_results = []
    
    # Calculate results for each seed
    for seed in all_seeds_results.keys():
        results = all_seeds_results[seed]
        
        threshold_results_rcl = []
        threshold_results_rcl_rgl = []
        for threshold in thresholds:
            avg_cost_rcl, accuracy_rcl = cascade_eval_rcl(cascade, costs_model, results, threshold)
            threshold_results_rcl.append((avg_cost_rcl, accuracy_rcl))
            
            avg_cost_rcl_rgl, accuracy_rcl_rgl = cascade_eval_rcl_rgl(cascade, costs_model, results, threshold)
            threshold_results_rcl_rgl.append((avg_cost_rcl_rgl, accuracy_rcl_rgl))
        
        lambda_results_knn = []
        for lamb in lambdas:
            avg_cost_knn, accuracy_knn = cascade_eval_knn(cascade, costs_model, results, lamb)
            lambda_results_knn.append((avg_cost_knn, accuracy_knn))
        
        confidence_results = []
        for conf_thresh in confidence_thresholds:
            avg_cost_conf, accuracy_conf = cascade_eval_confidence(cascade, costs_model, results, conf_thresh)
            confidence_results.append((avg_cost_conf, accuracy_conf))
        
        all_threshold_results_rcl.append(threshold_results_rcl)
        all_threshold_results_rcl_rgl.append(threshold_results_rcl_rgl)
        all_lambda_results_knn.append(lambda_results_knn)
        all_confidence_results.append(confidence_results)
    
    # Calculate means and standard deviations
    threshold_results_rcl_mean = []
    threshold_results_rcl_std = []
    threshold_results_rcl_rgl_mean = []
    threshold_results_rcl_rgl_std = []
    lambda_results_knn_mean = []
    lambda_results_knn_std = []
    confidence_results_mean = []
    confidence_results_std = []
    
    # Process RCL results
    for i in range(len(thresholds)):
        costs = [results[i][0] for results in all_threshold_results_rcl]
        accs = [results[i][1] for results in all_threshold_results_rcl]
        threshold_results_rcl_mean.append((np.mean(costs), np.mean(accs), thresholds[i]))
        threshold_results_rcl_std.append((np.std(costs), np.std(accs)))
    
    # Process RCL+RGL results
    for i in range(len(thresholds)):
        costs = [results[i][0] for results in all_threshold_results_rcl_rgl]
        accs = [results[i][1] for results in all_threshold_results_rcl_rgl]
        threshold_results_rcl_rgl_mean.append((np.mean(costs), np.mean(accs), thresholds[i]))
        threshold_results_rcl_rgl_std.append((np.std(costs), np.std(accs)))
    
    # Process KNN results
    for i in range(len(lambdas)):
        costs = [results[i][0] for results in all_lambda_results_knn]
        accs = [results[i][1] for results in all_lambda_results_knn]
        lambda_results_knn_mean.append((np.mean(costs), np.mean(accs), lambdas[i]))
        lambda_results_knn_std.append((np.std(costs), np.std(accs)))
    
    # Process Confidence results
    for i in range(len(confidence_thresholds)):
        costs = [results[i][0] for results in all_confidence_results]
        accs = [results[i][1] for results in all_confidence_results]
        confidence_results_mean.append((np.mean(costs), np.mean(accs), confidence_thresholds[i]))
        confidence_results_std.append((np.std(costs), np.std(accs)))
    
    # Create the plot
    plt.figure(figsize=(5,4))
    
    # Extract data for plotting
    costs_rcl_mean, accs_rcl_mean, thresh_vals = zip(*threshold_results_rcl_mean)
    costs_rcl_std, accs_rcl_std = zip(*threshold_results_rcl_std)
    
    costs_rcl_rgl_mean, accs_rcl_rgl_mean, _ = zip(*threshold_results_rcl_rgl_mean)
    costs_rcl_rgl_std, accs_rcl_rgl_std = zip(*threshold_results_rcl_rgl_std)
    
    costs_knn_mean, accs_knn_mean, lamb_vals = zip(*lambda_results_knn_mean)
    costs_knn_std, accs_knn_std = zip(*lambda_results_knn_std)
    
    costs_conf_mean, accs_conf_mean, conf_vals = zip(*confidence_results_mean)
    costs_conf_std, accs_conf_std = zip(*confidence_results_std)

    # Plot all methods with error bars
    plt.errorbar(costs_rcl_rgl_mean, accs_rcl_rgl_mean,
                xerr=costs_rcl_rgl_std, yerr=accs_rcl_rgl_std,
                color="tab:orange", marker='P', alpha=0.8, label='Cascade with $\widehat{R}_{CL}+\widehat{R}_{GL}$', capsize=3)
    
    plt.errorbar(costs_rcl_mean, accs_rcl_mean, 
                xerr=costs_rcl_std, yerr=accs_rcl_std,
                color="tab:green", marker='o', alpha=0.8, label='Cascade with $\widehat{R}_{CL}$', capsize=3)
    
    plt.errorbar(costs_knn_mean, accs_knn_mean,
                xerr=costs_knn_std, yerr=accs_knn_std,
                color='tab:blue', marker='^', alpha=0.8, label='Routing with DT', capsize=3)
    
    plt.errorbar(costs_conf_mean, accs_conf_mean,
                xerr=costs_conf_std, yerr=accs_conf_std,
                color='tab:purple', marker='s', alpha=0.8, label='Confidence-based Cascade', capsize=3)
    
    # Add parameter value annotations (only for mean values, sparse to avoid clutter)
    for i, (cost, acc, thresh) in enumerate(threshold_results_rcl_mean):
        if i % 4 == 1:
            plt.annotate(f't={thresh}', (cost, acc), 
                        textcoords="offset points", xytext=(-10,5), 
                        ha='right', fontsize=8, alpha=0.7, color='tab:green')
    
    for i, (cost, acc, thresh) in enumerate(threshold_results_rcl_rgl_mean):
        if i % 4 == 1:
            plt.annotate(f't={thresh}', (cost, acc), 
                        textcoords="offset points", xytext=(-18,10), 
                        ha='left', fontsize=8, alpha=0.7, color='tab:orange')
    
    for i, (cost, acc, lamb) in enumerate(lambda_results_knn_mean):
        if i % 4 == 1:
            plt.annotate(f'λ={lamb}', (cost, acc), 
                        textcoords="offset points", xytext=(10,-5), 
                        ha='left', fontsize=8, alpha=0.7, color='tab:blue')
    
    for i, (cost, acc, conf) in enumerate(confidence_results_mean):
        if i % 2 == 1:
            plt.annotate(f'c={conf}', (cost, acc), 
                        textcoords="offset points", xytext=(5,10), 
                        ha='left', fontsize=8, alpha=0.7, color='tab:purple')
    
    # Plot individual models for reference
    cascade_str = ' → '.join(cascade)
    first_seed_results = all_seeds_results[list(all_seeds_results.keys())[0]]
    
    for model_name in cascade:
        # Calculate mean accuracy across all seeds
        model_accs = []
        for seed in all_seeds_results.keys():
            model_predictions = all_seeds_results[seed][model_name]['S_test']
            acc = accuracy_score(all_seeds_results[seed][model_name]['y_test'], 
                               (model_predictions >= 0.5).astype(int))
            model_accs.append(acc)
        
        mean_acc = np.mean(model_accs)
        avg_cost = costs_per_model[model_name]
        
        plt.ylim(0.62, None)

        # Check if model is out of bounds
        y_min, y_max = plt.ylim()
        if mean_acc < y_min:
            plt.annotate('', xy=(avg_cost, y_min), xytext=(avg_cost, y_min + 0.01),
                arrowprops=dict(arrowstyle='->', color='gold', lw=2, edgecolor='black'))
            plt.annotate(model_name, (avg_cost + 0.1, y_min + 0.0025), 
                ha='center', fontsize=10, alpha=0.8, color='black')
        elif model_name == 'Llama 70B':
            plt.scatter(avg_cost, mean_acc,
                color='gold', marker='*', s=100, alpha=1, 
                edgecolors='black', linewidth=0.5,
                zorder=5)
            plt.annotate(model_name, (avg_cost, mean_acc), 
                textcoords="offset points", xytext=(-33,-8), 
                ha='center', fontsize=10, alpha=0.8, color='black')
        else:
            plt.scatter(avg_cost, mean_acc,
                color='gold', marker='*', s=100, alpha=1, 
                edgecolors='black', linewidth=0.5,
                zorder=5)
            plt.annotate(model_name, (avg_cost, mean_acc), 
                textcoords="offset points", xytext=(30,0), 
                ha='center', fontsize=10, alpha=0.8, color='black')
    
    plt.xlabel('Average Cost per Sample ($)', fontsize=14)
    plt.ylabel('Accuracy', fontsize=14)
    plt.title(f'Cascade Methods Comparison\nCascade: {cascade_str}')
    plt.grid(True, alpha=0.4)
    plt.legend(loc='lower right')
    plt.savefig('cascade_all_methods_comparison.pdf', dpi=800, bbox_inches='tight')
    plt.show()

    # Print detailed results summary
    print(f"\nAll Methods Comparison for Cascade: {cascade_str}")
    print("=" * 80)
    
    print("\nBest performance for each method:")
    
    # Find best RCL result
    best_rcl_idx = np.argmax([acc for _, acc, _ in threshold_results_rcl_mean])
    best_rcl = threshold_results_rcl_mean[best_rcl_idx]
    print(f"RCL: Accuracy = {best_rcl[1]:.4f}, Cost = ${best_rcl[0]:.3f}, Threshold = {best_rcl[2]}")
    
    # Find best RCL+RGL result
    best_rcl_rgl_idx = np.argmax([acc for _, acc, _ in threshold_results_rcl_rgl_mean])
    best_rcl_rgl = threshold_results_rcl_rgl_mean[best_rcl_rgl_idx]
    print(f"RCL+RGL: Accuracy = {best_rcl_rgl[1]:.4f}, Cost = ${best_rcl_rgl[0]:.3f}, Threshold = {best_rcl_rgl[2]}")
    
    # Find best KNN result
    best_knn_idx = np.argmax([acc for _, acc, _ in lambda_results_knn_mean])
    best_knn = lambda_results_knn_mean[best_knn_idx]
    print(f"DT: Accuracy = {best_knn[1]:.4f}, Cost = ${best_knn[0]:.3f}, Lambda = {best_knn[2]}")
    
    # Find best Confidence result
    best_conf_idx = np.argmax([acc for _, acc, _ in confidence_results_mean])
    best_conf = confidence_results_mean[best_conf_idx]
    print(f"Confidence: Accuracy = {best_conf[1]:.4f}, Cost = ${best_conf[0]:.3f}, Threshold = {best_conf[2]}")
#%%
test_cascade = ["Llama 1B", 'Llama 3B', 'Llama 8B', 'Llama 70B']
plot_threshold_lambda_analysis_with_confidence(test_cascade, costs_per_model, all_seeds_results)
# %%
