#%%
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from folktexts.acs import ACSDataset
from sklearn.tree import DecisionTreeRegressor
import glest
from folktexts.acs import ACSTaskMetadata
from deferral_experiment.regret_helpers import compute_regret_CL, get_constant_utilty, get_threshold_from_utility
from utils import honest_tree_pred
from sklearn.metrics import roc_auc_score, accuracy_score, brier_score_loss
import seaborn as sns
import matplotlib.pyplot as plt
# #%%

# llama3 = pd.read_csv('folktexts-results/folktexts-results/model-Llama-3.2-3B-Instruct_task-ACSIncome/Llama-3.2-3B-Instruct_bench-3406182338/ACSIncome_full_seed-42_hash-1998608642.test_predictions.csv')
# llama8 = pd.read_csv('folktexts-results/folktexts-results/model-Llama-3.1-8B-Instruct_task-ACSIncome/Llama-3.1-8B-Instruct_bench-2720683681/ACSIncome_full_seed-42_hash-1998608642.test_predictions.csv')
# llama1 = pd.read_csv('folktexts-results/folktexts-results/model-Llama-3.2-1B-Instruct_task-ACSIncome/Llama-3.2-1B-Instruct_bench-75857734/ACSIncome_full_seed-42_hash-1998608642.test_predictions.csv')
# llama70 = pd.read_csv('folktexts-results/folktexts-results/model-Meta-Llama-3-70B-Instruct_task-ACSIncome/Meta-Llama-3-70B-Instruct_bench-1000148012/ACSIncome_full_seed-42_hash-1998608642.test_predictions.csv')
# gemma27 = pd.read_csv('folktexts-results/folktexts-results/model-gemma-2-27b-it_task-ACSIncome/gemma-2-27b-it_bench-758589169/ACSIncome_full_seed-42_hash-1998608642.test_predictions.csv')
# mixtral8x7b = pd.read_csv('folktexts-results/folktexts-results/model-Mixtral-8x7B-Instruct-v0.1_task-ACSIncome/Mixtral-8x7B-Instruct-v0.1_bench-3380890101/ACSIncome_full_seed-42_hash-1998608642.test_predictions.csv')
# phi4 = pd.read_csv('folktexts-results/folktexts-results/model-phi-4_task-ACSIncome/phi-4_bench-1406674137/ACSIncome_full_seed-42_hash-1998608642.test_predictions.csv')
# #%%

# DATA_DIR = "notebooks/data"
# TASK_NAME = "ACSIncome"
# task = ACSTaskMetadata.get_task(TASK_NAME, use_numeric_qa=False)

# dataset = ACSDataset.make_from_task(task=task, cache_dir=DATA_DIR)
# # %%
# dataset.data
# %%


# %%
# # Process llama1
# features1 = dataset.get_features_data()
# matched_features1 = features1.loc[llama1["Unnamed: 0"]].copy()
# test1 = llama1.set_index(llama1["Unnamed: 0"].values)
# test1.drop(columns=["Unnamed: 0"], inplace=True)
# merged_df1 = pd.concat([matched_features1, test1], axis=1)
# merged_df1.to_csv('deferral_experiment/llama1_instruct.csv', index=False)

# # Process llama3
# features3 = dataset.get_features_data()
# matched_features3 = features3.loc[llama3["Unnamed: 0"]].copy()
# test3 = llama3.set_index(llama3["Unnamed: 0"].values)
# test3.drop(columns=["Unnamed: 0"], inplace=True)
# merged_df3 = pd.concat([matched_features3, test3], axis=1)
# merged_df3.to_csv('deferral_experiment/llama3_instruct.csv', index=False)

# # Process llama8
# features8 = dataset.get_features_data()
# matched_features8 = features8.loc[llama8["Unnamed: 0"]].copy()
# test8 = llama8.set_index(llama8["Unnamed: 0"].values)
# test8.drop(columns=["Unnamed: 0"], inplace=True)
# merged_df8 = pd.concat([matched_features8, test8], axis=1)
# merged_df8.to_csv('deferral_experiment/llama8_instruct.csv', index=False)

# # Process llama70
# features70 = dataset.get_features_data()
# matched_features70 = features70.loc[llama70["Unnamed: 0"]].copy()
# test70 = llama70.set_index(llama70["Unnamed: 0"].values)
# test70.drop(columns=["Unnamed: 0"], inplace=True)
# merged_df70 = pd.concat([matched_features70, test70], axis=1)
# merged_df70.to_csv('deferral_experiment/llama70_instruct.csv', index=False)

# # Process gemma27
# features_gemma27 = dataset.get_features_data()
# matched_features_gemma27 = features_gemma27.loc[gemma27["Unnamed: 0"]].copy()
# test_gemma27 = gemma27.set_index(gemma27["Unnamed: 0"].values)
# test_gemma27.drop(columns=["Unnamed: 0"], inplace=True)
# merged_df_gemma27 = pd.concat([matched_features_gemma27, test_gemma27], axis=1)
# merged_df_gemma27.to_csv('deferral_experiment/gemma27_instruct.csv', index=False)

# # Process mixtral8x7b
# features_mixtral8x7b = dataset.get_features_data()
# matched_features_mixtral8x7b = features_mixtral8x7b.loc[mixtral8x7b["Unnamed: 0"]].copy()
# test_mixtral8x7b = mixtral8x7b.set_index(mixtral8x7b["Unnamed: 0"].values)
# test_mixtral8x7b.drop(columns=["Unnamed: 0"], inplace=True)
# merged_df_mixtral8x7b = pd.concat([matched_features_mixtral8x7b, test_mixtral8x7b], axis=1)
# merged_df_mixtral8x7b.to_csv('deferral_experiment/mixtral8x7b_instruct.csv', index=False)

# # Process phi4
# features_phi4 = dataset.get_features_data()
# matched_features_phi4 = features_phi4.loc[phi4["Unnamed: 0"]].copy()
# test_phi4 = phi4.set_index(phi4["Unnamed: 0"].values)
# test_phi4.drop(columns=["Unnamed: 0"], inplace=True)
# merged_df_phi4 = pd.concat([matched_features_phi4, test_phi4], axis=1)
# merged_df_phi4.to_csv('deferral_experiment/phi4_instruct.csv', index=False)

# #%%

# # Save the raw dataframes with "non instruct" in filenames for later use
# llama1.to_csv('deferral_experiment/llama1_instruct.csv', index=False)
# llama3.to_csv('deferral_experiment/llama3_instruct.csv', index=False)
# llama8.to_csv('deferral_experiment/llama8_instruct.csv', index=False)
# llama70.to_csv('deferral_experiment/llama70_instruct.csv', index=False)

#%%


llama1 = pd.read_csv('deferral_experiment/llama1_instruct.csv')
llama3 = pd.read_csv('deferral_experiment/llama3_instruct.csv')
llama8 = pd.read_csv('deferral_experiment/llama8_instruct.csv')
llama70 = pd.read_csv('deferral_experiment/llama70_instruct.csv')
phi4 = pd.read_csv('deferral_experiment/phi4_instruct.csv')
gemma27 = pd.read_csv('deferral_experiment/gemma27_instruct.csv')
mixtral8x7b = pd.read_csv('deferral_experiment/mixtral8x7b_instruct.csv')


costs_per_model = {
    'llama1': 0.04,
    'llama3': 0.06,
    'llama8': 0.18,
    'llama70': 0.88,
    'gemma27' : 0.25,
    'Mixtral8x7B': 0.70,
    'phi4' : 0.22
}

