import pandas as pd
import numpy as np
import json
import ast
import matplotlib.pyplot as plt  # Import for plotting

def compute_2D_hypervolume(scores, set, reference_point):
    scores = np.array([scores[method][set] for method in scores.keys()]).T - reference_point
    sorted_scores = scores[np.argsort(scores[:, 0])]
    area = 0.0
    for i in range(len(sorted_scores)):
        if i == 0:
            last_x = 0
        else:
            last_x = sorted_scores[i - 1, 0]
        width = sorted_scores[i, 0] - last_x
        height = sorted_scores[i, 1]
        area += width * height
    return area

def pareto_front(scores, reference_point):
    avg_rewards = np.array([scores[method] for method in scores.keys()]).T - reference_point
    num_arms = avg_rewards.shape[0]
    is_pareto = np.ones(num_arms, dtype=bool)

    for i in range(num_arms):
        for j in range(num_arms):
            if i == j:
                continue
            # Check if arm j dominates arm i
            if np.all(avg_rewards[j] >= avg_rewards[i]) and np.any(avg_rewards[j] > avg_rewards[i]):
                is_pareto[i] = False
                break

    return [arm for arm, is_optimal in enumerate(is_pareto) if is_optimal]

def plot_2D_rewards(scores, group1_indices, group2_indices, pareto_indices, reference_point, fig_path=None):
    avg_rewards = np.array([scores[method] for method in scores.keys()]).T - reference_point

    # Plot all points
    for i, reward in enumerate(avg_rewards):
        color = "red" if i in pareto_indices else "black"
        if i in group1_indices:
            plt.scatter(reward[0], reward[1], color=color, marker="+", label="Group 1" if i == group1_indices[0] else None)
        elif i in group2_indices:
            plt.scatter(reward[0], reward[1], edgecolor=color, facecolors="none", marker="o", label="Group 2" if i == group2_indices[0] else None)
        else:
            plt.scatter(reward[0], reward[1], color=color, alpha=0.3, marker=".", label="Other" if i == 0 else None)

    # Add labels and legend
    plt.xlabel("Objective 1")
    plt.ylabel("Objective 2")
    plt.title("2D Rewards Visualization")
    plt.legend()
    plt.grid(True)
    plt.savefig(fig_path + '2D_rewards_plot.png')  # Save the plot as an image

def main(fig_path, prompts_file, results_file, algorithms, budgets, reference_point, config):
    # Load the true prompt evaluation results
    prompts_df = pd.read_csv(prompts_file)
    print(f"Loaded {len(prompts_df)} prompts from {prompts_file}")
    eval_scores = {method: prompts_df[f"mean_scores_{method}"].to_numpy() for method in config["reward_method_eval"]}

    # Load the algorithm results
    results_df = pd.read_csv(results_file)
    hypervolume_data = {}
    seeds_data = {}  # Dictionary to store random seeds for each algorithm-budget pair

    # Filter rows based on algorithms and budgets
    filtered_rows = results_df[
        results_df["bandit_algorithm"].isin(algorithms) & results_df["budget"].isin(budgets)
    ]

    for _, row in filtered_rows.iterrows():
        alg = row["bandit_algorithm"]
        budget = row["budget"]
        seed = row["random_seed"]
        df = list(ast.literal_eval(row['bandit_best_arm']))
        hypervolume = compute_2D_hypervolume(eval_scores, df, reference_point)
        print(f"alg: {alg}, budget: {budget}, seed: {seed}, hypervolume: {hypervolume}")
        
        if alg not in hypervolume_data:
            hypervolume_data[alg] = {}
        if budget not in hypervolume_data[alg]:
            hypervolume_data[alg][budget] = []
        hypervolume_data[alg][budget].append(hypervolume)
        
        # Store the random seeds
        if alg not in seeds_data:
            seeds_data[alg] = {}
        if budget not in seeds_data[alg]:
            seeds_data[alg][budget] = []
        seeds_data[alg][budget].append(seed)
    
    # Compute the average hypervolume for each algorithm and budget
    avg_hypervolume_data = {
        alg: {budget: hypervolumes for budget, hypervolumes in budgets.items()}
        for alg, budgets in hypervolume_data.items()
    }

    pareto = pareto_front(eval_scores, reference_point)
    pareto_hypervolume = compute_2D_hypervolume(eval_scores, pareto, reference_point)
    print("Pareto front:", pareto)
    print("Hypervolume of Pareto front:", pareto_hypervolume)

    # Print the random seeds for each algorithm-budget pair
    print("\nRandom seeds for each algorithm-budget pair:")
    for alg, budgets in seeds_data.items():
        for budget, seeds in budgets.items():
            seeds.sort()
            print(f"Algorithm: {alg}, Budget: {budget}, Seeds: {seeds}")

    # Plot hypervolume vs budget
    plot_hypervolume_vs_budget(avg_hypervolume_data, pareto_hypervolume, fig_path)
    
    return

def plot_hypervolume_vs_budget(hypervolume_data, pareto_hypervolume, fig_path, error_bars=True):
    plt.figure()
    for alg, data in hypervolume_data.items():
        print(alg, data)
        budgets = sorted(data.keys())
        avg_hypervolumes = [np.mean(data[budget]) for budget in budgets]
        std_hypervolumes = [np.std(data[budget]) / np.sqrt(len(data[budget]) - 1) for budget in budgets]  # Calculate standard deviation
        for budget in budgets:
            print(f"Algorithm: {alg}, Budget: {budget}, Avg Hypervolume: {avg_hypervolumes[budgets.index(budget)]}, Std: {std_hypervolumes[budgets.index(budget)]}")

        # Plot with error bars
        if error_bars:
            plt.errorbar(budgets, avg_hypervolumes, yerr=std_hypervolumes, label=alg, capsize=5, fmt='-o')
        else:
            plt.plot(budgets, avg_hypervolumes, label=alg, marker='o')
    
    # Add horizontal line for Pareto front hypervolume
    plt.axhline(y=pareto_hypervolume, color='black', linestyle='--', label='Pareto Front')
    
    # Add labels and legend
    plt.xlabel("Budget")
    plt.ylabel("Average Hypervolume")
    plt.title("Average Hypervolume vs Budget (with Error Bars)")
    plt.legend()
    plt.grid(True)
    plt.savefig(fig_path + 'average_hypervolume_vs_budget_with_error_bars.png')  # Save the plot as an image

if __name__ == "__main__":
    config_path = "./config.json"
    with open(config_path, "r") as f:
        config = json.load(f)
    
    task = "xsum" # "xsum" / "cnn_dailymail" / "samsum"
    model = "Gemma" # "Gemma" / "Llama3" / 'ChatGPT'
    if model != 'ChatGPT':
        model = "_WhiteBox_" + model
    else:
        model = "_ChatGPT"
    prompts_file = "./results/MO/" + task + model + "/prompts_42.csv"
    results_file = "./results/MO/" + task + model + "/results_pareto/50/bandit_results.csv"
    fig_path = "./results/MO/" + task + model + "/"
    
    # Specify the algorithms and budgets to filter
    algorithms = ["Pareto_Uni", "EGE", "MLP_EGE"]  # Replace with actual algorithm names
    budgets = [3, 5, 8, 10]  # Replace with actual budget values
    
    reference_point = [0.0, 0.0]  # Adjust based on your objectives

    main(fig_path, prompts_file, results_file, algorithms, budgets, reference_point, config)
