#%%
 
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from folktexts.acs import ACSDataset
from sklearn.tree import DecisionTreeRegressor
import glest
from folktexts.acs import ACSTaskMetadata
from deferral_experiment.regret_helpers import compute_regret_CL, get_constant_utilty, get_threshold_from_utility
from utils import honest_tree_pred
from sklearn.metrics import roc_auc_score, accuracy_score, brier_score_loss
import seaborn as sns
from itertools import combinations
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsRegressor
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.neural_network import MLPRegressor, MLPClassifier
from itertools import combinations
from matplotlib.lines import Line2D
from sklearn.linear_model import LinearRegression
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.patches import Polygon
#%%


llama1 = pd.read_csv('deferral_experiment/llama1_instruct.csv')
llama3 = pd.read_csv('deferral_experiment/llama3_instruct.csv')
llama8 = pd.read_csv('deferral_experiment/llama8_instruct.csv')
llama70 = pd.read_csv('deferral_experiment/llama70_instruct.csv')
phi4 = pd.read_csv('deferral_experiment/phi4_instruct.csv')
gemma27 = pd.read_csv('deferral_experiment/gemma27_instruct.csv')
mixtral8x7b = pd.read_csv('deferral_experiment/mixtral8x7b_instruct.csv')
embeddings = np.load('deferral_experiment/sentence_embeddings_MiniLM_L12_v2.npy')
#%%

costs_per_model = {
    'Llama 1B': 0.04,
    'Llama 3B': 0.06,
    'Llama 8B': 0.18,
    'Llama 70B': 0.88,
    'Gemma 27B' : 0.25,
    'Mixtral8x7B': 0.60,
    'Phi 4' : 0.22
}

costs_of_all_models = {'Llama 1B' : 0.04 * len(llama1)//10,
    'Llama 3B': 0.06 * len(llama3)//10,
    'Llama 8B': 0.18 * len(llama8)//10,
    'Llama 70B': 0.88 * len(llama70)//10,
    'Gemma 27B' : 0.25 * len(gemma27)//10,
    'Mixtral8x7B': 0.70 * len(mixtral8x7b)//10,
    'phi4' : 0.22 * len(phi4)//10
}

#%%

# Process all models
models = {'Llama 1B' : llama1, 'Llama 3B': llama3, 'Llama 8B': llama8, 'Llama 70B': llama70, 'Gemma 27B': gemma27, 'Mixtral8x7B': mixtral8x7b, 'Phi 4': phi4}
results = {}
seed = 0

# Store results for multiple seeds
all_seeds_results = {}
seeds = range(5)  # You can modify this list of seeds

for seed in seeds:
    print(f"\n=== Processing seed {seed} ===")
    seed_results = {}
    
    for model_name, model_df in models.items():
        print(f"Processing {model_name} with seed {seed}...")

        # half_size = len(model_df) // 4
        # model_df_half = model_df.sample(n=half_size, random_state=seed)
        # embeddings_half = embeddings[model_df_half.index]
        X = model_df.drop(columns=['risk_score', 'label']).values
        y = model_df['label'].values
        S = model_df['risk_score'].values
        
        t_target = [0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975, 0.99]
        U = get_constant_utilty(100, t_target)  # (n_utilities, 2, 2)
        t = get_threshold_from_utility(U) 

        calibrated_classifier = LogisticRegression()

        X, X_leftover, y, y_leftover, S, S_leftover, embeddings_kept, embeddings_leftover = train_test_split(
            X, y, S, embeddings, test_size=0.1, random_state=seed
        )


        correct = (y == (S >= 0.5).astype(int)).astype(int)
        dtcorrect = DecisionTreeRegressor(max_depth = None, min_samples_leaf= 15)

        # # dtcorrect = KNeighborsRegressor(n_neighbors=40, metric='cosine', n_jobs=40)
        dtcorrect.fit(X, correct)

        mlp_pred = np.load(f'deferral_experiment/baselines/{model_name}_MLP_correct_predictions_seed_{seed}.npy')
        knn_pred = np.load(f'deferral_experiment/baselines/{model_name}_KNN_correct_predictions_seed_{seed}.npy')

        X_train, X_test, y_train, y_test, S_train, S_test = train_test_split(
            X, y, S, test_size=0.5, random_state=seed
        )

        X_train, X_cal, y_train, y_cal, S_train, S_cal = train_test_split(
            X_train, y_train, S_train, test_size=max(int(len(X_train) * 0.2),4000), random_state=seed
        )

        calibrated_classifier.fit(S_cal.reshape(-1,1), y_cal)

        c_hat_train = calibrated_classifier.predict_proba(S_train.reshape(-1,1))[:, 1]
        c_hat_test = calibrated_classifier.predict_proba(S_test.reshape(-1,1))[:, 1]

        residuals_train = y_train - c_hat_train
        residuals_test = y_test - c_hat_test
        dt = DecisionTreeRegressor(max_depth = None, min_samples_leaf= 15)
        dt.fit(X_train, residuals_train)
        leaf_ids = dt.apply(X_test)

        gle = glest.core.GLEstimatorResiduals(None, None)
        gle.fit(X_test, y_test, y_scores_cal = c_hat_test, partition = leaf_ids)

        c_hat_leftover = calibrated_classifier.predict_proba(S_leftover.reshape(-1,1))[:, 1]
        r_hat_leftover = honest_tree_pred(dt, gle.honest_rj, X_leftover)

        t = 0.5
        a = (S_leftover[:, None] >= t).astype(int)
        RCL = compute_regret_CL(c_hat_leftover, t, a)  # (n, k)

        a = (c_hat_leftover[:, None] >= t).astype(int)  # (n, k)
        RGL = compute_regret_CL(c_hat_leftover + r_hat_leftover, t, a)  # (n, k)
        
        # Store results for this seed
        seed_results[model_name] = {
            'X_test': X_leftover,
            'y_test': y_leftover, 
            'S_test': S_leftover,
            'embeddings_test': X_leftover,
            'c_hat_test': c_hat_leftover,
            'r_hat': r_hat_leftover,
            'RCL': RCL,
            'RGL': RGL,
            'gle': gle,
            'tree': dt,
            'knn': dtcorrect,
            'mlp_predictions': mlp_pred,
            'knn_predictions': knn_pred
        }
    
    # Store results for this seed
    all_seeds_results[seed] = seed_results

# Set results to the first seed for backward compatibility
results = all_seeds_results[seeds[0]]

#%%