costs_of_all_models = {'llama1' : 0.04 * len(llama1)//2,
    'llama3': 0.06 * len(llama3)//10,
    'llama8': 0.18 * len(llama8)//10,
    'llama70': 0.88 * len(llama70)//2,
    'gemma27' : 0.25 * len(gemma27)//2,
    'Mixtral8x7B': 0.70 * len(mixtral8x7b)//10,
    'phi4' : 0.22 * len(phi4)//2
}

#%%

# Process all models
models = {'llama1' : llama1, 'llama3': llama3, 'llama8': llama8, 'llama70': llama70, 'gemma27': gemma27, 'Mixtral8x7B': mixtral8x7b, 'phi4': phi4}
results = {}

for model_name, model_df in models.items():
    print(f"Processing {model_name}...")
    
    X = model_df.drop(columns=['risk_score', 'label']).values
    y = model_df['label'].values
    S = model_df['risk_score'].values
    
    t_target = [0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975, 0.99]
    U = get_constant_utilty(100, t_target)  # (n_utilities, 2, 2)
    t = get_threshold_from_utility(U) 

    calibrated_classifier = LogisticRegression()
    X, X_leftover, y, y_leftover, S, S_leftover = train_test_split(
        X, y, S, test_size=0.1, random_state=0
    )
    X_train, X_test, y_train, y_test, S_train, S_test = train_test_split(
        X, y, S, test_size=0.5, random_state=0
    )

    X_train, X_cal, y_train, y_cal, S_train, S_cal = train_test_split(
        X_train, y_train, S_train, test_size=max(int(len(X_train) * 0.2),4000), random_state=0
    )

    calibrated_classifier.fit(S_cal.reshape(-1,1), y_cal)

    c_hat_train = calibrated_classifier.predict_proba(S_train.reshape(-1,1))[:, 1]
    c_hat_test = calibrated_classifier.predict_proba(S_test.reshape(-1,1))[:, 1]

    residuals_train = y_train - c_hat_train
    residuals_test = y_test - c_hat_test
    dt = DecisionTreeRegressor(max_depth = None, min_samples_leaf= 15)
    dt.fit(X_train, residuals_train)
    leaf_ids = dt.apply(X_test)

    gle = glest.core.GLEstimatorResiduals(None, None)
    gle.fit(X_test, y_test, y_scores_cal = c_hat_test, partition = leaf_ids)


    c_hat_leftover = calibrated_classifier.predict_proba(S_leftover.reshape(-1,1))[:, 1]
    r_hat_leftover = honest_tree_pred(dt, gle.honest_rj, X_leftover)

    t = 0.5
    a = (S_leftover[:, None] >= t).astype(int)
    RCL = compute_regret_CL(c_hat_leftover, t, a)  # (n, k)
    a = (c_hat_leftover[:, None] >= t).astype(int)  # (n, k)
    RGL = compute_regret_CL(c_hat_leftover + r_hat_leftover, t, a)  # (n, k)
    
    # Store results
    results[model_name] = {
        'X_test': X_leftover,
        'y_test': y_leftover, 
        'S_test': S_leftover,
        'c_hat_test': c_hat_leftover,
        'r_hat': r_hat_leftover,
        'RCL': RCL,
        'RGL': RGL,
        'gle': gle,
        'tree': dt,
    }

# For compatibility with existing code, set variables for llama1
# X_test = results['llama1']['X_test']
y_test = results['llama1']['y_test']
# S_test = results['llama1']['S_test']
# RGL = results['llama1']['RGL']
S_llama70_test = results['llama70']['S_test']

#%%


# ========================================================
# Cascade model selection based on RGL + RCL
# =======================================================

# For each sample, choose the prediction from the model that shows the smallest RGL
# If two RGL are equal, choose the prediction from the smallest model
# Get RGL values for all models

model_names = ['llama1', 'llama3', 'llama8', 'llama70', 'gemma27', 'Mixtral8x7B', 'phi4']
model_sizes = [1, 3, 8, 70, 27, 56, 14]  # Model sizes for tie-breaking
all_rgl = np.column_stack([results[name]['RGL'].flatten() + results[name]['RCL'].flatten() for name in model_names])
all_predictions = np.column_stack([results[name]['S_test'] for name in model_names])


#%%
# Find the model with minimum RGL for each sample
min_rgl_indices = np.argmin(all_rgl, axis=1)

# Handle ties by choosing the smallest model
for i in range(len(all_rgl)):
    min_rgl = all_rgl[i, min_rgl_indices[i]]
    tied_models = np.where(all_rgl[i] == min_rgl)[0]
    if len(tied_models) > 1:
        # Choose the smallest model among tied models
        smallest_tied_model = tied_models[np.argmin([model_sizes[j] for j in tied_models])]
        min_rgl_indices[i] = smallest_tied_model

# Create the optimal mixed predictions
optimal_predictions = all_predictions[np.arange(len(all_predictions)), min_rgl_indices]

# Calculate proportion of calls to each model
model_calls = np.bincount(min_rgl_indices, minlength=len(model_names))
call_proportions = model_calls / len(min_rgl_indices)

print("Casacade proportion of calls to each model:")
for i, name in enumerate(model_names):
    print(f"{name}: {call_proportions[i]:.4f} ({model_calls[i]} samples)")

# Calculate accuracy, ROC AUC and Brier score for optimal mixed predictions
auc_optimal = roc_auc_score(y_test, optimal_predictions)
acc_optimal = accuracy_score(y_test, (optimal_predictions >= 0.5).astype(int))
brier_optimal = brier_score_loss(y_test, optimal_predictions)

print(f"\nCascade Mixed Model Performance:")
print(f"AUC: {auc_optimal:.4f}")
print(f"Accuracy: {acc_optimal:.4f}")
print(f"Brier Score: {brier_optimal:.4f}")

print("\n===================================")
# Oracle baseline: choose smallest model that gets the correct answer
oracle_indices = np.full(len(y_test), -1)
for i in range(len(y_test)):
    # Get binary predictions for all models at this sample
    binary_preds = (all_predictions[i] >= 0.5).astype(int)
    # Find which models are correct
    correct_models = np.where(binary_preds == y_test[i])[0]
    if len(correct_models) > 0:
        # Choose the smallest model among correct ones
        oracle_indices[i] = correct_models[np.argmin([model_sizes[j] for j in correct_models])]
    else:
        # If no model is correct, choose the smallest model
        oracle_indices[i] = 0

# Create oracle predictions
oracle_predictions = all_predictions[np.arange(len(all_predictions)), oracle_indices]

# Calculate oracle performance
auc_oracle = roc_auc_score(y_test, oracle_predictions)
acc_oracle = accuracy_score(y_test, (oracle_predictions >= 0.5).astype(int))
brier_oracle = brier_score_loss(y_test, oracle_predictions)

# Calculate oracle model usage
oracle_calls = np.bincount(oracle_indices, minlength=len(model_names))
oracle_proportions = oracle_calls / len(oracle_indices)

print(f"\nOracle Baseline Performance:")
print(f"AUC: {auc_oracle:.4f}")
print(f"Accuracy: {acc_oracle:.4f}")
print(f"Brier Score: {brier_oracle:.4f}")

print(f"\nOracle proportion of calls to each model:")
for i, name in enumerate(model_names):
    print(f"{name}: {oracle_proportions[i]:.4f} ({oracle_calls[i]} samples)")
