import pandas as pd
import numpy as np
import json
import ast
import matplotlib.pyplot as plt  # Import for plotting

def compute_constrained_reward(scores, arm, constraints):
    scores = np.array([scores[method][arm] for method in scores.keys()]).reshape(-1)
    print(scores.shape)
    if (scores[1:] > constraints).all():
        return scores[0]
    else:
        return 0.0

def find_constrained_optimal(scores, constraints):
    scores = np.array([scores[method] for method in scores.keys()]).T
    best_arm = -1
    best_reward = -1
    
    
    for i in range(scores.shape[0]):
        if (scores[i, 1:] > constraints).all():
            if scores[i, 0] > best_reward:
                best_reward = scores[i, 0]
                best_arm = i
    return best_arm
    

def plot_2D_rewards(scores, group1_indices, group2_indices, pareto_indices, reference_point, fig_path=None):
    avg_rewards = np.array([scores[method] for method in scores.keys()]).T - reference_point

    # Plot all points
    for i, reward in enumerate(avg_rewards):
        color = "red" if i in pareto_indices else "black"
        if i in group1_indices:
            plt.scatter(reward[0], reward[1], color=color, marker="+", label="Group 1" if i == group1_indices[0] else None)
        elif i in group2_indices:
            plt.scatter(reward[0], reward[1], edgecolor=color, facecolors="none", marker="o", label="Group 2" if i == group2_indices[0] else None)
        else:
            plt.scatter(reward[0], reward[1], color=color, alpha=0.3, marker=".", label="Other" if i == 0 else None)

    # Add labels and legend
    plt.xlabel("Objective 1")
    plt.ylabel("Objective 2")
    plt.title("2D Rewards Visualization")
    plt.legend()
    plt.grid(True)
    plt.savefig(fig_path + '2D_rewards_plot.png')  # Save the plot as an image

def main(fig_path, prompts_file, results_file, algorithms, budgets, constraints, config):
    # Load the true prompt evaluation results
    prompts_df = pd.read_csv(prompts_file)
    print(f"Loaded {len(prompts_df)} prompts from {prompts_file}")
    eval_scores = {method: prompts_df[f"mean_scores_{method}"].to_numpy() for method in config["reward_method_eval"]}

    # Load the algorithm results
    results_df = pd.read_csv(results_file)
    constrained_reward_data = {}
    seeds_data = {}  # Dictionary to store random seeds for each algorithm-budget pair

    # Filter rows based on algorithms and budgets
    filtered_rows = results_df[
        results_df["bandit_algorithm"].isin(algorithms) & results_df["budget"].isin(budgets)
    ]

    for _, row in filtered_rows.iterrows():
        alg = row["bandit_algorithm"]
        budget = row["budget"]
        seed = row["random_seed"]
        df = list(ast.literal_eval(row['bandit_best_arm']))
        constrained_reward = compute_constrained_reward(eval_scores, df, constraints)
        print(f"alg: {alg}, budget: {budget}, seed: {seed}, reward: {constrained_reward}")
        
        if alg not in constrained_reward_data:
            constrained_reward_data[alg] = {}
        if budget not in constrained_reward_data[alg]:
            constrained_reward_data[alg][budget] = []
        constrained_reward_data[alg][budget].append(constrained_reward)
        
        # Store the random seeds
        if alg not in seeds_data:
            seeds_data[alg] = {}
        if budget not in seeds_data[alg]:
            seeds_data[alg][budget] = []
        seeds_data[alg][budget].append(seed)
    
    # Compute the average hypervolume for each algorithm and budget
    avg_constrained_reward_data = {
        alg: {budget: reward for budget, reward in budgets.items()}
        for alg, budgets in constrained_reward_data.items()
    }

    c_optimal_arm = find_constrained_optimal(eval_scores, constraints)
    for method in eval_scores.keys():
        print(eval_scores[method][c_optimal_arm])
    optimal_constrained_reward = compute_constrained_reward(eval_scores, c_optimal_arm, constraints)
    print("Constrained optimal arm: ", c_optimal_arm)
    print("Hypervolume of Pareto front:", optimal_constrained_reward)

    # Print the random seeds for each algorithm-budget pair
    print("\nRandom seeds for each algorithm-budget pair:")
    for alg, budgets in seeds_data.items():
        for budget, seeds in budgets.items():
            print(f"Algorithm: {alg}, Budget: {budget}, Seeds: {seeds}")

    # Plot hypervolume vs budget
    plot_reward_vs_budget(avg_constrained_reward_data, optimal_constrained_reward, fig_path)
    
    return

def plot_reward_vs_budget(XYdata, optimal_constrained_reward, fig_path, error_bars=False):
    plt.figure()
    for alg, data in XYdata.items():
        print(data)
        budgets = sorted(data.keys())
        avg_reward = [np.mean(data[budget]) for budget in budgets]
        std_reward = [np.std(data[budget]) for budget in budgets]  # Calculate standard deviation
        
        # Plot with error bars
        if error_bars:
            plt.errorbar(budgets, avg_reward, yerr=std_reward, label=alg, capsize=5, fmt='-o')
        else:
            plt.plot(budgets, avg_reward, label=alg, marker='o')
    
    # Add horizontal line for Pareto front hypervolume
    plt.axhline(y=optimal_constrained_reward, color='black', linestyle='--', label='optimal constrained reward')
    
    # Add labels and legend
    plt.xlabel("Budget")
    plt.ylabel("Average Reward")
    plt.title("Average Reward vs Budget (with Error Bars)")
    plt.legend()
    plt.grid(True)
    plt.savefig(fig_path + 'average_constrained_reward_vs_budget_with_error_bars.png')  # Save the plot as an image

if __name__ == "__main__":
    config_path = "./config.json"
    with open(config_path, "r") as f:
        config = json.load(f)
    
    task = "cnn_dailymail" # "xsum"
    prompts_file = "./results/MO/" + task + "_WhiteBox_Llama3/prompts_43.csv"
    results_file = "./results/MO/" + task + "_WhiteBox_Llama3/bandit_results.csv"
    fig_path = "./results/MO/" + task + "_WhiteBox_Llama3/"
    
    # Specify the algorithms and budgets to filter
    algorithms = ["Constrained_Uni", "CSR"]  # Replace with actual algorithm names
    budgets = [5, 10, 20]  # Replace with actual budget values
    
    constraints = [0.7]
    main(fig_path, prompts_file, results_file, algorithms, budgets, constraints, config)