all_seeds_results
# %%
def cascade_eval_rcl(cascade, costs_model, results, threshold=0.05):
    list_models = []
    for model_name in cascade:
        if model_name not in list_models:
            list_models.append(model_name)
    y_test = results[list_models[0]]['y_test']
    cascade_predictions = np.full(len(results[list_models[0]]["S_test"]), np.nan)
    cascade_model_used = np.full(len(results[list_models[0]]["S_test"]), -1)
    total_cost = 0
    
    for j in range(len(y_test)):
        model_accepted = False

        # Try models in the cascade order
        for model_name in cascade:
            # Calculate regret for this sample and model
            model_regret = results[model_name]['RCL'][j]
            total_cost += costs_model[model_name]
            # If regret is below threshold, accept this model's prediction
            if model_regret <= threshold:
                cascade_predictions[j] = results[model_name]['S_test'][j]
                cascade_model_used[j] = list_models.index(model_name)
                model_accepted = True
                break
        
        # If no model meets threshold, choose model with lowest regret
        if not model_accepted:
            min_regret = float('inf')
            best_model_idx = -1
            
            for i, model_name in enumerate(cascade):
                total_regret = results[model_name]['RCL'][j] + results[model_name]['RGL'][j]
                if total_regret < min_regret:
                    min_regret = total_regret
                    best_model_idx = i
            
            if best_model_idx != -1:
                cascade_predictions[j] = results[cascade[best_model_idx]]['S_test'][j]
                cascade_model_used[j] = best_model_idx
                # total_cost += costs_model[cascade[best_model_idx]]
    
    # Calculate performance metrics
    valid_predictions = ~np.isnan(cascade_predictions)
    if np.sum(valid_predictions) > 0:
        accuracy = accuracy_score(y_test[valid_predictions], 
                                (cascade_predictions[valid_predictions] >= 0.5).astype(int))
        avg_cost = total_cost / len(y_test)
    else:
        accuracy = 0
        avg_cost = 0
    
    return avg_cost, accuracy



def cascade_eval_rcl_rgl(cascade, costs_model, results, threshold=0.05):
    list_models = []
    for model_name in cascade:
        if model_name not in list_models:
            list_models.append(model_name)

    y_test = results[list_models[0]]['y_test']
    cascade_predictions = np.full(len(y_test), np.nan)
    cascade_model_used = np.full(len(y_test), -1)
    total_cost = 0
    
    for j in range(len(y_test)):
        model_accepted = False
        # Try models in the cascade order
        for model_name in cascade:
            # Calculate regret for this sample and model (RCL + RGL)
            model_regret = results[model_name]['RCL'][j] + results[model_name]['RGL'][j]
            total_cost += costs_model[model_name]
            # If regret is below threshold, accept this model's prediction
            if model_regret <= threshold:
                cascade_predictions[j] = results[model_name]['S_test'][j]
                cascade_model_used[j] = list_models.index(model_name)
                # total_cost += costs_model[model_name]
                model_accepted = True
                break
        
        # If no model meets threshold, choose model with lowest regret
        if not model_accepted:
            min_regret = float('inf')
            best_model_idx = -1
            
            for i, model_name in enumerate(cascade):
                total_regret = results[model_name]['RCL'][j] + results[model_name]['RGL'][j]
                if total_regret < min_regret:
                    min_regret = total_regret
                    best_model_idx = i
            
            if best_model_idx != -1:
                cascade_predictions[j] = results[cascade[best_model_idx]]['S_test'][j] #+ results[cascade[best_model_idx]]['r_hat'][j]
                cascade_model_used[j] = best_model_idx
                # total_cost += costs_model[cascade[best_model_idx]]
    
    # Calculate performance metrics
    valid_predictions = ~np.isnan(cascade_predictions)
    if np.sum(valid_predictions) > 0:
        accuracy = accuracy_score(y_test[valid_predictions], 
                                (cascade_predictions[valid_predictions] >= 0.5).astype(int))
        avg_cost = total_cost / len(y_test)
    else:
        accuracy = 0
        avg_cost = 0
    
    return avg_cost, accuracy




def cascade_eval_confidence(cascade, costs_model, results, confidence_threshold=0.9):
    """
    Confidence-based cascade routing: use a model's prediction only if its confidence
    (max probability) exceeds the threshold, otherwise move to next model in cascade.
    """
    list_models = []
    for model_name in cascade:
        if model_name not in list_models:
            list_models.append(model_name)
    
    y_test = results[list_models[0]]['y_test']
    cascade_predictions = np.full(len(y_test), np.nan)
    cascade_model_used = np.full(len(y_test), -1)
    total_cost = 0
    
    for j in range(len(y_test)):
        model_accepted = False
        
        # Try models in the cascade order
        for i, model_name in enumerate(cascade):
            # Get model's prediction probability
            prob = results[model_name]["S_test"][j] #results[model_name]['S_test'][j]
            
            # Calculate normalized sum of all predictions considered so far
            if i == 0:
                # First model - just use its prediction
                normalized_sum = prob
            else:
                # Sum all previous predictions including current one
                sum_predictions = 0
                for k in range(i + 1):
                    sum_predictions += results[model_name]["S_test"][j] #results[cascade[k]]['S_test'][j]
                normalized_sum = sum_predictions / (i + 1)
            # Calculate confidence (distance from 0.5)
            confidence = max(normalized_sum, 1 - normalized_sum)
            
            total_cost += costs_model[model_name]
            
            # If confidence exceeds threshold, accept this model's prediction
            if confidence >= confidence_threshold:
                cascade_predictions[j] = normalized_sum
                cascade_model_used[j] = i
                model_accepted = True
                break
        
        # If no model meets confidence threshold, use the last model in cascade
        if not model_accepted:
            last_model = cascade[-1]
            cascade_predictions[j] = normalized_sum
            cascade_model_used[j] = len(cascade) - 1
    
    # Calculate performance metrics
    valid_predictions = ~np.isnan(cascade_predictions)
    if np.sum(valid_predictions) > 0:
        accuracy = accuracy_score(y_test[valid_predictions], 
                                (cascade_predictions[valid_predictions] >= 0.5).astype(int))
        avg_cost = total_cost / len(y_test)
    else:
        accuracy = 0
        avg_cost = 0
    
    return avg_cost, accuracy