print("\n===================================")
# Compare with individual models
print(f"\nIndividual Model Performance:")
for i, name in enumerate(model_names):
    model_predictions = results[name]['S_test']
    auc = roc_auc_score(y_test, model_predictions)
    acc = accuracy_score(y_test, (model_predictions >= 0.5).astype(int))
    brier = brier_score_loss(y_test, model_predictions)
    print(f"{name} - AUC: {auc:.4f}, Accuracy: {acc:.4f}, Brier Score: {brier:.4f}")
#%%

# ========================================================
# 50% llama1 + 50% llama70 based on highest regret
# =======================================================


S_test = results['llama1']['S_test']
y_test = results['llama1']['y_test']
llama70_predictions = results['llama70']['S_test']
RGL = results['llama1']['RGL']
RCL = results['llama1']['RCL']


# Choose the 10% of samples with highest RGL + RCL in llama1
total_regret_llama1 = RGL.flatten() + RCL.flatten()
n_samples = len(total_regret_llama1)
top_10_percent_count = int(0.5 * n_samples)

# Get indices of top 10% highest regret samples
top_regret_indices = np.argsort(total_regret_llama1)[-top_10_percent_count:]

# Replace their values with predictions from llama70
S_test_modified = S_test.copy()
S_test_modified[top_regret_indices] = llama70_predictions[top_regret_indices]

print(f"Selected {len(top_regret_indices)} samples with highest regret (top 10%)")
print(f"Regret range: {total_regret_llama1[top_regret_indices].min():.4f} to {total_regret_llama1[top_regret_indices].max():.4f}")
print(f"Replaced {len(top_regret_indices)} predictions with llama70 values")
# Calculate performance metrics for the modified predictions
auc_modified = roc_auc_score(y_test, S_test_modified)
acc_modified = accuracy_score(y_test, (S_test_modified >= 0.5).astype(int))
brier_modified = brier_score_loss(y_test, S_test_modified)

print(f"\nModified Predictions Performance (50% llama1 + 50% llama70):")
print(f"AUC: {auc_modified:.4f}")
print(f"Accuracy: {acc_modified:.4f}")
print(f"Brier Score: {brier_modified:.4f}")

# Compare with original llama1 performance
auc_original = roc_auc_score(y_test, S_test)
acc_original = accuracy_score(y_test, (S_test >= 0.5).astype(int))
brier_original = brier_score_loss(y_test, S_test)

print(f"\nOriginal Llama1 Performance:")
print(f"AUC: {auc_original:.4f}")
print(f"Accuracy: {acc_original:.4f}")
print(f"Brier Score: {brier_original:.4f}")

print(f"\nImprovement:")
print(f"AUC: {auc_modified - auc_original:+.4f}")
print(f"Accuracy: {acc_modified - acc_original:+.4f}")
print(f"Brier Score: {brier_modified - brier_original:+.4f}")
# %%


# ========================================================
# Real Cascade (no oracle)
# =======================================================


# Real Cascade: Accept model response if regret is below threshold
# with cost constraint and no deferral once cost budget is exceeded
threshold = 0.02
cost_budget = 20000

cascade_predictions = np.full(len(y_test), np.nan)
cascade_model_used = np.full(len(y_test), -1)
total_cost = 0
budget_exceeded_at = None

# Order models from smallest to largest
ordered_models = ['llama1', 'llama3', 'llama8', 'llama70']
ordered_models = ['llama1', 'llama8']


#%%
for i in range(len(y_test)):
    # If budget exceeded, use smallest model without checking regret
    if total_cost >= cost_budget:
        if budget_exceeded_at is None:
            budget_exceeded_at = i
        cascade_predictions[i] = results['llama1']['S_test'][i]
        cascade_model_used[i] = 0
        total_cost += costs_per_model['llama1']
        continue
    
    for j, model_name in enumerate(ordered_models):
        # Check if using this model would exceed budget
        if total_cost + costs_per_model[model_name] > cost_budget:
            continue
            
        # Calculate total regret for this sample and model
        total_regret = results[model_name]['RGL'][i] + results[model_name]['RCL'][i]
        
        # If regret is below threshold, accept this model's prediction
        if total_regret < threshold:
            cascade_predictions[i] = results[model_name]['S_test'][i]
            cascade_model_used[i] = j
            total_cost += costs_per_model[model_name]
            break
    
    # If no model meets threshold or budget constraint, use smallest model
    if np.isnan(cascade_predictions[i]):
        if total_cost + costs_per_model['llama1'] <= cost_budget:
            cascade_predictions[i] = results['llama1']['S_test'][i]
            cascade_model_used[i] = 0
            total_cost += costs_per_model['llama1']
        else:
            # Budget exceeded, use smallest model without cost
            if budget_exceeded_at is None:
                budget_exceeded_at = i
            cascade_predictions[i] = results['llama1']['S_test'][i]
            cascade_model_used[i] = 0
            total_cost += costs_per_model['llama1']

# Calculate cascade performance
auc_cascade = roc_auc_score(y_test, cascade_predictions)
acc_cascade = accuracy_score(y_test, (cascade_predictions >= 0.5).astype(int))
brier_cascade = brier_score_loss(y_test, cascade_predictions)

# Calculate model usage
cascade_calls = np.bincount(cascade_model_used, minlength=len(ordered_models))
cascade_proportions = cascade_calls / len(cascade_model_used)

print(f"\nReal Cascade Performance (threshold={threshold}, budget={cost_budget}):")
print(f"Total cost: ${total_cost:.2f}")
print(f"Budget exceeded at sample: {budget_exceeded_at}")
print(f"AUC: {auc_cascade:.4f}")
print(f"Accuracy: {acc_cascade:.4f}")
print(f"Brier Score: {brier_cascade:.4f}")

print(f"\nReal Cascade proportion of calls to each model:")
for i, name in enumerate(ordered_models):
    print(f"{name}: {cascade_proportions[i]:.4f} ({cascade_calls[i]} samples)")


print(f"\nIndividual Model Performance for reference:")
for i, name in enumerate(ordered_models):
    model_predictions = results[name]['S_test']
    auc = roc_auc_score(y_test, model_predictions)
    acc = accuracy_score(y_test, (model_predictions >= 0.5).astype(int))
    brier = brier_score_loss(y_test, model_predictions)
    cost = costs_of_all_models[name]
    print(f"{name} - AUC: {auc:.4f}, Accuracy: {acc:.4f}, Brier Score: {brier:.4f}")
    print(f"Total cost if all samples processed by {name}: ${cost:.2f}")

#%%

# Test different cost budgets
cost_budgets = [2000, 5000, 8000, 10000,12500, 15000, 17500,  20000, 30000, 50000, 60000, 70000, 80000, 90000, 100000, ]
threshold = 0.05
smallest_model = ordered_models[0]
budget_results = []

