import os
import random
import numpy as np
import torch
from statistics import mean
from collections import namedtuple, deque
import seaborn as sns; sns.set_theme()
import matplotlib.pyplot as plt

def plot_cost_score_tradeoff(costs, scores, max_cost, save_path):
    # Convert to NumPy arrays
    costs = np.array(costs)
    scores = np.array(scores)

    # Sort by cost (increasing)
    sorted_indices = np.argsort(costs)
    costs = costs[sorted_indices]
    scores = scores[sorted_indices]

    # Compute non-decreasing convex hull (Pareto front)
    pareto_costs = []
    pareto_scores = []
    max_score_so_far = -np.inf

    for c, s in zip(costs, scores):
        if s > max_score_so_far:
            pareto_costs.append(c)
            pareto_scores.append(s)
            max_score_so_far = s
    
    # Extrapolate horizontally to max_cost        
    pareto_costs.append(max_cost)
    pareto_scores.append(pareto_scores[-1])
    
    # Compute area under the curve (AUC) using trapezoidal rule
    auc = np.trapezoid(pareto_scores, pareto_costs)

    # Plotting
    fig, ax = plt.subplots(1, 1)
    ax.scatter(costs, scores, label='All points')
    ax.plot(pareto_costs, pareto_scores, color='red', label='Non-decreasing convex hull')
    ax.set_title(f"AUC: {auc * 100 / max_cost}")
    ax.set_xlabel("Cost")
    ax.set_ylabel("Score")
    ax.legend()
    fig.savefig(save_path, bbox_inches='tight')
    plt.close(fig)