def cascade_eval_knn(cascade, costs_model, results, lamb=1):
    """
    For each sample, use KNN to predict which model will be most accurate,
    then use that model's prediction.
    """
    list_models = []
    for model_name in cascade:
        if model_name not in list_models:
            list_models.append(model_name)
    
    y_test = results[list_models[0]]['y_test']
    cascade_predictions = np.full(len(y_test), np.nan)
    cascade_model_used = np.full(len(y_test), -1)
    total_cost = 0
    
    # Get X_test from the first model (all models should have same X_test)
    embeddings_test = results[list_models[0]]['X_test']
    
    # Get KNN predictions for all models at once
    knn_scores = np.zeros((len(embeddings_test), len(cascade)))
    for i, model_name in enumerate(cascade):
        knn = results[model_name]['knn']
        knn_scores[:, i] = lamb * knn.predict(embeddings_test) - costs_model[model_name]

    # Choose model with highest KNN score for each sample
    best_model_indices = np.argmax(knn_scores, axis=1)
    
    # Get predictions from best models
    for j in range(len(embeddings_test)):
        best_model_name = cascade[best_model_indices[j]]
        cascade_predictions[j] = results[best_model_name]['S_test'][j]
        cascade_model_used[j] = best_model_indices[j]
        total_cost += costs_model[best_model_name]
    
    # Calculate performance metrics
    valid_predictions = ~np.isnan(cascade_predictions)
    if np.sum(valid_predictions) > 0:
        accuracy = accuracy_score(y_test[valid_predictions], 
                                (cascade_predictions[valid_predictions] >= 0.5).astype(int))
        avg_cost = total_cost / len(y_test)
    else:
        accuracy = 0
        avg_cost = 0
    
    return avg_cost, accuracy

#%%
# Define cascade (all models sorted by cost)
cascade = ['Llama 1B', 'Llama 3B', 'Llama 8B', 'Phi 4', 'Gemma 27B', 'Mixtral8x7B', 'Llama 70B']

# Get accuracy and cost of biggest model (last model in cascade)
biggest_model = cascade[-1]
biggest_cost = costs_per_model[biggest_model]

# Calculate accuracy for biggest model
y_test = results[biggest_model]['y_test']
biggest_predictions = results[biggest_model]['S_test']
biggest_accuracy = accuracy_score(y_test, (biggest_predictions >= 0.5).astype(int))

# Get points for different methods
rcl_cost, rcl_accuracy = cascade_eval_rcl(cascade, costs_per_model, results, threshold=0.05)
rgl_rcl_cost, rgl_rcl_accuracy = cascade_eval_rcl_rgl(cascade, costs_per_model, results, threshold=0.05)
knn_cost, knn_accuracy = cascade_eval_knn(cascade, costs_per_model, results, lamb=1)

# Calculate relative metrics
rcl_point = (rcl_cost / biggest_cost, rcl_accuracy - biggest_accuracy)
rgl_rcl_point = (rgl_rcl_cost / biggest_cost, rgl_rcl_accuracy - biggest_accuracy)
knn_point = (knn_cost / biggest_cost, knn_accuracy - biggest_accuracy)

print(f"RCL point: {rcl_point}")
print(f"RGL+RCL point: {rgl_rcl_point}")
print(f"KNN point: {knn_point}")
# %%


# Generate all possible cascades of all sizes

all_models = ['Llama 1B', 'Llama 3B', 'Llama 8B', 'Phi 4', 'Gemma 27B', 'Mixtral8x7B', 'Llama 70B']
all_cascades = []
for size in range(2, len(all_models) + 1):
    for combo in combinations(all_models, size):
        # Skip cascades that contain Llama 3B
        # if 'Llama 3B' in combo:
        #     continue
        # Sort each cascade by cost (ascending order)
        sorted_cascade = sorted(combo, key=lambda x: costs_per_model[x])
        all_cascades.append(sorted_cascade)

print(f"Generated {len(all_cascades)} cascades")
print("First few cascades:")
for i, cascade in enumerate(all_cascades[:5]):
    print(f"{i+1}: {cascade}")

# Evaluate all cascades
cascade_results = []

# Evaluate all cascades for all seeds
all_seeds_cascade_results = {}

for seed in seeds:
    print(f"Evaluating cascades for seed {seed}...")
    cascade_results = []
    
    for cascade in all_cascades:
        # Get accuracy and cost of biggest model in this cascade
        biggest_model = cascade[-1]
        biggest_cost = costs_per_model[biggest_model]
        
        # Calculate accuracy for biggest model
        y_test = all_seeds_results[seed][biggest_model]['y_test']
        biggest_predictions = all_seeds_results[seed][biggest_model]['S_test']
        biggest_accuracy = accuracy_score(y_test, (biggest_predictions >= 0.5).astype(int))
        
        # Get points for different methods
        rcl_cost, rcl_accuracy = cascade_eval_rcl(cascade, costs_per_model, all_seeds_results[seed], threshold=0)
        rgl_rcl_cost, rgl_rcl_accuracy = cascade_eval_rcl_rgl(cascade, costs_per_model, all_seeds_results[seed], threshold=0)
        knn_cost, knn_accuracy = cascade_eval_knn(cascade, costs_per_model, all_seeds_results[seed], lamb=100)
        
        # Calculate relative metrics
        rcl_point = (rcl_cost / biggest_cost, rcl_accuracy - biggest_accuracy)
        rgl_rcl_point = (rgl_rcl_cost / biggest_cost, rgl_rcl_accuracy - biggest_accuracy)
        knn_point = (knn_cost / biggest_cost, knn_accuracy - biggest_accuracy)
        
        cascade_results.append({
            'cascade': cascade,
            'rcl_point': rcl_point,
            'rgl_rcl_point': rgl_rcl_point,
            'knn_point': knn_point
        })
    
    all_seeds_cascade_results[seed] = cascade_results

# Aggregate results across seeds
aggregated_results = []
for i in range(len(all_cascades)):
    cascade = all_cascades[i]
    
    # Collect points for this cascade across all seeds
    rcl_points_seeds = []
    rgl_rcl_points_seeds = []
    knn_points_seeds = []
    
    for seed in seeds:
        rcl_points_seeds.append(all_seeds_cascade_results[seed][i]['rcl_point'])
        rgl_rcl_points_seeds.append(all_seeds_cascade_results[seed][i]['rgl_rcl_point'])
        knn_points_seeds.append(all_seeds_cascade_results[seed][i]['knn_point'])
    
    # Calculate means and stds
    rcl_x_vals = [point[0] for point in rcl_points_seeds]
    rcl_y_vals = [point[1] for point in rcl_points_seeds]
    rgl_rcl_x_vals = [point[0] for point in rgl_rcl_points_seeds]
    rgl_rcl_y_vals = [point[1] for point in rgl_rcl_points_seeds]
    knn_x_vals = [point[0] for point in knn_points_seeds]
    knn_y_vals = [point[1] for point in knn_points_seeds]
    
    aggregated_results.append({
        'cascade': cascade,
        'last_model': cascade[-1],
        'rcl_mean': (np.mean(rcl_x_vals), np.mean(rcl_y_vals)),
        'rcl_std': (np.std(rcl_x_vals), np.std(rcl_y_vals)),
        'rgl_rcl_mean': (np.mean(rgl_rcl_x_vals), np.mean(rgl_rcl_y_vals)),
        'rgl_rcl_std': (np.std(rgl_rcl_x_vals), np.std(rgl_rcl_y_vals)),
        'knn_mean': (np.mean(knn_x_vals), np.mean(knn_y_vals)),
        'knn_std': (np.std(knn_x_vals), np.std(knn_y_vals))
    })