for budget in cost_budgets:
    cascade_predictions = np.full(len(y_test), np.nan)
    cascade_model_used = np.full(len(y_test), -1)
    total_cost = 0
    
    for i in range(len(y_test)):
        # If budget exceeded, use smallest model without checking regret
        if total_cost >= budget:
            cascade_predictions[i] = results[smallest_model]['S_test'][i]
            cascade_model_used[i] = 0
            total_cost += costs_per_model[smallest_model]
            continue
        
        for j, model_name in enumerate(ordered_models):
            # Check if using this model would exceed budget
            if total_cost + costs_per_model[model_name] > budget:
                continue
                
            # Calculate total regret for this sample and model
            total_regret = results[model_name]['RCL'][i] # + results[model_name]['RGL'][i]
            
            # If regret is below threshold, accept this model's prediction
            if total_regret < threshold:
                cascade_predictions[i] = results[model_name]['S_test'][i]
                cascade_model_used[i] = j
                total_cost += costs_per_model[model_name]
                break
        
        # If no model meets threshold or budget constraint, use smallest model
        if np.isnan(cascade_predictions[i]):
            if total_cost + costs_per_model[smallest_model] <= budget:
                cascade_predictions[i] = results[smallest_model]['S_test'][i]
                cascade_model_used[i] = 0
                total_cost += costs_per_model[smallest_model]
            else:
                cascade_predictions[i] = results[smallest_model]['S_test'][i]
                cascade_model_used[i] = 0
                total_cost += costs_per_model[smallest_model]
    
    # Calculate performance
    acc = accuracy_score(y_test, (cascade_predictions >= 0.5).astype(int))
    avg_cost = total_cost / len(y_test)
    
    budget_results.append({
        'budget': budget,
        'total_cost': total_cost,
        'avg_cost': avg_cost,
        'accuracy': acc
    })

# Extract data for plotting
budgets = [r['budget'] for r in budget_results]
avg_costs = [r['avg_cost'] for r in budget_results]
accuracies = [r['accuracy'] for r in budget_results]

# Calculate individual model performance and costs
individual_models = []
for name in ordered_models:
    model_predictions = results[name]['S_test']
    acc = accuracy_score(y_test, (model_predictions >= 0.5).astype(int))
    avg_cost = costs_of_all_models[name] / len(y_test)
    individual_models.append({
        'name': name,
        'avg_cost': avg_cost,
        'accuracy': acc
    })

# Create the plot
plt.figure(figsize=(6, 6))

# Plot cascade results
plt.plot(avg_costs, accuracies, 'bo-', linewidth=2, markersize=6, label='Real Cascade no RGL', alpha = 0.3)


budget_results = []

for budget in cost_budgets:
    cascade_predictions = np.full(len(y_test), np.nan)
    cascade_model_used = np.full(len(y_test), -1)
    total_cost = 0
    
    for i in range(len(y_test)):
        # If budget exceeded, use smallest model without checking regret
        if total_cost >= budget:
            cascade_predictions[i] = results[smallest_model]['S_test'][i]
            cascade_model_used[i] = 0
            total_cost += costs_per_model[smallest_model]
            continue
        
        for j, model_name in enumerate(ordered_models):
            # Check if using this model would exceed budget
            if total_cost + costs_per_model[model_name] > budget:
                continue
                
            # Calculate total regret for this sample and model
            total_regret = results[model_name]['RCL'][i]  + results[model_name]['RGL'][i]
            
            # If regret is below threshold, accept this model's prediction
            if total_regret < threshold:
                cascade_predictions[i] = results[model_name]['S_test'][i]
                cascade_model_used[i] = j
                total_cost += costs_per_model[model_name]
                break
        
        # If no model meets threshold or budget constraint, use smallest model
        if np.isnan(cascade_predictions[i]):
            if total_cost + costs_per_model[smallest_model] <= budget:
                cascade_predictions[i] = results[smallest_model]['S_test'][i]
                cascade_model_used[i] = 0
                total_cost += costs_per_model[smallest_model]
            else:
                cascade_predictions[i] = results[smallest_model]['S_test'][i]
                cascade_model_used[i] = 0
                total_cost += costs_per_model[smallest_model]
    
    # Calculate performance
    acc = accuracy_score(y_test, (cascade_predictions >= 0.5).astype(int))
    avg_cost = total_cost / len(y_test)
    
    budget_results.append({
        'budget': budget,
        'total_cost': total_cost,
        'avg_cost': avg_cost,
        'accuracy': acc
    })

# Extract data for plotting
budgets = [r['budget'] for r in budget_results]
avg_costs = [r['avg_cost'] for r in budget_results]
accuracies = [r['accuracy'] for r in budget_results]

# Calculate individual model performance and costs
individual_models = []
for name in ordered_models:
    model_predictions = results[name]['S_test']
    acc = accuracy_score(y_test, (model_predictions >= 0.5).astype(int))
    avg_cost = costs_of_all_models[name] / len(y_test)
    individual_models.append({
        'name': name,
        'avg_cost': avg_cost,
        'accuracy': acc
    })

# Create the plot
# plt.figure(figsize=(6, 6))

# Plot cascade results
plt.plot(avg_costs, accuracies, 'ro-', linewidth=2, markersize=6, label='Real Cascade RGL included', alpha = 0.3)


# Oracle baseline: choose smallest model that gets the correct answer
oracle_indices = np.full(len(y_test), -1)
for i in range(len(y_test)):
    # Get binary predictions for all models at this sample
    binary_preds = (all_predictions[i] >= 0.5).astype(int)
    # Find which models are correct
    correct_models = np.where(binary_preds == y_test[i])[0]
    if len(correct_models) > 0:
        # Choose the smallest model among correct ones
        oracle_indices[i] = correct_models[np.argmin([model_sizes[j] for j in correct_models])]
    else:
        # If no model is correct, choose the smallest model
        oracle_indices[i] = 0

# Create oracle predictions
oracle_predictions = all_predictions[np.arange(len(all_predictions)), oracle_indices]

# Calculate oracle performance
auc_oracle = roc_auc_score(y_test, oracle_predictions)
acc_oracle = accuracy_score(y_test, (oracle_predictions >= 0.5).astype(int))
brier_oracle = brier_score_loss(y_test, oracle_predictions)

# Calculate oracle model usage and average cost
oracle_calls = np.bincount(oracle_indices, minlength=len(model_names))
oracle_proportions = oracle_calls / len(oracle_indices)

# Calculate oracle average cost
oracle_avg_cost = sum(oracle_proportions[i] * costs_per_model[model_names[i]] for i in range(len(model_names)))


# Add oracle to the plot
plt.scatter(oracle_avg_cost, acc_oracle, 
           color='gold', s=200, marker='*', 
           label='Oracle Baseline', zorder=10, edgecolors='black', linewidth=1)



# Plot individual models
model_colors = ['red', 'green', 'orange', 'purple']
for i, model in enumerate(individual_models):
    plt.scatter(model['avg_cost'], model['accuracy'], 
               color=model_colors[i], s=100, marker='s', 
               label=f"{model['name']}", zorder=5)

plt.xlabel('Average Cost per Sample ($)')
plt.ylabel('Accuracy')
plt.title('Average Cost vs Accuracy: Real Cascade vs Individual Models')
plt.grid(True, alpha=0.3)
plt.legend()

# Add budget labels for cascade points
# for i, budget in enumerate(budgets):
#     plt.annotate(f'${budget}', 
#                 (avg_costs[i], accuracies[i]), 
#                 textcoords="offset points", 
#                 xytext=(0,10), 
#                 ha='center',
#                 fontsize=8)

plt.tight_layout()
plt.show()

# Print the results
print("Budget vs Average Cost vs Accuracy:")
for r in budget_results:
    print(f"Budget: ${r['budget']:,} -> Avg Cost: ${r['avg_cost']:.4f} -> Accuracy: {r['accuracy']:.4f}")

print("\nIndividual Model Performance:")
for model in individual_models:
    print(f"{model['name']}: Avg Cost: ${model['avg_cost']:.4f}, Accuracy: {model['accuracy']:.4f}")
# %%

plt.figure(figsize=(8, 6))
# sns.scatterplot(x=results['llama8']['RCL'].flatten(), y=results['llama8']['RGL'].flatten(), alpha=0.6)
sns.kdeplot(x=results['llama1']['RCL'].flatten(), y=results['llama1']['RGL'].flatten(), levels=6, color='red', alpha=0.7)
plt.xlabel('RCL')
plt.ylabel('RGL')
plt.title('RCL vs RGL for Llama 8B with Density Contours')