print(f"Evaluated {len(aggregated_results)} cascades across {len(seeds)} seeds")


#%%
# Create marker mapping for each model
marker_map = {
    'Llama 1B': 'o',
    'Llama 3B': 's', 
    'Llama 8B': '^',
    'Phi 4': 'v',
    'Gemma 27B': 'D',
    'Mixtral8x7B': 'P',
    'Llama 70B': '*'
}

# First, find the global min/max for x and y across all results
all_x_vals = []
all_y_vals = []

for result in aggregated_results:
    # Collect all x and y values including error bars
    all_x_vals.extend([
        result['rcl_mean'][0] - result['rcl_std'][0],
        result['rcl_mean'][0] + result['rcl_std'][0],
        result['rgl_rcl_mean'][0] - result['rgl_rcl_std'][0],
        result['rgl_rcl_mean'][0] + result['rgl_rcl_std'][0],
        result['knn_mean'][0] - result['knn_std'][0],
        result['knn_mean'][0] + result['knn_std'][0]
    ])
    
    all_y_vals.extend([
        result['rcl_mean'][1] - result['rcl_std'][1],
        result['rcl_mean'][1] + result['rcl_std'][1],
        result['rgl_rcl_mean'][1] - result['rgl_rcl_std'][1],
        result['rgl_rcl_mean'][1] + result['rgl_rcl_std'][1],
        result['knn_mean'][1] - result['knn_std'][1],
        result['knn_mean'][1] + result['knn_std'][1]
    ])

x_min, x_max = min(all_x_vals), max(all_x_vals)
y_min, y_max = min(all_y_vals), max(all_y_vals)

# Add some padding
x_padding = (x_max - x_min) * 0.05
y_padding = (y_max - y_min) * 0.05
x_min -= x_padding
x_max += x_padding
y_min -= y_padding
y_max += y_padding

# Create subplots for each last model (excluding Llama 1B)
models_to_plot = [model for model in all_models if model != 'Llama 1B']
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for idx, last_model in enumerate(models_to_plot):
    ax = axes[idx]
    
    # Filter results for this last model
    filtered_results = [r for r in aggregated_results if r['last_model'] == last_model]
    
    if not filtered_results:
        ax.set_title(f'Last model: {last_model} (No data)')
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)
        continue
        
    rcl_x_mean = [result['rcl_mean'][0] for result in filtered_results]
    rcl_y_mean = [result['rcl_mean'][1] for result in filtered_results]
    rcl_x_std = [result['rcl_std'][0] for result in filtered_results]
    rcl_y_std = [result['rcl_std'][1] for result in filtered_results]

    rgl_rcl_x_mean = [result['rgl_rcl_mean'][0] for result in filtered_results]
    rgl_rcl_y_mean = [result['rgl_rcl_mean'][1] for result in filtered_results]
    rgl_rcl_x_std = [result['rgl_rcl_std'][0] for result in filtered_results]
    rgl_rcl_y_std = [result['rgl_rcl_std'][1] for result in filtered_results]

    knn_x_mean = [result['knn_mean'][0] for result in filtered_results]
    knn_y_mean = [result['knn_mean'][1] for result in filtered_results]
    knn_x_std = [result['knn_std'][0] for result in filtered_results]
    knn_y_std = [result['knn_std'][1] for result in filtered_results]
    
    marker = marker_map[last_model]
    
    ax.errorbar(rcl_x_mean, rcl_y_mean, xerr=rcl_x_std, yerr=rcl_y_std, 
                fmt=marker, alpha=0.7, color='blue', capsize=3, markersize=8, label='$\widehat{R}_{CL}$')
    ax.errorbar(rgl_rcl_x_mean, rgl_rcl_y_mean, xerr=rgl_rcl_x_std, yerr=rgl_rcl_y_std,
                fmt=marker, alpha=0.7, color='red', capsize=3, markersize=8, label='$\widehat{R}_{GL}+\widehat{R}_{CL}$')
    ax.errorbar(knn_x_mean, knn_y_mean, xerr=knn_x_std, yerr=knn_y_std,
                fmt=marker, alpha=0.7, color='green', capsize=3, markersize=8, label='kNN')
    
    # Draw dotted lines connecting each triplet of points
    for i in range(len(filtered_results)):
        # Connect the three points for each cascade with dotted lines
        x_coords = [rcl_x_mean[i], rgl_rcl_x_mean[i], knn_x_mean[i]]
        y_coords = [rcl_y_mean[i], rgl_rcl_y_mean[i], knn_y_mean[i]]
        
        # Draw lines between all pairs in the triplet
        ax.plot([x_coords[0], x_coords[1]], [y_coords[0], y_coords[1]], 
                'k--', alpha=0.3, linewidth=0.5)
        ax.plot([x_coords[1], x_coords[2]], [y_coords[1], y_coords[2]], 
                'k--', alpha=0.3, linewidth=0.5)
        ax.plot([x_coords[0], x_coords[2]], [y_coords[0], y_coords[2]], 
                'k--', alpha=0.3, linewidth=0.5)
    
    ax.set_xlabel('Relative Cost')
    ax.set_ylabel('Accuracy Improvement')
    ax.set_title(f'Last model: {last_model}')
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.grid(True, alpha=0.5)
    ax.legend()

plt.tight_layout()
plt.suptitle('Cascade Performance by Last Model: Cost vs Accuracy Improvement', y=1.02)
plt.show()

# %%
all_cascades
# %%
all_models = ['Llama 1B', 'Llama 3B', 'Llama 8B', 'Phi 4', 'Gemma 27B', 'Mixtral8x7B', 'Llama 70B']
all_cascades = []
for size in range(2, len(all_models) + 1):
    for combo in combinations(all_models, size):
        # Skip cascades that contain Llama 3B
        # if 'Llama 3B' in combo:
        #     continue
        # Sort each cascade by cost (ascending order)
        sorted_cascade = sorted(combo, key=lambda x: costs_per_model[x])
        all_cascades.append(sorted_cascade)

print(f"Generated {len(all_cascades)} cascades")
print("First few cascades:")
for i, cascade in enumerate(all_cascades[:5]):
    print(f"{i+1}: {cascade}")

# Evaluate all cascades
cascade_results = []

# Evaluate all cascades for all seeds
all_seeds_cascade_results = {}