# %%
# Test different cost budgets
cost_budgets = [2000, 5000, 8000, 10000,12500, 15000, 17500,  20000, 30000, 50000, 60000, 70000, 80000, 90000, 100000, ]
threshold = 0.05
smallest_model = ordered_models[0]
budget_results = []
decisions = []
for budget in cost_budgets:
    cascade_predictions = np.full(len(y_test), np.nan)
    cascade_model_used = np.full(len(y_test), -1)
    total_cost = 0
    
    for i in range(len(y_test)):
        # If budget exceeded, use smallest model without checking regret
        if total_cost >= budget:
            cascade_predictions[i] = results[smallest_model]['S_test'][i]
            cascade_model_used[i] = 0
            total_cost += costs_per_model[smallest_model]
            continue
        
        for j, model_name in enumerate(ordered_models):
            # Check if using this model would exceed budget
            if total_cost + costs_per_model[model_name] > budget:
                continue
                
            # Calculate total regret for this sample and model
            total_regret = results[model_name]['RCL'][i] # + results[model_name]['RGL'][i]
            
            # If regret is below threshold, accept this model's prediction
            if total_regret < threshold:
                cascade_predictions[i] = results[model_name]['S_test'][i]
                cascade_model_used[i] = j
                total_cost += costs_per_model[model_name]
                break
        
        # If no model meets threshold or budget constraint, use smallest model
        if np.isnan(cascade_predictions[i]):
            if total_cost + costs_per_model[smallest_model] <= budget:
                cascade_predictions[i] = results[smallest_model]['S_test'][i]
                cascade_model_used[i] = 0
                total_cost += costs_per_model[smallest_model]
            else:
                cascade_predictions[i] = results[smallest_model]['S_test'][i]
                cascade_model_used[i] = 0
                total_cost += costs_per_model[smallest_model]
    decisions.append(cascade_model_used)
    # Calculate performance
    acc = accuracy_score(y_test, (cascade_predictions >= 0.5).astype(int))
    avg_cost = total_cost / len(y_test)
    
    budget_results.append({
        'budget': budget,
        'total_cost': total_cost,
        'avg_cost': avg_cost,
        'accuracy': acc
    })

avg_costs_RCL = [r['avg_cost'] for r in budget_results]
    
decision_RCL = decisions

#%%
budget_results = []
decisions = []
for budget in cost_budgets:
    cascade_predictions = np.full(len(y_test), np.nan)
    cascade_model_used = np.full(len(y_test), -1)
    total_cost = 0
    
    for i in range(len(y_test)):
        # If budget exceeded, use smallest model without checking regret
        if total_cost >= budget:
            cascade_predictions[i] = results[smallest_model]['S_test'][i]
            cascade_model_used[i] = 0
            total_cost += costs_per_model[smallest_model]
            continue
        
        for j, model_name in enumerate(ordered_models):
            # Check if using this model would exceed budget
            if total_cost + costs_per_model[model_name] > budget:
                continue
                
            # Calculate total regret for this sample and model
            total_regret = results[model_name]['RCL'][i]  + results[model_name]['RGL'][i]
            
            # If regret is below threshold, accept this model's prediction
            if total_regret < threshold:
                cascade_predictions[i] = results[model_name]['S_test'][i]
                cascade_model_used[i] = j
                total_cost += costs_per_model[model_name]
                break
        
        # If no model meets threshold or budget constraint, use smallest model
        if np.isnan(cascade_predictions[i]):
            if total_cost + costs_per_model[smallest_model] <= budget:
                cascade_predictions[i] = results[smallest_model]['S_test'][i]
                cascade_model_used[i] = 0
                total_cost += costs_per_model[smallest_model]
            else:
                cascade_predictions[i] = results[smallest_model]['S_test'][i]
                cascade_model_used[i] = 0
                total_cost += costs_per_model[smallest_model]
    decisions.append(cascade_model_used)

    # Calculate performance
    acc = accuracy_score(y_test, (cascade_predictions >= 0.5).astype(int))
    avg_cost = total_cost / len(y_test)
    
    budget_results.append({
        'budget': budget,
        'total_cost': total_cost,
        'avg_cost': avg_cost,
        'accuracy': acc
    })


avg_costs_RCL_RGL = [r['avg_cost'] for r in budget_results]
decision_RCL_RGL = decisions


#%%

decision_RCL_RGL
# %%
for i in range(len(decision_RCL)):
    diff = decision_RCL[i] != decision_RCL_RGL[i]
    print(diff)
    acc = accuracy_score(y_test[diff], (results['llama70']['S_test'][diff] >= 0.5).astype(int))
    print(f"Budget: {avg_costs_RCL[i]} vs {avg_costs_RCL_RGL[i]}, Different decisions: {np.sum(diff)}, Accuracy on different decisions: {acc:.4f}")
# %%



# Test different cost budgets
cost_budgets = [0.05, 0.07, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45,0.75, 0.8, 0.9]
threshold = 0.05
smallest_model = ordered_models[0]
budget_results = []

for budget in cost_budgets:
    cascade_predictions = np.full(len(y_test), np.nan)
    cascade_model_used = np.full(len(y_test), -1)
    
    for i in range(len(y_test)):
        model_accepted = False
        
        for j, model_name in enumerate(ordered_models):
            # Check if this model's cost is within budget
            if costs_per_model[model_name] <= budget:
                # Calculate total regret for this sample and model
                total_regret = results[model_name]['RCL'][i] # + results[model_name]['RGL'][i]
                
                # If regret is below threshold, accept this model's prediction
                if total_regret < threshold:
                    cascade_predictions[i] = results[model_name]['S_test'][i]
                    cascade_model_used[i] = j
                    model_accepted = True
                    break
        
        # If no model meets threshold within budget, choose model with lowest regret within budget
        if not model_accepted:
            min_regret = float('inf')
            best_model_idx = -1
            
            for j, model_name in enumerate(ordered_models):
                if costs_per_model[model_name] <= budget:
                    total_regret = results[model_name]['RCL'][i] # + results[model_name]['RGL'][i]
                    if total_regret < min_regret:
                        min_regret = total_regret
                        best_model_idx = j
            
            if best_model_idx != -1:
                cascade_predictions[i] = results[ordered_models[best_model_idx]]['S_test'][i]
                cascade_model_used[i] = best_model_idx
    
    # Calculate performance
    valid_predictions = ~np.isnan(cascade_predictions)
    if np.sum(valid_predictions) > 0:
        acc = accuracy_score(y_test[valid_predictions], (cascade_predictions[valid_predictions] >= 0.5).astype(int))
        avg_cost = np.mean([costs_per_model[ordered_models[idx]] for idx in cascade_model_used[valid_predictions]])
    else:
        acc = 0
        avg_cost = 0
    
    budget_results.append({
        'budget': budget,
        'total_cost': avg_cost * len(y_test),
        'avg_cost': avg_cost,
        'accuracy': acc
    })

# Extract data for plotting
budgets = [r['budget'] for r in budget_results]
avg_costs = [r['avg_cost'] for r in budget_results]
accuracies = [r['accuracy'] for r in budget_results]

# Calculate individual model performance and costs
individual_models = []
for name in ordered_models:
    model_predictions = results[name]['S_test']
    acc = accuracy_score(y_test, (model_predictions >= 0.5).astype(int))
    avg_cost = costs_per_model[name]
    individual_models.append({
        'name': name,
        'avg_cost': avg_cost,
        'accuracy': acc
    })

# Create the plot
plt.figure(figsize=(6, 6))

# Plot cascade results
plt.plot(avg_costs, accuracies, 'bo-', linewidth=2, markersize=6, label='Real Cascade no RGL', alpha = 0.3)

budget_results = []

for budget in cost_budgets:
    cascade_predictions = np.full(len(y_test), np.nan)
    cascade_model_used = np.full(len(y_test), -1)
    
    for i in range(len(y_test)):
        model_accepted = False
        
        for j, model_name in enumerate(ordered_models):
            # Check if this model's cost is within budget
            if costs_per_model[model_name] <= budget:
                # Calculate total regret for this sample and model
                total_regret = results[model_name]['RCL'][i] + results[model_name]['RGL'][i]
                
                # If regret is below threshold, accept this model's prediction
                if total_regret < threshold:
                    cascade_predictions[i] = results[model_name]['S_test'][i]
                    cascade_model_used[i] = j
                    model_accepted = True
                    break
        
        # If no model meets threshold within budget, choose model with lowest regret within budget
        if not model_accepted:
            min_regret = float('inf')
            best_model_idx = -1
            
            for j, model_name in enumerate(ordered_models):
                if costs_per_model[model_name] <= budget:
                    total_regret = results[model_name]['RCL'][i] + results[model_name]['RGL'][i]
                    if total_regret < min_regret:
                        min_regret = total_regret
                        best_model_idx = j
            
                if best_model_idx != -1:
                    cascade_predictions[i] = results[ordered_models[best_model_idx]]['S_test'][i]
                    cascade_model_used[i] = best_model_idx
    
    # Calculate performance
    valid_predictions = ~np.isnan(cascade_predictions)
    if np.sum(valid_predictions) > 0:
        acc = accuracy_score(y_test[valid_predictions], (cascade_predictions[valid_predictions] >= 0.5).astype(int))
        avg_cost = np.mean([costs_per_model[ordered_models[idx]] for idx in cascade_model_used[valid_predictions]])
    else:
        acc = 0
        avg_cost = 0
    
    budget_results.append({
        'budget': budget,
        'total_cost': avg_cost * len(y_test),
        'avg_cost': avg_cost,
        'accuracy': acc
    })

# Extract data for plotting
budgets = [r['budget'] for r in budget_results]
avg_costs = [r['avg_cost'] for r in budget_results]
accuracies = [r['accuracy'] for r in budget_results]

# Calculate individual model performance and costs
individual_models = []
for name in ordered_models:
    model_predictions = results[name]['S_test']
    acc = accuracy_score(y_test, (model_predictions >= 0.5).astype(int))
    avg_cost = costs_of_all_models[name] / len(y_test)
    individual_models.append({
        'name': name,
        'avg_cost': avg_cost,
        'accuracy': acc
    })

# Create the plot
# plt.figure(figsize=(6, 6))

# Plot cascade results
plt.plot(avg_costs, accuracies, 'ro-', linewidth=2, markersize=6, label='Real Cascade RGL included', alpha = 0.3)


# Oracle baseline: choose smallest model that gets the correct answer
oracle_indices = np.full(len(y_test), -1)
for i in range(len(y_test)):
    # Get binary predictions for all models at this sample
    binary_preds = (all_predictions[i] >= 0.5).astype(int)
    # Find which models are correct
    correct_models = np.where(binary_preds == y_test[i])[0]
    if len(correct_models) > 0:
        # Choose the smallest model among correct ones
        oracle_indices[i] = correct_models[np.argmin([model_sizes[j] for j in correct_models])]
    else:
        # If no model is correct, choose the smallest model
        oracle_indices[i] = 0

# Create oracle predictions
oracle_predictions = all_predictions[np.arange(len(all_predictions)), oracle_indices]

# Calculate oracle performance
auc_oracle = roc_auc_score(y_test, oracle_predictions)
acc_oracle = accuracy_score(y_test, (oracle_predictions >= 0.5).astype(int))
brier_oracle = brier_score_loss(y_test, oracle_predictions)

# Calculate oracle model usage and average cost
oracle_calls = np.bincount(oracle_indices, minlength=len(model_names))
oracle_proportions = oracle_calls / len(oracle_indices)

# Calculate oracle average cost
oracle_avg_cost = sum(oracle_proportions[i] * costs_per_model[model_names[i]] for i in range(len(model_names)))


# Add oracle to the plot
plt.scatter(oracle_avg_cost, acc_oracle, 
           color='gold', s=200, marker='*', 
           label='Oracle Baseline', zorder=10, edgecolors='black', linewidth=1)



# Plot individual models
model_colors = ['red', 'green', 'orange', 'purple', 'blue', 'brown', 'pink']
for i, model in enumerate(individual_models):
    plt.scatter(model['avg_cost'], model['accuracy'], 
               color=model_colors[i], s=100, marker='s', 
               label=f"{model['name']}", zorder=5)

plt.xlabel('Average Cost per Sample ($)')
plt.ylabel('Accuracy')
plt.title('Average Cost vs Accuracy: Real Cascade vs Individual Models')
plt.grid(True, alpha=0.3)
plt.legend()

# Add budget labels for cascade points
# for i, budget in enumerate(budgets):
#     plt.annotate(f'${budget}', 
#                 (avg_costs[i], accuracies[i]), 
#                 textcoords="offset points", 
#                 xytext=(0,10), 
#                 ha='center',
#                 fontsize=8)

plt.tight_layout()
plt.show()

# Print the results
print("Budget vs Average Cost vs Accuracy:")
for r in budget_results:
    print(f"Budget: ${r['budget']:,} -> Avg Cost: ${r['avg_cost']:.4f} -> Accuracy: {r['accuracy']:.4f}")

print("\nIndividual Model Performance:")
for model in individual_models:
    print(f"{model['name']}: Avg Cost: ${model['avg_cost']:.4f}, Accuracy: {model['accuracy']:.4f}")
# %%



# Test different numbers of models in cascade
num_models_list = [1, 2, 3, 4, 5, 6]  # Number of models to consider in cascade
threshold = 0.05
budget_results = []

for num_models in num_models_list:
    cascade_predictions = np.full(len(y_test), np.nan)
    cascade_model_used = np.full(len(y_test), -1)
    models_to_consider = ordered_models[:num_models]
    
    for i in range(len(y_test)):
        model_accepted = False
        
        for j, model_name in enumerate(models_to_consider):
            # Calculate total regret for this sample and model
            total_regret = results[model_name]['RCL'][i] # + results[model_name]['RGL'][i]
            
            # If regret is below threshold, accept this model's prediction
            if total_regret < threshold:
                cascade_predictions[i] = results[model_name]['S_test'][i]
                cascade_model_used[i] = j
                model_accepted = True
                break
        
        # If no model meets threshold, choose model with lowest regret
        if not model_accepted:
            min_regret = float('inf')
            best_model_idx = -1
            
            for j, model_name in enumerate(models_to_consider):
                total_regret = results[model_name]['RCL'][i] # + results[model_name]['RGL'][i]
                if total_regret < min_regret:
                    min_regret = total_regret
                    best_model_idx = j
            
            if best_model_idx != -1:
                cascade_predictions[i] = results[models_to_consider[best_model_idx]]['S_test'][i]
                cascade_model_used[i] = best_model_idx
    
    # Calculate performance
    valid_predictions = ~np.isnan(cascade_predictions)
    if np.sum(valid_predictions) > 0:
        acc = accuracy_score(y_test[valid_predictions], (cascade_predictions[valid_predictions] >= 0.5).astype(int))
        avg_cost = np.mean([costs_per_model[models_to_consider[idx]] for idx in cascade_model_used[valid_predictions]])
    else:
        acc = 0
        avg_cost = 0
    
    budget_results.append({
        'num_models': num_models,
        'avg_cost': avg_cost,
        'accuracy': acc
    })