for seed in seeds:
    print(f"Evaluating cascades for seed {seed}...")
    cascade_results = []
    
    for cascade in all_cascades:
        # Get accuracy and cost of biggest model in this cascade
        biggest_model = cascade[-1]
        biggest_cost = costs_per_model[biggest_model]
        
        # Calculate accuracy for biggest model
        y_test = all_seeds_results[seed][biggest_model]['y_test']
        biggest_predictions = all_seeds_results[seed][biggest_model]['S_test']
        biggest_accuracy = accuracy_score(y_test, (biggest_predictions >= 0.5).astype(int))
        
        # Get points for different methods
        rcl_cost, rcl_accuracy = cascade_eval_rcl(cascade, costs_per_model, all_seeds_results[seed], threshold=0)
        rgl_rcl_cost, rgl_rcl_accuracy = cascade_eval_rcl_rgl(cascade, costs_per_model, all_seeds_results[seed], threshold=0)
        knn_cost, knn_accuracy = cascade_eval_knn(cascade, costs_per_model, all_seeds_results[seed], lamb=100)
        confidence_cost, confidence_accuracy = cascade_eval_confidence(cascade, costs_per_model, all_seeds_results[seed], confidence_threshold=0.7)


        # Calculate relative metrics
        rcl_point = (rcl_cost, rcl_accuracy)
        rgl_rcl_point = (rgl_rcl_cost, rgl_rcl_accuracy)
        knn_point = (knn_cost, knn_accuracy)
        confidence_point = (confidence_cost, confidence_accuracy)
        
        cascade_results.append({
            'cascade': cascade,
            'rcl_point': rcl_point,
            'rgl_rcl_point': rgl_rcl_point,
            'knn_point': knn_point,
            'last_model_accuracy': biggest_accuracy,
            'last_model_cost': biggest_cost,
            'confidence_point': confidence_point
        })
    
    all_seeds_cascade_results[seed] = cascade_results

# Aggregate results across seeds
aggregated_results = []
for i in range(len(all_cascades)):
    cascade = all_cascades[i]
    
    # Collect points for this cascade across all seeds
    rcl_points_seeds = []
    rgl_rcl_points_seeds = []
    knn_points_seeds = []
    
    for seed in seeds:
        rcl_points_seeds.append(all_seeds_cascade_results[seed][i]['rcl_point'])
        rgl_rcl_points_seeds.append(all_seeds_cascade_results[seed][i]['rgl_rcl_point'])
        knn_points_seeds.append(all_seeds_cascade_results[seed][i]['knn_point'])
    
    # Calculate means and stds
    rcl_x_vals = [point[0] for point in rcl_points_seeds]
    rcl_y_vals = [point[1] for point in rcl_points_seeds]
    rgl_rcl_x_vals = [point[0] for point in rgl_rcl_points_seeds]
    rgl_rcl_y_vals = [point[1] for point in rgl_rcl_points_seeds]
    knn_x_vals = [point[0] for point in knn_points_seeds]
    knn_y_vals = [point[1] for point in knn_points_seeds]
    
    aggregated_results.append({
        'cascade': cascade,
        'last_model': cascade[-1],
        'rcl_mean': (np.mean(rcl_x_vals), np.mean(rcl_y_vals)),
        'rcl_std': (np.std(rcl_x_vals), np.std(rcl_y_vals)),
        'rgl_rcl_mean': (np.mean(rgl_rcl_x_vals), np.mean(rgl_rcl_y_vals)),
        'rgl_rcl_std': (np.std(rgl_rcl_x_vals), np.std(rgl_rcl_y_vals)),
        'knn_mean': (np.mean(knn_x_vals), np.mean(knn_y_vals)),
        'knn_std': (np.std(knn_x_vals), np.std(knn_y_vals))
    })

print(f"Evaluated {len(aggregated_results)} cascades across {len(seeds)} seeds")


#%%
# Create marker mapping for each model
marker_map = {
    'Llama 1B': 'o',
    'Llama 3B': 's', 
    'Llama 8B': '^',
    'Phi 4': 'v',
    'Gemma 27B': 'D',
    'Mixtral8x7B': 'P',
    'Llama 70B': '*'
}

# First, find the global min/max for x and y across all results
all_x_vals = []
all_y_vals = []

for result in aggregated_results:
    # Collect all x and y values including error bars
    all_x_vals.extend([
        result['rcl_mean'][0] - result['rcl_std'][0],
        result['rcl_mean'][0] + result['rcl_std'][0],
        result['rgl_rcl_mean'][0] - result['rgl_rcl_std'][0],
        result['rgl_rcl_mean'][0] + result['rgl_rcl_std'][0],
        result['knn_mean'][0] - result['knn_std'][0],
        result['knn_mean'][0] + result['knn_std'][0]
    ])
    
    all_y_vals.extend([
        result['rcl_mean'][1] - result['rcl_std'][1],
        result['rcl_mean'][1] + result['rcl_std'][1],
        result['rgl_rcl_mean'][1] - result['rgl_rcl_std'][1],
        result['rgl_rcl_mean'][1] + result['rgl_rcl_std'][1],
        result['knn_mean'][1] - result['knn_std'][1],
        result['knn_mean'][1] + result['knn_std'][1]
    ])

# Also include the biggest model points
for seed in seeds:
    for model_name in all_models:
        y_test = all_seeds_results[seed][model_name]['y_test']
        predictions = all_seeds_results[seed][model_name]['S_test']
        accuracy = accuracy_score(y_test, (predictions >= 0.5).astype(int))
        cost = costs_per_model[model_name]
        all_x_vals.append(cost)
        all_y_vals.append(accuracy)

x_min, x_max = min(all_x_vals), max(all_x_vals)
y_min, y_max = min(all_y_vals), max(all_y_vals)

# Add some padding
x_padding = (x_max - x_min) * 0.05
y_padding = (y_max - y_min) * 0.05
x_min -= x_padding
x_max += x_padding
y_min = 0.68  # Set y_min to fit the models considered
y_max = 0.78

# Only plot for specific models
models_to_plot = ['Gemma 27B', 'Mixtral8x7B', 'Llama 70B']
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for idx, last_model in enumerate(models_to_plot):
    ax = axes[idx]
    
    # Filter results for this last model
    filtered_results = [r for r in aggregated_results if r['last_model'] == last_model]
    
    if not filtered_results:
        ax.set_title(f'Last model: {last_model} (No data)')
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)
        continue
        
    rcl_x_mean = [result['rcl_mean'][0] for result in filtered_results]
    rcl_y_mean = [result['rcl_mean'][1] for result in filtered_results]
    rcl_x_std = [result['rcl_std'][0] for result in filtered_results]
    rcl_y_std = [result['rcl_std'][1] for result in filtered_results]

    rgl_rcl_x_mean = [result['rgl_rcl_mean'][0] for result in filtered_results]
    rgl_rcl_y_mean = [result['rgl_rcl_mean'][1] for result in filtered_results]
    rgl_rcl_x_std = [result['rgl_rcl_std'][0] for result in filtered_results]
    rgl_rcl_y_std = [result['rgl_rcl_std'][1] for result in filtered_results]

    knn_x_mean = [result['knn_mean'][0] for result in filtered_results]
    knn_y_mean = [result['knn_mean'][1] for result in filtered_results]
    knn_x_std = [result['knn_std'][0] for result in filtered_results]
    knn_y_std = [result['knn_std'][1] for result in filtered_results]
    
    marker = marker_map[last_model]
    
    ax.errorbar(rcl_x_mean, rcl_y_mean, xerr=rcl_x_std, yerr=rcl_y_std, 
                fmt=marker, alpha=0.7, color='blue', capsize=3, markersize=8, label='Cascades with $\widehat{R}_{CL}$')
    ax.errorbar(rgl_rcl_x_mean, rgl_rcl_y_mean, xerr=rgl_rcl_x_std, yerr=rgl_rcl_y_std,
                fmt=marker, alpha=0.7, color='red', capsize=3, markersize=8, label='Cascades with $\widehat{R}_{GL}+\widehat{R}_{CL}$')
    ax.errorbar(knn_x_mean, knn_y_mean, xerr=knn_x_std, yerr=knn_y_std,
                fmt=marker, alpha=0.7, color='green', capsize=3, markersize=8, label='Routings with KNN')
    
    # Add gold star for the biggest model (last_model)
    # Calculate mean accuracy and cost for the biggest model across seeds
    biggest_model_accuracies = []
    for seed in seeds:
        y_test = all_seeds_results[seed][last_model]['y_test']
        predictions = all_seeds_results[seed][last_model]['S_test']
        accuracy = accuracy_score(y_test, (predictions >= 0.5).astype(int))
        biggest_model_accuracies.append(accuracy)
    
    biggest_model_cost = costs_per_model[last_model]
    biggest_model_accuracy_mean = np.mean(biggest_model_accuracies)
    biggest_model_accuracy_std = np.std(biggest_model_accuracies)
    
    # ax.errorbar(biggest_model_cost, biggest_model_accuracy_mean, yerr=biggest_model_accuracy_std,
    #             fmt='*', color='gold', markersize=15, capsize=3, label=f'{last_model}', zorder=5)
    # Add black border around the star
    ax.scatter(biggest_model_cost, biggest_model_accuracy_mean, 
               s=300, marker='*', color='gold',label=f'{last_model}', zorder=4, edgecolors='black', linewidths=1)
    # Draw dotted lines connecting each triplet of points
    for i in range(len(filtered_results)):
        # Connect the three points for each cascade with dotted lines
        x_coords = [rcl_x_mean[i], rgl_rcl_x_mean[i], knn_x_mean[i]]
        y_coords = [rcl_y_mean[i], rgl_rcl_y_mean[i], knn_y_mean[i]]
        
        # Draw lines between all pairs in the triplet
        ax.plot([x_coords[0], x_coords[1]], [y_coords[0], y_coords[1]], 
                'k--', alpha=0.3, linewidth=0.5)
        ax.plot([x_coords[1], x_coords[2]], [y_coords[1], y_coords[2]], 
                'k--', alpha=0.3, linewidth=0.5)
        # ax.plot([x_coords[0], x_coords[2]], [y_coords[0], y_coords[2]], 
        #         'k--', alpha=0.3, linewidth=0.5)
    
    ax.set_xlabel('Cost')
    ax.set_ylabel('Accuracy')
    ax.set_title(f'Last model: {last_model}')
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.grid(True, alpha=0.6)
    ax.legend(loc ="lower left")

plt.tight_layout()
plt.suptitle('Cascade Performance by Last Model: Cost vs Accuracy. Fitting on embeddings', y=1.02)
plt.show()
# %%
# Calculate differences and relative costs for all cascades across all seeds
differences_data = []

for seed in seeds:
    for i, cascade in enumerate(all_cascades):
        result = all_seeds_cascade_results[seed][i]
        
        # Get metrics
        rcl_accuracy = result['rcl_point'][1]
        rgl_rcl_accuracy = result['rgl_rcl_point'][1]
        knn_accuracy = result['knn_point'][1]
        final_model_accuracy = result['last_model_accuracy']
        confidence_accuracy = result['confidence_point'][1]
        
        rcl_cost = result['rcl_point'][0]
        rgl_rcl_cost = result['rgl_rcl_point'][0]
        knn_cost = result['knn_point'][0]
        final_model_cost = result['last_model_cost']
        confidence_cost = result['confidence_point'][0]
        
        # Calculate differences
        diff_rgl_rcl_vs_rcl = rgl_rcl_accuracy - rcl_accuracy
        diff_rgl_rcl_vs_knn = rgl_rcl_accuracy - knn_accuracy
        diff_rgl_rcl_vs_final = rgl_rcl_accuracy - final_model_accuracy
        diff_rgl_rcl_vs_confidence = rgl_rcl_accuracy - confidence_accuracy
        
        # Calculate relative costs
        rel_cost_rgl_rcl = rgl_rcl_cost  / final_model_cost
        rel_cost_rgl_rcl_knn = rgl_rcl_cost / knn_cost
        rel_cost_rgl_rcl_rcl = rgl_rcl_cost/ rcl_cost
        rel_cost_rgl_rcl_confidence = rgl_rcl_cost / confidence_cost
        # Store results


        differences_data.append({
            'seed': seed,
            'cascade': str(cascade),
            'diff_rgl_rcl_vs_rcl': diff_rgl_rcl_vs_rcl,
            'diff_rgl_rcl_vs_knn': diff_rgl_rcl_vs_knn,
            'diff_rgl_rcl_vs_final': diff_rgl_rcl_vs_final,
            'diff_rgl_rcl_vs_confidence': diff_rgl_rcl_vs_confidence,
            'rel_cost_rgl_rcl': rel_cost_rgl_rcl,
            'rel_cost_rgl_rcl_knn': rel_cost_rgl_rcl_knn,
            'rel_cost_rgl_rcl_rcl': rel_cost_rgl_rcl_rcl,
            'rel_cost_rgl_rcl_confidence': rel_cost_rgl_rcl_confidence
        })

# Convert to DataFrame for easier plotting
df_differences = pd.DataFrame(differences_data)

# Create the four plots
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Plot 1: RCL+RGL vs RCL accuracy difference vs relative cost

# Fit linear regression
X = df_differences['rel_cost_rgl_rcl_rcl'].values.reshape(-1, 1)
y = df_differences['diff_rgl_rcl_vs_rcl'].values
reg = LinearRegression().fit(X, y)