# Extract data for plotting
num_models_considered = [r['num_models'] for r in budget_results]
accuracies = [r['accuracy'] for r in budget_results]
avg_costs = [r['avg_cost'] for r in budget_results]

# Calculate individual model performance
individual_models = []
for name in ordered_models:
    model_predictions = results[name]['S_test']
    acc = accuracy_score(y_test, (model_predictions >= 0.5).astype(int))
    individual_models.append({
        'name': name,
        'accuracy': acc
    })

# Create the plot
plt.figure(figsize=(8, 6))

# Plot cascade results
plt.plot(num_models_considered, accuracies, 'bo-', linewidth=2, markersize=3, label='Real Cascade no RGL')

# plt.scatter(num_models_considered, accuracies, c=avg_costs, s=100, 
#            marker='o', cmap='viridis', label='Real Cascade RGL included', zorder=10)

# Run cascade with RGL included
budget_results_rgl = []

for num_models in num_models_list:
    cascade_predictions = np.full(len(y_test), np.nan)
    cascade_model_used = np.full(len(y_test), -1)
    models_to_consider = ordered_models[:num_models]
    
    for i in range(len(y_test)):
        model_accepted = False
        
        for j, model_name in enumerate(models_to_consider):
            # Calculate total regret for this sample and model
            total_regret = results[model_name]['RCL'][i] + results[model_name]['RGL'][i]
            
            # If regret is below threshold, accept this model's prediction
            if total_regret < threshold:
                cascade_predictions[i] = results[model_name]['S_test'][i]
                cascade_model_used[i] = j
                model_accepted = True
                break
        
        # If no model meets threshold, choose model with lowest regret
        if not model_accepted:
            min_regret = float('inf')
            best_model_idx = -1
            
            for j, model_name in enumerate(models_to_consider):
                total_regret = results[model_name]['RCL'][i] + results[model_name]['RGL'][i]
                if total_regret < min_regret:
                    min_regret = total_regret
                    best_model_idx = j
            
            if best_model_idx != -1:
                cascade_predictions[i] = results[models_to_consider[best_model_idx]]['S_test'][i]
                cascade_model_used[i] = best_model_idx
    
    # Calculate performance
    valid_predictions = ~np.isnan(cascade_predictions)
    if np.sum(valid_predictions) > 0:
        acc = accuracy_score(y_test[valid_predictions], (cascade_predictions[valid_predictions] >= 0.5).astype(int))
        avg_cost = np.mean([costs_per_model[models_to_consider[idx]] for idx in cascade_model_used[valid_predictions]])
    else:
        acc = 0
        avg_cost = 0
    
    budget_results_rgl.append({'accuracy': acc, 'avg_cost': avg_cost})

accuracies_rgl = [r['accuracy'] for r in budget_results_rgl]
avg_costs_rgl = [r['avg_cost'] for r in budget_results_rgl]


# Plot cascade results with RGL
plt.plot(num_models_considered, accuracies_rgl, 'ro-', linewidth=2, markersize=3, label='Real Cascade RGL included')

# Put scatterplot on top with higher zorder
# plt.scatter(num_models_considered, accuracies_rgl, c=avg_costs_rgl, s=100, 
#            marker='o', cmap='viridis', label='Real Cascade RGL included', zorder=10)
# plt.colorbar(label='Average Cost ($)')


oracle_accuracies = []

for num_models in num_models_list:
    models_to_consider = ordered_models[:num_models]
    oracle_indices = np.full(len(y_test), -1)
    
    for i in range(len(y_test)):
        # Get predictions for models being considered
        model_preds = [results[name]['S_test'][i] for name in models_to_consider]
        binary_preds = [(pred >= 0.5).astype(int) for pred in model_preds]
        
        # Find which models are correct
        correct_models = [j for j, pred in enumerate(binary_preds) if pred == y_test[i]]
        if len(correct_models) > 0:
            # Choose the smallest model among correct ones (first in ordered list)
            oracle_indices[i] = correct_models[0]
        else:
            # If no model is correct, choose the smallest model
            oracle_indices[i] = 0
    
    # Create oracle predictions
    oracle_predictions = np.array([results[models_to_consider[oracle_indices[i]]]['S_test'][i] for i in range(len(y_test))])
    
    # Calculate oracle accuracy
    oracle_acc = accuracy_score(y_test, (oracle_predictions >= 0.5).astype(int))
    oracle_accuracies.append(oracle_acc)

# Add oracle baseline as a solid line
plt.plot(num_models_considered, oracle_accuracies, 'g-', linewidth=3, markersize=8, 
         label='Oracle Baseline', alpha=0.8)



# Plot individual models as horizontal dotted lines
model_colors = ['red', 'green', 'orange', 'purple', 'blue', 'brown']
for i, model in enumerate(individual_models):
    plt.axhline(y=model['accuracy'], color=model_colors[i], linestyle='--', alpha=0.7, 
                label=f"{ordered_models[i]}")

plt.xlabel('Number of Models in Cascade')
plt.ylabel('Accuracy')
plt.title('Accuracy vs Number of Models in Cascade')
plt.grid(True, alpha=0.3)
plt.legend()
plt.xticks(num_models_list)

plt.tight_layout()
plt.show()

# Print the results
print("Number of Models vs Accuracy:")
for i, r in enumerate(budget_results):
    print(f"Models: {r['num_models']} -> Accuracy (no RGL): {r['accuracy']:.4f}, Accuracy (with RGL): {accuracies_rgl[i]:.4f}")

print("\nIndividual Model Performance:")
for i, model in enumerate(individual_models):
    print(f"{ordered_models[i]}: Accuracy: {model['accuracy']:.4f}")


# %%


def cascade_eval_rcl(cascade, costs_model, results, threshold=0.05):
    list_models = []
    for model_name in cascade:
        if model_name not in list_models:
            list_models.append(model_name)

    cascade_predictions = np.full(len(results[list_models[0]]["S_test"]), np.nan)
    cascade_model_used = np.full(len(results[list_models[0]]["S_test"]), -1)
    total_cost = 0
    
    for j in range(len(y_test)):
        model_accepted = False
        
        # Try models in the cascade order
        for model_name in cascade:
            # Calculate regret for this sample and model
            model_regret = results[model_name]['RCL'][j]
            
            # If regret is below threshold, accept this model's prediction
            if model_regret < threshold:
                cascade_predictions[j] = results[model_name]['S_test'][j]
                cascade_model_used[j] = list_models.index(model_name)
                total_cost += costs_model[model_name]
                model_accepted = True
                break
        
        # If no model meets threshold, choose model with lowest regret
        if not model_accepted:
            min_regret = float('inf')
            best_model_idx = -1
            
            for j, model_name in enumerate(cascade):
                total_regret = results[model_name]['RCL'][i] + results[model_name]['RGL'][i]
                if total_regret < min_regret:
                    min_regret = total_regret
                    best_model_idx = j
            
            if best_model_idx != -1:
                cascade_predictions[i] = results[cascade[best_model_idx]]['S_test'][i]
                cascade_model_used[i] = best_model_idx
    
    # Calculate performance metrics
    valid_predictions = ~np.isnan(cascade_predictions)
    if np.sum(valid_predictions) > 0:
        accuracy = accuracy_score(y_test[valid_predictions], 
                                (cascade_predictions[valid_predictions] >= 0.5).astype(int))
        avg_cost = total_cost / len(y_test)
    else:
        accuracy = 0
        avg_cost = 0
    
    return avg_cost, accuracy