# Get parameters
slope = reg.coef_[0]
intercept = reg.intercept_
r_squared = reg.score(X, y)
# Calculate means and stds for each unique relative cost ratio
# Group by relative cost ratios and calculate statistics
rcl_rcl_groups = df_differences.groupby('rel_cost_rgl_rcl_rcl')['diff_rgl_rcl_vs_rcl']
rcl_rcl_means = rcl_rcl_groups.mean()
rcl_rcl_stds = rcl_rcl_groups.std()
rcl_rcl_costs = rcl_rcl_means.index

knn_groups = df_differences.groupby('rel_cost_rgl_rcl_knn')['diff_rgl_rcl_vs_knn']
knn_means = knn_groups.mean()
knn_stds = knn_groups.std()
knn_costs = knn_means.index

final_groups = df_differences.groupby('rel_cost_rgl_rcl')['diff_rgl_rcl_vs_final']
final_means = final_groups.mean()
final_stds = final_groups.std()
final_costs = final_means.index

confidence_groups = df_differences.groupby('rel_cost_rgl_rcl_confidence')['diff_rgl_rcl_vs_confidence']
confidence_means = confidence_groups.mean()
confidence_stds = confidence_groups.std()
confidence_costs = confidence_means.index

# Plot 1: RCL+RGL vs RCL with error bars
axes[0].errorbar(rcl_rcl_costs, rcl_rcl_means, yerr=rcl_rcl_stds,
                fmt='o', alpha=0.2, capsize=3, markersize=6)

# Fit linear regression on means
X_means = rcl_rcl_costs.values.reshape(-1, 1)
y_means = rcl_rcl_means.values
reg_means = LinearRegression(fit_intercept=False).fit(X_means, y_means)
slope_means = reg_means.coef_[0]
intercept_means = 0  # Force intercept to be 0
r_squared_means = reg_means.score(X_means, y_means)

# Plot regression line
x_range = np.linspace(final_costs.min(), final_costs.max(), 100)
# Remove the first axis and create broken axis
axes[0].remove()

# Define the target cascade to highlight
target_cascade = "['Llama 1B', 'Llama 3B', 'Llama 8B', 'Llama 70B']"

# Get data for target cascade
target_data_final = df_differences[df_differences['cascade'] == target_cascade]
target_data_rcl = df_differences[df_differences['cascade'] == target_cascade]
target_data_knn = df_differences[df_differences['cascade'] == target_cascade]

# Create broken axis subplot for what was previously the third plot
ax1_top = plt.subplot(1, 8, 1)
ax1_bottom = plt.subplot(1, 8, 2)

# Define break points
break_low = 0.08
break_high = 0.23

# Plot data with error bars on both axes (RCL+RGL vs Final Model)
ax1_top.errorbar(final_costs, final_means, yerr=final_stds,
                fmt='+', alpha=0.5, capsize=3, markersize=7, color='tab:gray')
ax1_bottom.errorbar(final_costs, final_means, yerr=final_stds,
                   fmt='+', alpha=0.5, capsize=3, markersize=7, color='tab:gray')

# Highlight target cascade in broken axis plots
if len(target_data_final) > 0:
    target_cost_final = target_data_final['rel_cost_rgl_rcl'].mean()
    target_diff_final = target_data_final['diff_rgl_rcl_vs_final'].mean()
    target_std_final = target_data_final['diff_rgl_rcl_vs_final'].std()
    
    if target_diff_final > break_high:
        ax1_top.errorbar(target_cost_final, target_diff_final, yerr=target_std_final,
                        fmt='o', color='red', capsize=5, markersize=7, markeredgecolor='darkred', 
                        markeredgewidth=2, label='Target Cascade')
    elif target_diff_final < break_low:
        ax1_bottom.errorbar(target_cost_final, target_diff_final, yerr=target_std_final,
                           fmt='o', color='red', capsize=5, markersize=7, markeredgecolor='darkred', 
                           markeredgewidth=2, label='Target Cascade')
    else:
        # Point is in the break region, plot on both
        ax1_top.errorbar(target_cost_final, target_diff_final, yerr=target_std_final,
                        fmt='o', color='red', capsize=5, markersize=7, markeredgecolor='darkred', 
                        markeredgewidth=2, label='Target Cascade')
        ax1_bottom.errorbar(target_cost_final, target_diff_final, yerr=target_std_final,
                           fmt='o', color='red', capsize=5, markersize=7, markeredgecolor='darkred', 
                           markeredgewidth=2)

# Set the y-axis limits
ax1_top.set_ylim(break_high, 0.25)
ax1_bottom.set_ylim(-0.02, break_low)

# Hide the spines between the two axes
ax1_top.spines['bottom'].set_visible(False)
ax1_bottom.spines['top'].set_visible(False)
ax1_top.xaxis.tick_top()
ax1_top.tick_params(labeltop=False, labelsize=14)
ax1_bottom.xaxis.tick_bottom()
ax1_bottom.tick_params(labelsize=14)
ax1_top.tick_params(labelsize=14)

ax1_top.axvline(x=1, color='red', linestyle='--', alpha=0.5)
# ax1_top.axvline(x=1, color='red', linestyle='-', alpha=0.4, linewidth=1)
ax1_bottom.axvline(x=1, color='red', linestyle='--', alpha=0.5)
# ax1_bottom.axvline(x=1, color='red', linestyle='-', alpha=0.4, linewidth=1)

sns.kdeplot(data=df_differences, x='rel_cost_rgl_rcl', y='diff_rgl_rcl_vs_final', 
            ax=ax1_top, fill=True, alpha=0.6, color='tab:orange')
sns.kdeplot(data=df_differences, x='rel_cost_rgl_rcl',y='diff_rgl_rcl_vs_final', 
            ax=ax1_bottom, fill=True, alpha=0.6, color='tab:orange')
# Add break lines
d = .5  # proportion of vertical to horizontal extent of the slanted line
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12,
              linestyle="none", color='k', mec='k', mew=1, clip_on=False)
ax1_top.plot([0, 1], [0, 0], transform=ax1_top.transAxes, **kwargs)
ax1_bottom.plot([0, 1], [1, 1], transform=ax1_bottom.transAxes, **kwargs)

# Add horizontal reference line
ax1_top.axhline(y=0, color='red', linestyle='--', alpha=0.4)
ax1_bottom.axhline(y=0, color='red', linestyle='--', alpha=0.4)

# Add grids
ax1_top.grid(True, alpha=0.5)
ax1_bottom.grid(True, alpha=0.5)