def cascade_eval_rcl_rgl(cascade, costs_model, results, threshold=0.05):
    list_models = []
    for model_name in cascade:
        if model_name not in list_models:
            list_models.append(model_name)

    cascade_predictions = np.full(len(y_test), np.nan)
    cascade_model_used = np.full(len(y_test), -1)
    total_cost = 0
    
    for j in range(len(y_test)):
        model_accepted = False
        
        # Try models in the cascade order
        for model_name in cascade:
            # Calculate regret for this sample and model (RCL + RGL)
            model_regret = results[model_name]['RCL'][j] + results[model_name]['RGL'][j]
            
            # If regret is below threshold, accept this model's prediction
            if model_regret < threshold:
                cascade_predictions[j] = results[model_name]['S_test'][j]
                cascade_model_used[j] = list_models.index(model_name)
                total_cost += costs_model[model_name]
                model_accepted = True
                break
        
        # If no model meets threshold, choose model with lowest regret
        if not model_accepted:
            min_regret = float('inf')
            best_model_idx = -1
            
            for j, model_name in enumerate(cascade):
                total_regret = results[model_name]['RCL'][i] + results[model_name]['RGL'][i]
                if total_regret < min_regret:
                    min_regret = total_regret
                    best_model_idx = j
            
            if best_model_idx != -1:
                cascade_predictions[i] = results[cascade[best_model_idx]]['S_test'][i]
                cascade_model_used[i] = best_model_idx
    
    # Calculate performance metrics
    valid_predictions = ~np.isnan(cascade_predictions)
    if np.sum(valid_predictions) > 0:
        accuracy = accuracy_score(y_test[valid_predictions], 
                                (cascade_predictions[valid_predictions] >= 0.5).astype(int))
        avg_cost = total_cost / len(y_test)
    else:
        accuracy = 0
        avg_cost = 0
    
    return avg_cost, accuracy


def plot_cascade(list_cascades, costs_model, results, threshold=0.05):
    list_models = []
    for cascade in list_cascades:
        for model_name in cascade:
            if model_name not in list_models:
                list_models.append(model_name)

    # Calculate performance for each cascade
    cascade_results_rcl = []
    cascade_results_rcl_rgl = []
    for cascade in list_cascades:
        avg_cost_rcl, accuracy_rcl = cascade_eval_rcl(cascade, costs_model, results, threshold)
        cascade_results_rcl.append((avg_cost_rcl, accuracy_rcl))
        print(accuracy_rcl)
        
        avg_cost_rcl_rgl, accuracy_rcl_rgl = cascade_eval_rcl_rgl(cascade, costs_model, results, threshold)
        cascade_results_rcl_rgl.append((avg_cost_rcl_rgl, accuracy_rcl_rgl))

    # Plotting
    plt.figure(figsize=(6, 6))

    # Plot cascade results without RGL
    avg_costs_rcl, accuracies_rcl = zip(*cascade_results_rcl)

    # Plot cascade results with RGL
    avg_costs_rcl_rgl, accuracies_rcl_rgl = zip(*cascade_results_rcl_rgl)
    
    # Plot each cascade with both RCL and RCL+RGL as points
    for i, cascade in enumerate(list_cascades):
        cascade_name = ' → '.join(cascade)
        
        # Plot RCL only point
        plt.scatter(avg_costs_rcl[i], accuracies_rcl[i], 
                   color='blue', s=100, marker='o', alpha=0.8,
                   label=f'{cascade_name} (RCL)' if i == 0 else '')
        
        # Plot RCL+RGL point  
        plt.scatter(avg_costs_rcl_rgl[i], accuracies_rcl_rgl[i], 
                   color='red', s=100, marker='o', alpha=0.8,
                   label=f'{cascade_name} (RCL+RGL)' if i == 0 else '')
        
        # Connect the two points with a line to show the comparison
        plt.plot([avg_costs_rcl[i], avg_costs_rcl_rgl[i]], 
                [accuracies_rcl[i], accuracies_rcl_rgl[i]], 
                'k--', alpha=0.3, linewidth=1)
        
        # Add cascade name annotation
        mid_x = (avg_costs_rcl[i] + avg_costs_rcl_rgl[i]) / 2
        mid_y = (accuracies_rcl[i] + accuracies_rcl_rgl[i]) / 2
        plt.annotate(f'C{i+1}', (mid_x, mid_y), 
                    textcoords="offset points", xytext=(0,15), 
                    ha='center', fontsize=8, alpha=0.7)
    # Oracle baseline: choose smallest model that gets the correct answer
    oracle_indices = np.full(len(y_test), -1)
    for i in range(len(y_test)):
        # Get binary predictions for all models at this sample
        binary_preds = (all_predictions[i] >= 0.5).astype(int)
        # Find which models are correct
        correct_models = np.where(binary_preds == y_test[i])[0]
        if len(correct_models) > 0:
            # Choose the smallest model among correct ones
            oracle_indices[i] = correct_models[np.argmin([model_sizes[j] for j in correct_models])]
        else:
            # If no model is correct, choose the smallest model
            oracle_indices[i] = 0
    # Create oracle predictions
    oracle_predictions = all_predictions[np.arange(len(all_predictions)), oracle_indices]
    # Calculate oracle performance
    auc_oracle = roc_auc_score(y_test, oracle_predictions)
    acc_oracle = accuracy_score(y_test, (oracle_predictions >= 0.5).astype(int))
    brier_oracle = brier_score_loss(y_test, oracle_predictions)
    # Calculate oracle model usage and average cost
    oracle_calls = np.bincount(oracle_indices, minlength=len(model_names))
    oracle_proportions = oracle_calls / len(oracle_indices)
    # Calculate oracle average cost
    oracle_avg_cost = sum(oracle_proportions[i] * costs_per_model[model_names[i]] for i in range(len(model_names)))
    # Add oracle to the plot
    plt.scatter(oracle_avg_cost, acc_oracle, 
               color='gold', s=200, marker='*', 
               label='Oracle Baseline', zorder=10, edgecolors='black', linewidth=1)
    
    # Plot individual models
    model_colors = plt.cm.tab10(np.linspace(0, 1, len(list_models)))  
    for i, model_name in enumerate(list_models):
        model_predictions = results[model_name]['S_test']
        acc = accuracy_score(y_test, (model_predictions >= 0.5).astype(int))
        avg_cost = costs_per_model[model_name]
        plt.scatter(avg_cost, acc, 
                   color=model_colors[i], s=100, marker='s', 
                   label=f"{model_name}", zorder=5)
    plt.xlabel('Average Cost per Sample ($)')
    plt.ylabel('Accuracy')
    plt.title('Average Cost vs Accuracy: Cascade vs Individual Models')
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.show()
    return cascade_results_rcl, cascade_results_rcl_rgl

# Example usage

list_cascades = [['llama1', 'llama8'],
    ['llama1', 'llama8', 'llama70'],
    ['llama1', 'phi4', 'llama70'],
    ['llama1', 'llama8', 'phi4', 'Mixtral8x7B', 'llama70']]

cascade_results_rcl, cascade_results_rcl_rgl = plot_cascade(list_cascades, costs_per_model, results, threshold=0.05)
# %%