# Labels
ax1_bottom.set_xlabel('Fraction of cost', fontsize=17)
ax1_bottom.set_ylabel('Gain in accuracy', y=0.6, fontsize=17)
ax1_top.set_title('Ours vs biggest model', fontsize=19)
ax1_bottom.set_title('')

ax1_top.set_ylabel('')
# Adjust positions to make plots closer together
ax1_bottom.set_position([0.05, 0.11, 0.25, 0.61])  # Increased width
ax1_top.set_position([0.05, 0.76, 0.25, 0.12])     # Increased width

# Plot 2: RCL+RGL vs RCL with error bars (now in second position)
axes[1].errorbar(rcl_rcl_costs, rcl_rcl_means, yerr=rcl_rcl_stds,
                fmt='+', alpha=0.6, capsize=3, markersize=7, color='tab:gray')

# Highlight target cascade in RCL plot
if len(target_data_rcl) > 0:
    target_cost_rcl = target_data_rcl['rel_cost_rgl_rcl_rcl'].mean()
    target_diff_rcl = target_data_rcl['diff_rgl_rcl_vs_rcl'].mean()
    target_std_rcl = target_data_rcl['diff_rgl_rcl_vs_rcl'].std()
    axes[1].errorbar(target_cost_rcl, target_diff_rcl, yerr=target_std_rcl,
                    fmt='o', color='red', capsize=5, markersize=7, markeredgecolor='darkred', 
                    markeredgewidth=2, label='Target Cascade')

axes[1].set_xlabel('Fraction of cost', fontsize=17)
axes[1].set_ylabel(' ', fontsize=11)
axes[1].set_title('Ours vs $\widehat{\mathcal{R}}^{CL} only$', fontsize=19)
axes[1].set_ylim(-0.002, 0.035)
axes[1].set_xlim(0.975, None)
axes[1].grid(True, alpha=0.5)
axes[1].axhline(y=0, color='red', linestyle='--', alpha=0.5)
axes[1].tick_params(labelsize=14)
axes[1].set_position([0.35, 0.11, 0.25, 0.77])  # Moved closer, increased width

axes[1].axvline(x=1, color='red', linestyle='--', alpha=0.5)

sns.kdeplot(data=df_differences, x='rel_cost_rgl_rcl_rcl', y='diff_rgl_rcl_vs_rcl',
            ax=axes[1], fill=True, alpha=0.5, color='tab:green')

# Plot 3: RCL+RGL vs KNN with error bars (now in third position)
axes[2].errorbar(knn_costs, knn_means, yerr=knn_stds,
                fmt='+', alpha=0.6, capsize=3, markersize=7, color='tab:gray')

# Highlight target cascade in KNN plot
if len(target_data_knn) > 0:
    target_cost_knn = target_data_knn['rel_cost_rgl_rcl_knn'].mean()
    target_diff_knn = target_data_knn['diff_rgl_rcl_vs_knn'].mean()
    target_std_knn = target_data_knn['diff_rgl_rcl_vs_knn'].std()
    axes[2].errorbar(target_cost_knn, target_diff_knn, yerr=target_std_knn,
                    fmt='o', color='red', capsize=5, markersize=7, markeredgecolor='darkred', 
                    markeredgewidth=2, label='Target Cascade')

# Add vertical double red line at x=1 to indicate cost parity
axes[2].axvline(x=1, color='red', linestyle='--', alpha=0.5)
# axes[2].axvline(x=1, color='red', linestyle='-', alpha=0.4, linewidth=1)

axes[2].set_xlim(0.661, 1.69)
axes[2].set_ylim(-0.002, 0.035)
axes[2].set_xlabel('Fraction of cost', fontsize=17)
axes[2].set_ylabel(' ', fontsize=11)
axes[2].set_title('Ours vs predictive router', fontsize=19)
axes[2].grid(True, alpha=0.5)
axes[2].axhline(y=0, color='red', linestyle='--', alpha=0.5)
# axes[2].legend(fontsize=12, loc ="lower right")
axes[2].tick_params(labelsize=14)
axes[2].set_position([0.65, 0.11, 0.25, 0.77])  # Moved closer, increased width

sns.kdeplot(data=df_differences, x='rel_cost_rgl_rcl_knn', y='diff_rgl_rcl_vs_knn',
            ax=axes[2], fill=True, alpha=0.5, color='tab:blue')

plt.savefig('accuracy_improvement_vs_cost_deferral_final.pdf', bbox_inches='tight')
plt.show()
# plt.subplots_adjust(wspace=0.15, hspace=0.1)
# plt.suptitle('Distribution of Accuracy Differences vs Relative Costs', y=1.02, fontsize=16)
# plt.savefig('accuracy_improvement_vs_cost_deferral.pdf', bbox_inches='tight')
# plt.show()
# # Print some summary statistics
# print("Summary Statistics:")
# print(f"RCL+RGL vs RCL accuracy difference: mean={df_differences['diff_rgl_rcl_vs_rcl'].mean():.4f}, std={df_differences['diff_rgl_rcl_vs_rcl'].std():.4f}")
# print(f"RCL+RGL vs KNN accuracy difference: mean={df_differences['diff_rgl_rcl_vs_knn'].mean():.4f}, std={df_differences['diff_rgl_rcl_vs_knn'].std():.4f}")
# print(f"RCL+RGL vs Final Model accuracy difference: mean={df_differences['diff_rgl_rcl_vs_final'].mean(25):.4f}, std={df_differences['diff_rgl_rcl_vs_final'].std():.4f}")
# print(f"Relative cost RCL+RGL: mean={df_differences['rel_cost_rgl_rcl'].mean():.4f}, std={df_differences['rel_cost_rgl_rcl'].std():.4f}")
# %%

# Count proportions of points with positive diff_rgl_rcl_vs_knn and rel_cost_rgl_rcl_knn < 1
positive_diff = df_differences['diff_rgl_rcl_vs_knn'] > 0
cost_less_than_1 = df_differences['rel_cost_rgl_rcl_knn'] < 1

# Points with positive difference and cost < 1
positive_and_cheaper = positive_diff & cost_less_than_1

# Calculate proportions
total_points = len(df_differences)
positive_diff_count = positive_diff.sum()
cheaper_count = cost_less_than_1.sum()
positive_and_cheaper_count = positive_and_cheaper.sum()

print("Proportions analysis:")
print(f"Total points: {total_points}")
print(f"Points with positive diff_rgl_rcl_vs_knn: {positive_diff_count} ({positive_diff_count/total_points:.3f})")
print(f"Points with rel_cost_rgl_rcl_knn < 1: {cheaper_count} ({cheaper_count/total_points:.3f})")
print(f"Points with positive diff AND cost < 1: {positive_and_cheaper_count} ({positive_and_cheaper_count/total_points:.3f})")
# %%
